165 lines
5.4 KiB
Python
165 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
from time import time
|
|
|
|
import pandas as pd
|
|
|
|
from market.provider_base import MarketDataProvider
|
|
|
|
|
|
DISPLAY_FIELDS = [
|
|
"代码",
|
|
"名称",
|
|
"最新价",
|
|
"涨跌幅",
|
|
"涨跌额",
|
|
"成交量",
|
|
"成交额",
|
|
"振幅",
|
|
"最高",
|
|
"最低",
|
|
"今开",
|
|
"昨收",
|
|
"市盈率-动态",
|
|
"市净率",
|
|
"总市值",
|
|
"流通市值",
|
|
]
|
|
|
|
|
|
def _safe(v: Any) -> Any:
|
|
if v is None:
|
|
return None
|
|
if isinstance(v, float) and pd.isna(v):
|
|
return None
|
|
return v
|
|
|
|
|
|
def build_signals(hist: pd.DataFrame, sma_fast: int, sma_slow: int) -> list[dict[str, Any]]:
|
|
frame = hist.copy()
|
|
frame["sma_fast"] = frame["close"].rolling(sma_fast).mean()
|
|
frame["sma_slow"] = frame["close"].rolling(sma_slow).mean()
|
|
signals: list[dict[str, Any]] = []
|
|
|
|
for i in range(1, len(frame)):
|
|
prev = frame.iloc[i - 1]
|
|
curr = frame.iloc[i]
|
|
if pd.isna(prev["sma_fast"]) or pd.isna(prev["sma_slow"]) or pd.isna(curr["sma_fast"]) or pd.isna(curr["sma_slow"]):
|
|
continue
|
|
cross_up = prev["sma_fast"] <= prev["sma_slow"] and curr["sma_fast"] > curr["sma_slow"]
|
|
cross_down = prev["sma_fast"] >= prev["sma_slow"] and curr["sma_fast"] < curr["sma_slow"]
|
|
if cross_up:
|
|
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "buy", "price": float(curr["close"])})
|
|
elif cross_down:
|
|
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "sell", "price": float(curr["close"])})
|
|
return signals
|
|
|
|
|
|
def resolve_candidates(provider: MarketDataProvider, query: str, limit: int) -> pd.DataFrame:
|
|
return provider.search_spot(query=query, limit=limit)
|
|
|
|
|
|
def _pick_first_candidate(provider: MarketDataProvider, query: str) -> pd.Series:
|
|
candidates = resolve_candidates(provider, query=query, limit=1)
|
|
if candidates.empty:
|
|
raise RuntimeError(f"未找到匹配股票: {query}")
|
|
row = candidates.iloc[0]
|
|
code = str(row.get("代码", "")).strip()
|
|
if not code:
|
|
raise RuntimeError("行情返回缺少代码字段")
|
|
return row
|
|
|
|
|
|
def _build_info(provider: MarketDataProvider, row: pd.Series) -> dict[str, Any]:
|
|
code = str(row.get("代码", "")).strip()
|
|
return {
|
|
"code": code,
|
|
"name": str(row.get("名称", "")),
|
|
"price": _safe(row.get("最新价")),
|
|
"change_pct": _safe(row.get("涨跌幅")),
|
|
"change": _safe(row.get("涨跌额")),
|
|
"volume": _safe(row.get("成交量")),
|
|
"turnover": _safe(row.get("成交额")),
|
|
"amplitude": _safe(row.get("振幅")),
|
|
"high": _safe(row.get("最高")),
|
|
"low": _safe(row.get("最低")),
|
|
"open": _safe(row.get("今开")),
|
|
"prev_close": _safe(row.get("昨收")),
|
|
"pe": _safe(row.get("市盈率-动态")),
|
|
"pb": _safe(row.get("市净率")),
|
|
"market_cap": _safe(row.get("总市值")),
|
|
"float_market_cap": _safe(row.get("流通市值")),
|
|
"provider": provider.provider_name,
|
|
"channel": provider.channel,
|
|
}
|
|
|
|
|
|
def build_realtime_info(provider: MarketDataProvider, query: str) -> dict[str, Any]:
|
|
row = _pick_first_candidate(provider, query=query)
|
|
info = _build_info(provider=provider, row=row)
|
|
now = datetime.now()
|
|
price = info.get("price")
|
|
volume = info.get("volume")
|
|
realtime_point = {
|
|
"date": now.strftime("%Y-%m-%d"),
|
|
"close": float(price) if price is not None else None,
|
|
"volume": float(volume) if volume is not None else 0.0,
|
|
}
|
|
return {
|
|
"info": info,
|
|
"realtime_point": realtime_point,
|
|
"source": "realtime",
|
|
"updated_at": now.strftime("%Y-%m-%d %H:%M:%S"),
|
|
}
|
|
|
|
|
|
_DASHBOARD_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
|
|
_DASHBOARD_CACHE_TTL_SECONDS = 15.0
|
|
|
|
|
|
def build_dashboard(
|
|
provider: MarketDataProvider,
|
|
query: str,
|
|
days: int,
|
|
sma_fast: int,
|
|
sma_slow: int,
|
|
) -> dict[str, Any]:
|
|
if sma_fast >= sma_slow:
|
|
raise RuntimeError("sma_fast must be less than sma_slow")
|
|
|
|
row = _pick_first_candidate(provider, query=query)
|
|
code = str(row.get("代码", "")).strip()
|
|
cache_key = "|".join([provider.provider_name, provider.channel, code, str(days), str(sma_fast), str(sma_slow)])
|
|
now = time()
|
|
cached = _DASHBOARD_CACHE.get(cache_key)
|
|
if cached is not None and now - cached[0] <= _DASHBOARD_CACHE_TTL_SECONDS:
|
|
return cached[1]
|
|
|
|
end = datetime.now()
|
|
lookback_days = max(days + (sma_slow * 3), days + 30)
|
|
start = end - timedelta(days=lookback_days)
|
|
hist = provider.fetch_daily_kline(code=code, start=start, end=end).tail(days).reset_index(drop=True)
|
|
if hist.empty:
|
|
raise RuntimeError(f"未获取到 K 线数据: {code}")
|
|
|
|
signals = build_signals(hist, sma_fast=sma_fast, sma_slow=sma_slow)
|
|
hist["sma_fast"] = hist["close"].rolling(sma_fast).mean()
|
|
hist["sma_slow"] = hist["close"].rolling(sma_slow).mean()
|
|
|
|
points = [
|
|
{
|
|
"date": d.strftime("%Y-%m-%d"),
|
|
"close": float(c),
|
|
"sma_fast": _safe(f),
|
|
"sma_slow": _safe(s),
|
|
"volume": float(v),
|
|
}
|
|
for d, c, f, s, v in zip(hist["date"], hist["close"], hist["sma_fast"], hist["sma_slow"], hist["volume"])
|
|
]
|
|
|
|
result = {"info": _build_info(provider=provider, row=row), "points": points, "signals": signals}
|
|
_DASHBOARD_CACHE[cache_key] = (now, result)
|
|
return result
|