feat: 修复报错
This commit is contained in:
164
python-app/app/market/service.py
Normal file
164
python-app/app/market/service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from time import time
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from market.provider_base import MarketDataProvider
|
||||
|
||||
|
||||
DISPLAY_FIELDS = [
|
||||
"代码",
|
||||
"名称",
|
||||
"最新价",
|
||||
"涨跌幅",
|
||||
"涨跌额",
|
||||
"成交量",
|
||||
"成交额",
|
||||
"振幅",
|
||||
"最高",
|
||||
"最低",
|
||||
"今开",
|
||||
"昨收",
|
||||
"市盈率-动态",
|
||||
"市净率",
|
||||
"总市值",
|
||||
"流通市值",
|
||||
]
|
||||
|
||||
|
||||
def _safe(v: Any) -> Any:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, float) and pd.isna(v):
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
def build_signals(hist: pd.DataFrame, sma_fast: int, sma_slow: int) -> list[dict[str, Any]]:
|
||||
frame = hist.copy()
|
||||
frame["sma_fast"] = frame["close"].rolling(sma_fast).mean()
|
||||
frame["sma_slow"] = frame["close"].rolling(sma_slow).mean()
|
||||
signals: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(1, len(frame)):
|
||||
prev = frame.iloc[i - 1]
|
||||
curr = frame.iloc[i]
|
||||
if pd.isna(prev["sma_fast"]) or pd.isna(prev["sma_slow"]) or pd.isna(curr["sma_fast"]) or pd.isna(curr["sma_slow"]):
|
||||
continue
|
||||
cross_up = prev["sma_fast"] <= prev["sma_slow"] and curr["sma_fast"] > curr["sma_slow"]
|
||||
cross_down = prev["sma_fast"] >= prev["sma_slow"] and curr["sma_fast"] < curr["sma_slow"]
|
||||
if cross_up:
|
||||
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "buy", "price": float(curr["close"])})
|
||||
elif cross_down:
|
||||
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "sell", "price": float(curr["close"])})
|
||||
return signals
|
||||
|
||||
|
||||
def resolve_candidates(provider: MarketDataProvider, query: str, limit: int) -> pd.DataFrame:
|
||||
return provider.search_spot(query=query, limit=limit)
|
||||
|
||||
|
||||
def _pick_first_candidate(provider: MarketDataProvider, query: str) -> pd.Series:
|
||||
candidates = resolve_candidates(provider, query=query, limit=1)
|
||||
if candidates.empty:
|
||||
raise RuntimeError(f"未找到匹配股票: {query}")
|
||||
row = candidates.iloc[0]
|
||||
code = str(row.get("代码", "")).strip()
|
||||
if not code:
|
||||
raise RuntimeError("行情返回缺少代码字段")
|
||||
return row
|
||||
|
||||
|
||||
def _build_info(provider: MarketDataProvider, row: pd.Series) -> dict[str, Any]:
|
||||
code = str(row.get("代码", "")).strip()
|
||||
return {
|
||||
"code": code,
|
||||
"name": str(row.get("名称", "")),
|
||||
"price": _safe(row.get("最新价")),
|
||||
"change_pct": _safe(row.get("涨跌幅")),
|
||||
"change": _safe(row.get("涨跌额")),
|
||||
"volume": _safe(row.get("成交量")),
|
||||
"turnover": _safe(row.get("成交额")),
|
||||
"amplitude": _safe(row.get("振幅")),
|
||||
"high": _safe(row.get("最高")),
|
||||
"low": _safe(row.get("最低")),
|
||||
"open": _safe(row.get("今开")),
|
||||
"prev_close": _safe(row.get("昨收")),
|
||||
"pe": _safe(row.get("市盈率-动态")),
|
||||
"pb": _safe(row.get("市净率")),
|
||||
"market_cap": _safe(row.get("总市值")),
|
||||
"float_market_cap": _safe(row.get("流通市值")),
|
||||
"provider": provider.provider_name,
|
||||
"channel": provider.channel,
|
||||
}
|
||||
|
||||
|
||||
def build_realtime_info(provider: MarketDataProvider, query: str) -> dict[str, Any]:
|
||||
row = _pick_first_candidate(provider, query=query)
|
||||
info = _build_info(provider=provider, row=row)
|
||||
now = datetime.now()
|
||||
price = info.get("price")
|
||||
volume = info.get("volume")
|
||||
realtime_point = {
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"close": float(price) if price is not None else None,
|
||||
"volume": float(volume) if volume is not None else 0.0,
|
||||
}
|
||||
return {
|
||||
"info": info,
|
||||
"realtime_point": realtime_point,
|
||||
"source": "realtime",
|
||||
"updated_at": now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
|
||||
_DASHBOARD_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||
_DASHBOARD_CACHE_TTL_SECONDS = 15.0
|
||||
|
||||
|
||||
def build_dashboard(
|
||||
provider: MarketDataProvider,
|
||||
query: str,
|
||||
days: int,
|
||||
sma_fast: int,
|
||||
sma_slow: int,
|
||||
) -> dict[str, Any]:
|
||||
if sma_fast >= sma_slow:
|
||||
raise RuntimeError("sma_fast must be less than sma_slow")
|
||||
|
||||
row = _pick_first_candidate(provider, query=query)
|
||||
code = str(row.get("代码", "")).strip()
|
||||
cache_key = "|".join([provider.provider_name, provider.channel, code, str(days), str(sma_fast), str(sma_slow)])
|
||||
now = time()
|
||||
cached = _DASHBOARD_CACHE.get(cache_key)
|
||||
if cached is not None and now - cached[0] <= _DASHBOARD_CACHE_TTL_SECONDS:
|
||||
return cached[1]
|
||||
|
||||
end = datetime.now()
|
||||
lookback_days = max(days + (sma_slow * 3), days + 30)
|
||||
start = end - timedelta(days=lookback_days)
|
||||
hist = provider.fetch_daily_kline(code=code, start=start, end=end).tail(days).reset_index(drop=True)
|
||||
if hist.empty:
|
||||
raise RuntimeError(f"未获取到 K 线数据: {code}")
|
||||
|
||||
signals = build_signals(hist, sma_fast=sma_fast, sma_slow=sma_slow)
|
||||
hist["sma_fast"] = hist["close"].rolling(sma_fast).mean()
|
||||
hist["sma_slow"] = hist["close"].rolling(sma_slow).mean()
|
||||
|
||||
points = [
|
||||
{
|
||||
"date": d.strftime("%Y-%m-%d"),
|
||||
"close": float(c),
|
||||
"sma_fast": _safe(f),
|
||||
"sma_slow": _safe(s),
|
||||
"volume": float(v),
|
||||
}
|
||||
for d, c, f, s, v in zip(hist["date"], hist["close"], hist["sma_fast"], hist["sma_slow"], hist["volume"])
|
||||
]
|
||||
|
||||
result = {"info": _build_info(provider=provider, row=row), "points": points, "signals": signals}
|
||||
_DASHBOARD_CACHE[cache_key] = (now, result)
|
||||
return result
|
||||
Reference in New Issue
Block a user