Files
AITrading/python-app/app/market/service.py
2026-03-26 14:13:44 +08:00

165 lines
5.4 KiB
Python

from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any
from time import time
import pandas as pd
from market.provider_base import MarketDataProvider
DISPLAY_FIELDS = [
"代码",
"名称",
"最新价",
"涨跌幅",
"涨跌额",
"成交量",
"成交额",
"振幅",
"最高",
"最低",
"今开",
"昨收",
"市盈率-动态",
"市净率",
"总市值",
"流通市值",
]
def _safe(v: Any) -> Any:
if v is None:
return None
if isinstance(v, float) and pd.isna(v):
return None
return v
def build_signals(hist: pd.DataFrame, sma_fast: int, sma_slow: int) -> list[dict[str, Any]]:
frame = hist.copy()
frame["sma_fast"] = frame["close"].rolling(sma_fast).mean()
frame["sma_slow"] = frame["close"].rolling(sma_slow).mean()
signals: list[dict[str, Any]] = []
for i in range(1, len(frame)):
prev = frame.iloc[i - 1]
curr = frame.iloc[i]
if pd.isna(prev["sma_fast"]) or pd.isna(prev["sma_slow"]) or pd.isna(curr["sma_fast"]) or pd.isna(curr["sma_slow"]):
continue
cross_up = prev["sma_fast"] <= prev["sma_slow"] and curr["sma_fast"] > curr["sma_slow"]
cross_down = prev["sma_fast"] >= prev["sma_slow"] and curr["sma_fast"] < curr["sma_slow"]
if cross_up:
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "buy", "price": float(curr["close"])})
elif cross_down:
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "sell", "price": float(curr["close"])})
return signals
def resolve_candidates(provider: MarketDataProvider, query: str, limit: int) -> pd.DataFrame:
return provider.search_spot(query=query, limit=limit)
def _pick_first_candidate(provider: MarketDataProvider, query: str) -> pd.Series:
candidates = resolve_candidates(provider, query=query, limit=1)
if candidates.empty:
raise RuntimeError(f"未找到匹配股票: {query}")
row = candidates.iloc[0]
code = str(row.get("代码", "")).strip()
if not code:
raise RuntimeError("行情返回缺少代码字段")
return row
def _build_info(provider: MarketDataProvider, row: pd.Series) -> dict[str, Any]:
code = str(row.get("代码", "")).strip()
return {
"code": code,
"name": str(row.get("名称", "")),
"price": _safe(row.get("最新价")),
"change_pct": _safe(row.get("涨跌幅")),
"change": _safe(row.get("涨跌额")),
"volume": _safe(row.get("成交量")),
"turnover": _safe(row.get("成交额")),
"amplitude": _safe(row.get("振幅")),
"high": _safe(row.get("最高")),
"low": _safe(row.get("最低")),
"open": _safe(row.get("今开")),
"prev_close": _safe(row.get("昨收")),
"pe": _safe(row.get("市盈率-动态")),
"pb": _safe(row.get("市净率")),
"market_cap": _safe(row.get("总市值")),
"float_market_cap": _safe(row.get("流通市值")),
"provider": provider.provider_name,
"channel": provider.channel,
}
def build_realtime_info(provider: MarketDataProvider, query: str) -> dict[str, Any]:
row = _pick_first_candidate(provider, query=query)
info = _build_info(provider=provider, row=row)
now = datetime.now()
price = info.get("price")
volume = info.get("volume")
realtime_point = {
"date": now.strftime("%Y-%m-%d"),
"close": float(price) if price is not None else None,
"volume": float(volume) if volume is not None else 0.0,
}
return {
"info": info,
"realtime_point": realtime_point,
"source": "realtime",
"updated_at": now.strftime("%Y-%m-%d %H:%M:%S"),
}
_DASHBOARD_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
_DASHBOARD_CACHE_TTL_SECONDS = 15.0
def build_dashboard(
provider: MarketDataProvider,
query: str,
days: int,
sma_fast: int,
sma_slow: int,
) -> dict[str, Any]:
if sma_fast >= sma_slow:
raise RuntimeError("sma_fast must be less than sma_slow")
row = _pick_first_candidate(provider, query=query)
code = str(row.get("代码", "")).strip()
cache_key = "|".join([provider.provider_name, provider.channel, code, str(days), str(sma_fast), str(sma_slow)])
now = time()
cached = _DASHBOARD_CACHE.get(cache_key)
if cached is not None and now - cached[0] <= _DASHBOARD_CACHE_TTL_SECONDS:
return cached[1]
end = datetime.now()
lookback_days = max(days + (sma_slow * 3), days + 30)
start = end - timedelta(days=lookback_days)
hist = provider.fetch_daily_kline(code=code, start=start, end=end).tail(days).reset_index(drop=True)
if hist.empty:
raise RuntimeError(f"未获取到 K 线数据: {code}")
signals = build_signals(hist, sma_fast=sma_fast, sma_slow=sma_slow)
hist["sma_fast"] = hist["close"].rolling(sma_fast).mean()
hist["sma_slow"] = hist["close"].rolling(sma_slow).mean()
points = [
{
"date": d.strftime("%Y-%m-%d"),
"close": float(c),
"sma_fast": _safe(f),
"sma_slow": _safe(s),
"volume": float(v),
}
for d, c, f, s, v in zip(hist["date"], hist["close"], hist["sma_fast"], hist["sma_slow"], hist["volume"])
]
result = {"info": _build_info(provider=provider, row=row), "points": points, "signals": signals}
_DASHBOARD_CACHE[cache_key] = (now, result)
return result