feat: 修复报错

This commit is contained in:
Daniel
2026-03-26 14:13:44 +08:00
commit b2223ec058
31 changed files with 17401 additions and 0 deletions

24
python-app/Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
ARG DOCKER_MIRROR_PREFIX=m.daocloud.io/docker.io/library
FROM ${DOCKER_MIRROR_PREFIX}/python:3.11-slim-bookworm
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ \
PIP_TRUSTED_HOST=mirrors.aliyun.com \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
&& sed -i 's|security.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
build-essential \
tzdata \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY app ./app
CMD ["python", "app/backtest_runner.py", "--help"]

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
import argparse
import ast
import importlib
import json
from pathlib import Path
from typing import Any, Type
import backtrader as bt
import pandas as pd
from data_protocol import KlineMessage
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Backtrader runner")
parser.add_argument("--input", required=True, help="Path to kline jsonl file")
parser.add_argument(
"--strategy",
default="strategies.sma_cross.SmaCrossStrategy",
help="Strategy class path, e.g. strategies.sma_cross.SmaCrossStrategy",
)
parser.add_argument("--cash", type=float, default=100000.0, help="Initial cash")
parser.add_argument("--commission", type=float, default=0.001, help="Commission")
parser.add_argument(
"--strategy-param",
action="append",
default=[],
help="Strategy parameter in key=value format, repeatable",
)
parser.add_argument("--plot", action="store_true", help="Plot result")
return parser.parse_args()
def load_strategy_class(class_path: str) -> Type[bt.Strategy]:
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
strategy_cls = getattr(module, class_name)
if not issubclass(strategy_cls, bt.Strategy):
raise TypeError(f"{class_path} is not a backtrader.Strategy subclass")
return strategy_cls
def load_dataframe(path: Path) -> pd.DataFrame:
rows = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
msg = KlineMessage.from_dict(payload)
rows.append(
{
"datetime": pd.to_datetime(msg.open_time, unit="ms"),
"open": float(msg.open),
"high": float(msg.high),
"low": float(msg.low),
"close": float(msg.close),
"volume": float(msg.volume),
"openinterest": 0.0,
}
)
if not rows:
raise RuntimeError("no kline rows found for backtest")
frame = pd.DataFrame(rows).set_index("datetime")
return frame.sort_index()
def parse_strategy_params(raw_params: list[str]) -> dict[str, Any]:
parsed: dict[str, Any] = {}
for item in raw_params:
if "=" not in item:
raise ValueError(f"invalid --strategy-param '{item}', expected key=value")
key, raw_value = item.split("=", 1)
key = key.strip()
raw_value = raw_value.strip()
if not key:
raise ValueError(f"invalid --strategy-param '{item}', empty key")
try:
value = ast.literal_eval(raw_value)
except (ValueError, SyntaxError):
value = raw_value
parsed[key] = value
return parsed
def extract_report(result: list[bt.Strategy], final_value: float) -> dict[str, Any]:
strategy = result[0]
drawdown_info = strategy.analyzers.drawdown.get_analysis()
sharpe_info = strategy.analyzers.sharpe.get_analysis()
max_drawdown = drawdown_info.get("max", {}).get("drawdown")
sharpe_ratio = sharpe_info.get("sharperatio")
return {
"final_value": final_value,
"max_drawdown_pct": None if max_drawdown is None else float(max_drawdown),
"sharpe_ratio": None if sharpe_ratio is None else float(sharpe_ratio),
}
def print_report(report: dict[str, Any]) -> None:
print("\n===== Backtest Report =====")
print(f"Final Value : {report['final_value']:.2f}")
max_drawdown = report["max_drawdown_pct"]
sharpe_ratio = report["sharpe_ratio"]
print(f"Max Drawdown (%) : {'N/A' if max_drawdown is None else f'{max_drawdown:.4f}'}")
print(f"Sharpe Ratio : {'N/A' if sharpe_ratio is None else f'{sharpe_ratio:.4f}'}")
print("===========================\n")
def run_backtest() -> None:
args = parse_args()
data_path = Path(args.input)
if not data_path.exists():
raise FileNotFoundError(f"input file not found: {data_path}")
strategy_cls = load_strategy_class(args.strategy)
strategy_params = parse_strategy_params(args.strategy_param)
df = load_dataframe(data_path)
cerebro = bt.Cerebro()
cerebro.addstrategy(strategy_cls, **strategy_params)
cerebro.broker.setcash(args.cash)
cerebro.broker.setcommission(commission=args.commission)
cerebro.adddata(bt.feeds.PandasData(dataname=df))
cerebro.addanalyzer(bt.analyzers.DrawDown, _name="drawdown")
cerebro.addanalyzer(bt.analyzers.SharpeRatio_A, _name="sharpe", riskfreerate=0.0)
print(f"starting portfolio value: {cerebro.broker.getvalue():.2f}")
result = cerebro.run()
final_value = cerebro.broker.getvalue()
print(f"final portfolio value: {final_value:.2f}")
print(f"strategies executed: {len(result)}")
print_report(extract_report(result, final_value))
if args.plot:
cerebro.plot(style="candlestick")
if __name__ == "__main__":
run_backtest()

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class KlineMessage:
type: str
source: str
symbol: str
interval: str
event_time: int
open_time: int
close_time: int
open: str
high: str
low: str
close: str
volume: str
trade_num: int
final: bool
@classmethod
def from_dict(cls, payload: Dict[str, Any]) -> "KlineMessage":
return cls(
type=str(payload["type"]),
source=str(payload["source"]),
symbol=str(payload["symbol"]),
interval=str(payload["interval"]),
event_time=int(payload["event_time"]),
open_time=int(payload["open_time"]),
close_time=int(payload["close_time"]),
open=str(payload["open"]),
high=str(payload["high"]),
low=str(payload["low"]),
close=str(payload["close"]),
volume=str(payload["volume"]),
trade_num=int(payload["trade_num"]),
final=bool(payload["final"]),
)
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.type,
"source": self.source,
"symbol": self.symbol,
"interval": self.interval,
"event_time": self.event_time,
"open_time": self.open_time,
"close_time": self.close_time,
"open": self.open,
"high": self.high,
"low": self.low,
"close": self.close,
"volume": self.volume,
"trade_num": self.trade_num,
"final": self.final,
}

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import argparse
import json
from collections import OrderedDict
import matplotlib.pyplot as plt
import mplfinance as mpf
import pandas as pd
import redis
from data_protocol import KlineMessage
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Realtime K-line viewer from Redis")
parser.add_argument("--redis-host", default="127.0.0.1", help="Redis host")
parser.add_argument("--redis-port", type=int, default=6379, help="Redis port")
parser.add_argument("--redis-db", type=int, default=0, help="Redis DB")
parser.add_argument("--redis-password", default="", help="Redis password")
parser.add_argument("--channel", default="kline.stream", help="Pub/Sub channel")
parser.add_argument("--window", type=int, default=80, help="Window size to render")
parser.add_argument("--only-final", action="store_true", help="Render only closed candles")
parser.add_argument("--timeout", type=float, default=1.0, help="Subscriber timeout (seconds)")
return parser.parse_args()
def upsert_message(buffer: OrderedDict[int, KlineMessage], message: KlineMessage, max_size: int) -> None:
buffer[message.open_time] = message
buffer.move_to_end(message.open_time)
while len(buffer) > max_size:
buffer.popitem(last=False)
def to_dataframe(messages: list[KlineMessage]) -> pd.DataFrame:
rows = []
for msg in messages:
rows.append(
{
"Date": pd.to_datetime(msg.open_time, unit="ms"),
"Open": float(msg.open),
"High": float(msg.high),
"Low": float(msg.low),
"Close": float(msg.close),
"Volume": float(msg.volume),
}
)
frame = pd.DataFrame(rows).set_index("Date")
return frame
def main() -> None:
args = parse_args()
plt.ion()
client = redis.Redis(
host=args.redis_host,
port=args.redis_port,
db=args.redis_db,
password=args.redis_password or None,
decode_responses=True,
)
pubsub = client.pubsub(ignore_subscribe_messages=True)
pubsub.subscribe(args.channel)
print(
f"subscribed redis channel={args.channel} "
f"at {args.redis_host}:{args.redis_port}/{args.redis_db}"
)
candles: OrderedDict[int, KlineMessage] = OrderedDict()
try:
while True:
packet = pubsub.get_message(timeout=args.timeout)
if not packet:
plt.pause(0.01)
continue
raw = packet.get("data")
if not isinstance(raw, str):
continue
payload = json.loads(raw)
message = KlineMessage.from_dict(payload)
if args.only_final and not message.final:
continue
upsert_message(candles, message, args.window)
frame = to_dataframe(list(candles.values()))
if frame.empty:
continue
plt.clf()
mpf.plot(
frame,
type="candle",
style="yahoo",
volume=True,
datetime_format="%H:%M",
tight_layout=True,
)
plt.pause(0.001)
except KeyboardInterrupt:
print("viewer stopped by user")
finally:
pubsub.close()
client.close()
plt.ioff()
plt.show()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,4 @@
from market.config import load_market_config
from market.factory import create_provider
__all__ = ["load_market_config", "create_provider"]

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import os
from dataclasses import dataclass
@dataclass
class MarketConfig:
channel: str
provider: str
cmes_token: str
futu_host: str
futu_port: int
futu_is_encrypt: bool | None
futu_market: str
def load_market_config() -> MarketConfig:
channel = os.getenv("MARKET_CHANNEL", "cn").strip().lower()
provider = os.getenv("MARKET_PROVIDER", "akshare").strip().lower()
cmes_token = os.getenv("CMES_TOKEN", "").strip()
futu_host = os.getenv("FUTU_HOST", "127.0.0.1").strip()
futu_port = int(os.getenv("FUTU_PORT", "11111").strip())
futu_encrypt_raw = os.getenv("FUTU_IS_ENCRYPT", "").strip().lower()
if futu_encrypt_raw in {"1", "true", "yes", "y"}:
futu_is_encrypt = True
elif futu_encrypt_raw in {"0", "false", "no", "n"}:
futu_is_encrypt = False
else:
futu_is_encrypt = None
futu_market = os.getenv("FUTU_MARKET", "").strip().lower()
return MarketConfig(
channel=channel,
provider=provider,
cmes_token=cmes_token,
futu_host=futu_host,
futu_port=futu_port,
futu_is_encrypt=futu_is_encrypt,
futu_market=futu_market,
)

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from market.config import MarketConfig
from market.provider_base import MarketDataProvider
from market.providers.akshare_provider import AkshareCnProvider
from market.providers.cmes_provider import CmesCnProvider
def create_provider(config: MarketConfig) -> MarketDataProvider:
if config.channel == "cn" and config.provider == "akshare":
return AkshareCnProvider()
if config.channel == "cn" and config.provider == "cmesdata":
return CmesCnProvider(token=config.cmes_token)
if config.provider == "futu" and config.channel in {"cn", "hk", "us"}:
from market.providers.futu_provider import FutuProvider
return FutuProvider(
channel=config.channel,
host=config.futu_host,
port=config.futu_port,
is_encrypt=config.futu_is_encrypt,
market=config.futu_market,
)
if config.channel in {"us", "hk"}:
raise RuntimeError(
f"channel={config.channel} provider={config.provider} 尚未实现,"
"请新增 Provider 后接入 factory。"
)
raise RuntimeError(f"不支持的通道配置: channel={config.channel}, provider={config.provider}")

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
import pandas as pd
class MarketDataProvider(ABC):
@property
@abstractmethod
def provider_name(self) -> str:
raise NotImplementedError
@property
@abstractmethod
def channel(self) -> str:
raise NotImplementedError
@abstractmethod
def fetch_spot(self) -> pd.DataFrame:
raise NotImplementedError
@abstractmethod
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
raise NotImplementedError
@abstractmethod
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
raise NotImplementedError

View File

@@ -0,0 +1,12 @@
from market.providers.akshare_provider import AkshareCnProvider
from market.providers.cmes_provider import CmesCnProvider
__all__ = ["AkshareCnProvider", "CmesCnProvider"]
try:
from market.providers.futu_provider import FutuProvider
__all__.append("FutuProvider")
except Exception:
# Allow non-futu environments to keep using other providers.
pass

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from datetime import datetime
import akshare as ak
import pandas as pd
from market.provider_base import MarketDataProvider
class AkshareCnProvider(MarketDataProvider):
@property
def provider_name(self) -> str:
return "akshare"
@property
def channel(self) -> str:
return "cn"
def fetch_spot(self) -> pd.DataFrame:
df = ak.stock_zh_a_spot_em()
if df.empty:
raise RuntimeError("未获取到 A 股实时行情数据")
return df
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
q = query.strip().lower()
if not q:
return pd.DataFrame()
df = self.fetch_spot()
code = df["代码"].astype(str).str.lower()
name = df["名称"].astype(str).str.lower()
exact = df[(code == q) | (name == q)]
if not exact.empty:
return exact.head(limit)
starts = df[code.str.startswith(q) | name.str.startswith(q)]
if not starts.empty:
return starts.head(limit)
return df[code.str.contains(q, na=False) | name.str.contains(q, na=False)].head(limit)
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
hist = ak.stock_zh_a_hist(
symbol=code,
period="daily",
start_date=start.strftime("%Y%m%d"),
end_date=end.strftime("%Y%m%d"),
adjust="qfq",
)
if hist.empty:
raise RuntimeError(f"未获取到 K 线数据: {code}")
frame = pd.DataFrame(
{
"date": pd.to_datetime(hist["日期"]),
"close": pd.to_numeric(hist["收盘"], errors="coerce"),
"volume": pd.to_numeric(hist["成交量"], errors="coerce"),
}
).dropna()
return frame.sort_values("date").reset_index(drop=True)

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
import importlib
import os
from datetime import datetime
import pandas as pd
from market.provider_base import MarketDataProvider
class CmesCnProvider(MarketDataProvider):
def __init__(self, token: str | None = None) -> None:
self._token = token or os.getenv("CMES_TOKEN", "").strip()
self._module = importlib.import_module("cmesdata")
self._login_once()
@property
def provider_name(self) -> str:
return "cmesdata"
@property
def channel(self) -> str:
return "cn"
def _login_once(self) -> None:
if not self._token:
raise RuntimeError("CMES_TOKEN 未配置,无法使用 cmesdata 通道")
self._module.login(self._token)
@staticmethod
def _to_prefixed_code(code: str) -> str:
raw = code.strip().upper().replace(".", "")
if raw.startswith("SH") or raw.startswith("SZ"):
return f"{raw[:2]}.{raw[2:]}"
if raw.isdigit() and len(raw) == 6:
if raw.startswith(("6", "9")):
return f"SH.{raw}"
return f"SZ.{raw}"
raise RuntimeError("cmesdata 通道仅支持 6 位 A 股代码或 SH./SZ. 前缀代码")
def fetch_spot(self) -> pd.DataFrame:
raise RuntimeError("cmesdata 不支持全市场快照拉取,请通过精确代码查询")
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
_ = limit
code = self._to_prefixed_code(query)
df = self._module.get_real_hq([code])
if df is None or df.empty:
raise RuntimeError(f"未获取到实时行情: {code}")
return df
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
prefixed = self._to_prefixed_code(code)
df = self._module.get_history_data(
prefixed,
start.strftime("%Y-%m-%d"),
end.strftime("%Y-%m-%d"),
"D",
)
if df is None or df.empty:
raise RuntimeError(f"未获取到历史 K 线: {prefixed}")
frame = pd.DataFrame(
{
"date": pd.to_datetime(df["时间"]),
"close": pd.to_numeric(df["收盘价"], errors="coerce"),
"volume": pd.to_numeric(df["成交量"], errors="coerce"),
}
).dropna()
return frame.sort_values("date").reset_index(drop=True)

View File

@@ -0,0 +1,207 @@
from __future__ import annotations
from datetime import datetime
import pandas as pd
from futu import AuType, KLType, Market, OpenQuoteContext, RET_OK, SecurityType
from market.provider_base import MarketDataProvider
class FutuProvider(MarketDataProvider):
def __init__(
self,
channel: str,
host: str = "127.0.0.1",
port: int = 11111,
is_encrypt: bool | None = None,
market: str = "",
) -> None:
self._channel = channel.strip().lower()
self._host = host
self._port = int(port)
self._is_encrypt = is_encrypt
self._market = (market or self._channel).strip().lower()
self._ctx = OpenQuoteContext(host=self._host, port=self._port, is_encrypt=self._is_encrypt)
self._basicinfo_cache: pd.DataFrame | None = None
@property
def provider_name(self) -> str:
return "futu"
@property
def channel(self) -> str:
return self._channel
def __del__(self) -> None:
try:
self._ctx.close()
except Exception:
pass
def _require_market(self) -> Market:
if self._market == "cn":
return Market.SH
if self._market == "hk":
return Market.HK
if self._market == "us":
return Market.US
raise RuntimeError(f"不支持的 FUTU_MARKET: {self._market},可选: cn/hk/us")
def _normalize_code(self, query: str) -> str:
q = query.strip().upper()
if not q:
return ""
if "." in q:
return q
if self._market == "cn":
if q.isdigit() and len(q) == 6:
prefix = "SH" if q.startswith(("5", "6", "9")) else "SZ"
return f"{prefix}.{q}"
return q
if self._market == "hk":
if q.isdigit():
return f"HK.{q.zfill(5)}"
return f"HK.{q}"
if self._market == "us":
return f"US.{q}"
return q
def _snapshot_to_unified(self, df: pd.DataFrame) -> pd.DataFrame:
frame = df.copy()
if "code" not in frame.columns:
raise RuntimeError("Futu 快照返回缺少 code 字段")
if "name" not in frame.columns:
frame["name"] = ""
frame["涨跌额"] = pd.to_numeric(frame.get("last_price"), errors="coerce") - pd.to_numeric(
frame.get("prev_close_price"), errors="coerce"
)
frame["市盈率-动态"] = frame.get("pe_ttm_ratio", frame.get("pe_ratio"))
frame["市净率"] = frame.get("pb_ratio")
frame["总市值"] = frame.get("total_market_val")
frame["流通市值"] = frame.get("circular_market_val")
frame["代码"] = frame["code"].astype(str)
frame["名称"] = frame["name"]
frame["最新价"] = frame.get("last_price")
frame["涨跌幅"] = frame.get("change_rate")
frame["成交量"] = frame.get("volume")
frame["成交额"] = frame.get("turnover")
frame["振幅"] = frame.get("amplitude")
frame["最高"] = frame.get("high_price")
frame["最低"] = frame.get("low_price")
frame["今开"] = frame.get("open_price")
frame["昨收"] = frame.get("prev_close_price")
columns = [
"代码",
"名称",
"最新价",
"涨跌幅",
"涨跌额",
"成交量",
"成交额",
"振幅",
"最高",
"最低",
"今开",
"昨收",
"市盈率-动态",
"市净率",
"总市值",
"流通市值",
]
return frame[columns]
def _get_snapshot(self, codes: list[str]) -> pd.DataFrame:
ret, data = self._ctx.get_market_snapshot(codes)
if ret != RET_OK:
raise RuntimeError(f"Futu 获取快照失败: {data}")
if data is None or data.empty:
return pd.DataFrame()
return self._snapshot_to_unified(data)
def _load_basicinfo(self) -> pd.DataFrame:
if self._basicinfo_cache is not None:
return self._basicinfo_cache
market = self._require_market()
ret, data = self._ctx.get_stock_basicinfo(market=market, stock_type=SecurityType.STOCK)
if ret != RET_OK:
raise RuntimeError(f"Futu 获取股票基础信息失败: {data}")
if data is None:
self._basicinfo_cache = pd.DataFrame(columns=["code", "name"])
return self._basicinfo_cache
frame = data.copy()
frame["code"] = frame["code"].astype(str)
frame["name"] = frame["name"].astype(str)
self._basicinfo_cache = frame
return frame
def fetch_spot(self) -> pd.DataFrame:
raise RuntimeError("Futu 不支持直接拉取全市场实时快照,请使用 search_spot")
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
q = query.strip()
if not q:
return pd.DataFrame()
code = self._normalize_code(q)
if code:
exact = self._get_snapshot([code])
if not exact.empty:
return exact.head(limit)
basic = self._load_basicinfo()
q_lower = q.lower()
code_col = basic["code"].astype(str)
name_col = basic["name"].astype(str)
mask = (
code_col.str.lower().eq(q_lower)
| name_col.str.lower().eq(q_lower)
| code_col.str.lower().str.startswith(q_lower)
| name_col.str.lower().str.startswith(q_lower)
| code_col.str.lower().str.contains(q_lower, na=False)
| name_col.str.lower().str.contains(q_lower, na=False)
)
candidates = basic.loc[mask, "code"].drop_duplicates().head(max(limit * 2, 20)).tolist()
if not candidates:
return pd.DataFrame()
snap = self._get_snapshot(candidates)
if snap.empty:
return pd.DataFrame()
return snap.head(limit)
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
normalized = self._normalize_code(code)
page_key = None
frames: list[pd.DataFrame] = []
while True:
ret, data, page_key = self._ctx.request_history_kline(
normalized,
start=start.strftime("%Y-%m-%d"),
end=end.strftime("%Y-%m-%d"),
ktype=KLType.K_DAY,
autype=AuType.QFQ,
max_count=1000,
page_req_key=page_key,
)
if ret != RET_OK:
raise RuntimeError(f"Futu 获取历史 K 线失败: {data}")
if data is not None and not data.empty:
frames.append(data.copy())
if page_key is None:
break
if not frames:
raise RuntimeError(f"未获取到 K 线数据: {normalized}")
full = pd.concat(frames, ignore_index=True)
frame = pd.DataFrame(
{
"date": pd.to_datetime(full["time_key"]),
"close": pd.to_numeric(full["close"], errors="coerce"),
"volume": pd.to_numeric(full["volume"], errors="coerce"),
}
).dropna()
if frame.empty:
raise RuntimeError(f"K 线数据为空: {normalized}")
return frame.sort_values("date").reset_index(drop=True)

View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any
from time import time
import pandas as pd
from market.provider_base import MarketDataProvider
DISPLAY_FIELDS = [
"代码",
"名称",
"最新价",
"涨跌幅",
"涨跌额",
"成交量",
"成交额",
"振幅",
"最高",
"最低",
"今开",
"昨收",
"市盈率-动态",
"市净率",
"总市值",
"流通市值",
]
def _safe(v: Any) -> Any:
if v is None:
return None
if isinstance(v, float) and pd.isna(v):
return None
return v
def build_signals(hist: pd.DataFrame, sma_fast: int, sma_slow: int) -> list[dict[str, Any]]:
frame = hist.copy()
frame["sma_fast"] = frame["close"].rolling(sma_fast).mean()
frame["sma_slow"] = frame["close"].rolling(sma_slow).mean()
signals: list[dict[str, Any]] = []
for i in range(1, len(frame)):
prev = frame.iloc[i - 1]
curr = frame.iloc[i]
if pd.isna(prev["sma_fast"]) or pd.isna(prev["sma_slow"]) or pd.isna(curr["sma_fast"]) or pd.isna(curr["sma_slow"]):
continue
cross_up = prev["sma_fast"] <= prev["sma_slow"] and curr["sma_fast"] > curr["sma_slow"]
cross_down = prev["sma_fast"] >= prev["sma_slow"] and curr["sma_fast"] < curr["sma_slow"]
if cross_up:
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "buy", "price": float(curr["close"])})
elif cross_down:
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "sell", "price": float(curr["close"])})
return signals
def resolve_candidates(provider: MarketDataProvider, query: str, limit: int) -> pd.DataFrame:
return provider.search_spot(query=query, limit=limit)
def _pick_first_candidate(provider: MarketDataProvider, query: str) -> pd.Series:
candidates = resolve_candidates(provider, query=query, limit=1)
if candidates.empty:
raise RuntimeError(f"未找到匹配股票: {query}")
row = candidates.iloc[0]
code = str(row.get("代码", "")).strip()
if not code:
raise RuntimeError("行情返回缺少代码字段")
return row
def _build_info(provider: MarketDataProvider, row: pd.Series) -> dict[str, Any]:
code = str(row.get("代码", "")).strip()
return {
"code": code,
"name": str(row.get("名称", "")),
"price": _safe(row.get("最新价")),
"change_pct": _safe(row.get("涨跌幅")),
"change": _safe(row.get("涨跌额")),
"volume": _safe(row.get("成交量")),
"turnover": _safe(row.get("成交额")),
"amplitude": _safe(row.get("振幅")),
"high": _safe(row.get("最高")),
"low": _safe(row.get("最低")),
"open": _safe(row.get("今开")),
"prev_close": _safe(row.get("昨收")),
"pe": _safe(row.get("市盈率-动态")),
"pb": _safe(row.get("市净率")),
"market_cap": _safe(row.get("总市值")),
"float_market_cap": _safe(row.get("流通市值")),
"provider": provider.provider_name,
"channel": provider.channel,
}
def build_realtime_info(provider: MarketDataProvider, query: str) -> dict[str, Any]:
row = _pick_first_candidate(provider, query=query)
info = _build_info(provider=provider, row=row)
now = datetime.now()
price = info.get("price")
volume = info.get("volume")
realtime_point = {
"date": now.strftime("%Y-%m-%d"),
"close": float(price) if price is not None else None,
"volume": float(volume) if volume is not None else 0.0,
}
return {
"info": info,
"realtime_point": realtime_point,
"source": "realtime",
"updated_at": now.strftime("%Y-%m-%d %H:%M:%S"),
}
_DASHBOARD_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
_DASHBOARD_CACHE_TTL_SECONDS = 15.0
def build_dashboard(
provider: MarketDataProvider,
query: str,
days: int,
sma_fast: int,
sma_slow: int,
) -> dict[str, Any]:
if sma_fast >= sma_slow:
raise RuntimeError("sma_fast must be less than sma_slow")
row = _pick_first_candidate(provider, query=query)
code = str(row.get("代码", "")).strip()
cache_key = "|".join([provider.provider_name, provider.channel, code, str(days), str(sma_fast), str(sma_slow)])
now = time()
cached = _DASHBOARD_CACHE.get(cache_key)
if cached is not None and now - cached[0] <= _DASHBOARD_CACHE_TTL_SECONDS:
return cached[1]
end = datetime.now()
lookback_days = max(days + (sma_slow * 3), days + 30)
start = end - timedelta(days=lookback_days)
hist = provider.fetch_daily_kline(code=code, start=start, end=end).tail(days).reset_index(drop=True)
if hist.empty:
raise RuntimeError(f"未获取到 K 线数据: {code}")
signals = build_signals(hist, sma_fast=sma_fast, sma_slow=sma_slow)
hist["sma_fast"] = hist["close"].rolling(sma_fast).mean()
hist["sma_slow"] = hist["close"].rolling(sma_slow).mean()
points = [
{
"date": d.strftime("%Y-%m-%d"),
"close": float(c),
"sma_fast": _safe(f),
"sma_slow": _safe(s),
"volume": float(v),
}
for d, c, f, s, v in zip(hist["date"], hist["close"], hist["sma_fast"], hist["sma_slow"], hist["volume"])
]
result = {"info": _build_info(provider=provider, row=row), "points": points, "signals": signals}
_DASHBOARD_CACHE[cache_key] = (now, result)
return result

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@dataclass
class SpotRow:
code: str
name: str
price: Any
change_pct: Any
change: Any
volume: Any
turnover: Any
amplitude: Any
high: Any
low: Any
open: Any
prev_close: Any
pe: Any
pb: Any
market_cap: Any
float_market_cap: Any
@dataclass
class KlineRow:
date: datetime
close: float
volume: float

View File

@@ -0,0 +1,20 @@
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430400000,"open_time":1711430400000,"close_time":1711430459999,"open":"68000.00","high":"68025.00","low":"67990.00","close":"68010.00","volume":"12.40","trade_num":102,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430460000,"open_time":1711430460000,"close_time":1711430519999,"open":"68010.00","high":"68040.00","low":"68000.00","close":"68028.00","volume":"10.12","trade_num":96,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430520000,"open_time":1711430520000,"close_time":1711430579999,"open":"68028.00","high":"68036.00","low":"68005.00","close":"68012.00","volume":"8.80","trade_num":87,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430580000,"open_time":1711430580000,"close_time":1711430639999,"open":"68012.00","high":"68022.00","low":"67980.00","close":"67992.00","volume":"11.50","trade_num":120,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430640000,"open_time":1711430640000,"close_time":1711430699999,"open":"67992.00","high":"68018.00","low":"67970.00","close":"68005.00","volume":"9.64","trade_num":91,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430700000,"open_time":1711430700000,"close_time":1711430759999,"open":"68005.00","high":"68055.00","low":"68002.00","close":"68048.00","volume":"14.03","trade_num":133,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430760000,"open_time":1711430760000,"close_time":1711430819999,"open":"68048.00","high":"68070.00","low":"68030.00","close":"68062.00","volume":"13.38","trade_num":121,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430820000,"open_time":1711430820000,"close_time":1711430879999,"open":"68062.00","high":"68080.00","low":"68040.00","close":"68044.00","volume":"10.99","trade_num":99,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430880000,"open_time":1711430880000,"close_time":1711430939999,"open":"68044.00","high":"68046.00","low":"68000.00","close":"68006.00","volume":"12.75","trade_num":115,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430940000,"open_time":1711430940000,"close_time":1711430999999,"open":"68006.00","high":"68012.00","low":"67972.00","close":"67986.00","volume":"11.66","trade_num":105,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431000000,"open_time":1711431000000,"close_time":1711431059999,"open":"67986.00","high":"68008.00","low":"67970.00","close":"67996.00","volume":"9.21","trade_num":88,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431060000,"open_time":1711431060000,"close_time":1711431119999,"open":"67996.00","high":"68035.00","low":"67992.00","close":"68020.00","volume":"10.77","trade_num":94,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431120000,"open_time":1711431120000,"close_time":1711431179999,"open":"68020.00","high":"68058.00","low":"68018.00","close":"68052.00","volume":"12.62","trade_num":110,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431180000,"open_time":1711431180000,"close_time":1711431239999,"open":"68052.00","high":"68066.00","low":"68030.00","close":"68042.00","volume":"8.19","trade_num":76,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431240000,"open_time":1711431240000,"close_time":1711431299999,"open":"68042.00","high":"68052.00","low":"68008.00","close":"68016.00","volume":"10.40","trade_num":93,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431300000,"open_time":1711431300000,"close_time":1711431359999,"open":"68016.00","high":"68018.00","low":"67986.00","close":"67998.00","volume":"9.07","trade_num":84,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431360000,"open_time":1711431360000,"close_time":1711431419999,"open":"67998.00","high":"68012.00","low":"67974.00","close":"67988.00","volume":"7.55","trade_num":69,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431420000,"open_time":1711431420000,"close_time":1711431479999,"open":"67988.00","high":"68028.00","low":"67982.00","close":"68024.00","volume":"11.88","trade_num":101,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431480000,"open_time":1711431480000,"close_time":1711431539999,"open":"68024.00","high":"68070.00","low":"68020.00","close":"68062.00","volume":"13.45","trade_num":124,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431540000,"open_time":1711431540000,"close_time":1711431599999,"open":"68062.00","high":"68082.00","low":"68048.00","close":"68074.00","volume":"12.03","trade_num":111,"final":true}

View File

@@ -0,0 +1,91 @@
from __future__ import annotations
import argparse
import pandas as pd
from market import create_provider, load_market_config
from market.service import DISPLAY_FIELDS, resolve_candidates
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Stock quick lookup (A-share)")
parser.add_argument("query", nargs="?", default="", help="stock code/name/abbr")
parser.add_argument("--limit", type=int, default=8, help="max candidate rows")
return parser.parse_args()
def safe_str(value: object) -> str:
if value is None:
return ""
if isinstance(value, float) and pd.isna(value):
return ""
return str(value).strip()
def print_row(row: pd.Series) -> None:
print("\n===== 股票信息 =====")
for field in DISPLAY_FIELDS:
if field in row.index:
print(f"{field}: {safe_str(row[field])}")
print("===================\n")
def pick_candidate(candidates: pd.DataFrame) -> pd.Series | None:
if len(candidates) == 1:
return candidates.iloc[0]
print("\n匹配到多个标的,请选择:")
for idx, (_, row) in enumerate(candidates.iterrows(), start=1):
print(f"{idx}. {safe_str(row['代码'])} {safe_str(row['名称'])} 最新价={safe_str(row.get('最新价'))}")
choice = input("输入序号(直接回车默认 1): ").strip()
if not choice:
return candidates.iloc[0]
if not choice.isdigit():
print("输入无效")
return None
pos = int(choice)
if pos < 1 or pos > len(candidates):
print("输入超出范围")
return None
return candidates.iloc[pos - 1]
def run_once(query: str, limit: int) -> None:
provider = create_provider(load_market_config())
candidates = resolve_candidates(provider, query, limit)
if candidates.empty:
print(f"未找到匹配股票: {query}")
return
selected = pick_candidate(candidates)
if selected is None:
return
print_row(selected)
def interactive_loop(limit: int) -> None:
print("股票速查已启动,输入代码/名称(输入 q 退出)")
while True:
query = input("查询> ").strip()
if query.lower() in {"q", "quit", "exit"}:
print("已退出股票速查")
return
if not query:
continue
try:
run_once(query, limit)
except Exception as exc: # noqa: BLE001
print(f"查询失败: {exc}")
def main() -> None:
args = parse_args()
if not args.query:
interactive_loop(args.limit)
return
run_once(args.query, args.limit)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from .sma_cross import SmaCrossStrategy
__all__ = ["SmaCrossStrategy"]

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
import backtrader as bt
class SmaCrossStrategy(bt.Strategy):
params = (
("fast", 3),
("slow", 8),
("stake", 1),
)
def __init__(self) -> None:
self.fast_sma = bt.indicators.SMA(self.data.close, period=self.params.fast)
self.slow_sma = bt.indicators.SMA(self.data.close, period=self.params.slow)
self.cross_signal = bt.indicators.CrossOver(self.fast_sma, self.slow_sma)
def next(self) -> None:
if not self.position and self.cross_signal > 0:
self.buy(size=self.params.stake)
elif self.position and self.cross_signal < 0:
self.close()

View File

@@ -0,0 +1,211 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>AITrading 股票看板</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.3/dist/chart.umd.min.js"></script>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Arial, sans-serif; margin: 20px; background:#f7f8fa; color:#1f2937; }
.card { background:#fff; border-radius:10px; box-shadow:0 2px 8px rgba(0,0,0,.06); padding:16px; margin-bottom:16px; }
.row { display:flex; gap:12px; flex-wrap:wrap; }
input, button { padding:8px 10px; border-radius:8px; border:1px solid #d1d5db; }
button { background:#2563eb; color:#fff; border:none; cursor:pointer; }
.grid { display:grid; grid-template-columns: repeat(3, minmax(180px,1fr)); gap:8px; }
.item { background:#f9fafb; border-radius:8px; padding:8px; }
.label { color:#6b7280; font-size:12px; }
.value { font-size:14px; font-weight:600; }
#status { font-size:13px; color:#6b7280; }
</style>
</head>
<body>
<div class="card">
<h2>股票信息与买卖点</h2>
<div class="row">
<input id="query" placeholder="输入股票代码或名称,如 600519 / 贵州茅台" style="min-width:300px;" />
<input id="fast" type="number" min="2" max="60" value="5" />
<input id="slow" type="number" min="3" max="120" value="20" />
<button onclick="searchStock()">查询</button>
</div>
<p id="status">请输入代码或名称并点击查询。</p>
</div>
<div class="card">
<div id="info" class="grid"></div>
</div>
<div class="card">
<canvas id="chart" height="120"></canvas>
</div>
<script>
let chart;
let realtimeWs = null;
let currentPoints = [];
let currentSignals = [];
let currentFast = 5;
let currentSlow = 20;
function renderInfo(info) {
const fields = [
["代码", info.code], ["名称", info.name], ["最新价", info.price],
["涨跌幅", info.change_pct], ["涨跌额", info.change], ["成交额", info.turnover],
["PE(动态)", info.pe], ["PB", info.pb], ["总市值", info.market_cap]
];
const container = document.getElementById("info");
container.innerHTML = fields.map(([k, v]) =>
`<div class="item"><div class="label">${k}</div><div class="value">${v ?? "-"}</div></div>`
).join("");
}
function buildSignals(points, fast, slow) {
const signals = [];
for (let i = 1; i < points.length; i++) {
const prev = points[i - 1];
const curr = points[i];
if (prev.sma_fast == null || prev.sma_slow == null || curr.sma_fast == null || curr.sma_slow == null) continue;
const crossUp = prev.sma_fast <= prev.sma_slow && curr.sma_fast > curr.sma_slow;
const crossDown = prev.sma_fast >= prev.sma_slow && curr.sma_fast < curr.sma_slow;
if (crossUp) signals.push({ date: curr.date, type: "buy", price: curr.close });
else if (crossDown) signals.push({ date: curr.date, type: "sell", price: curr.close });
}
return signals;
}
function recomputeSma(points, fast, slow) {
const sumFast = [];
const sumSlow = [];
for (let i = 0; i < points.length; i++) {
sumFast[i] = (sumFast[i - 1] || 0) + Number(points[i].close || 0);
sumSlow[i] = (sumSlow[i - 1] || 0) + Number(points[i].close || 0);
if (i >= fast - 1) {
const prev = i - fast >= 0 ? sumFast[i - fast] : 0;
points[i].sma_fast = (sumFast[i] - prev) / fast;
} else {
points[i].sma_fast = null;
}
if (i >= slow - 1) {
const prev = i - slow >= 0 ? sumSlow[i - slow] : 0;
points[i].sma_slow = (sumSlow[i] - prev) / slow;
} else {
points[i].sma_slow = null;
}
}
}
function upsertRealtimePoint(point) {
if (!point || point.close == null || currentPoints.length === 0) return;
const last = currentPoints[currentPoints.length - 1];
if (last.date === point.date) {
last.close = Number(point.close);
if (point.volume != null) last.volume = Number(point.volume);
} else if (point.date > last.date) {
currentPoints.push({
date: point.date,
close: Number(point.close),
volume: Number(point.volume || 0),
sma_fast: null,
sma_slow: null,
});
}
recomputeSma(currentPoints, currentFast, currentSlow);
currentSignals = buildSignals(currentPoints, currentFast, currentSlow);
drawChart(currentPoints, currentSignals);
}
function drawChart(points, signals) {
const labels = points.map(p => p.date);
const close = points.map(p => p.close);
const fast = points.map(p => p.sma_fast);
const slow = points.map(p => p.sma_slow);
const buyPoints = signals.filter(s => s.type === "buy").map(s => ({x: s.date, y: s.price}));
const sellPoints = signals.filter(s => s.type === "sell").map(s => ({x: s.date, y: s.price}));
if (!chart) {
chart = new Chart(document.getElementById("chart"), {
type: "line",
data: {
labels,
datasets: [
{ label: "收盘价", data: close, borderColor: "#1f2937", tension: 0.2, pointRadius: 0 },
{ label: "SMA Fast", data: fast, borderColor: "#2563eb", tension: 0.2, pointRadius: 0 },
{ label: "SMA Slow", data: slow, borderColor: "#16a34a", tension: 0.2, pointRadius: 0 },
{ label: "买点", data: buyPoints, parsing: {xAxisKey: "x", yAxisKey: "y"}, showLine:false, pointRadius:5, pointStyle:"triangle", pointBackgroundColor:"#dc2626", pointBorderColor:"#dc2626" },
{ label: "卖点", data: sellPoints, parsing: {xAxisKey: "x", yAxisKey: "y"}, showLine:false, pointRadius:5, pointStyle:"rectRot", pointBackgroundColor:"#7c3aed", pointBorderColor:"#7c3aed" }
]
},
options: {
responsive: true,
animation: false,
scales: { x: { ticks: { maxTicksLimit: 10 } } }
}
});
return;
}
chart.data.labels = labels;
chart.data.datasets[0].data = close;
chart.data.datasets[1].data = fast;
chart.data.datasets[2].data = slow;
chart.data.datasets[3].data = buyPoints;
chart.data.datasets[4].data = sellPoints;
chart.update("none");
}
async function searchStock() {
const query = document.getElementById("query").value.trim();
const fast = Number(document.getElementById("fast").value);
const slow = Number(document.getElementById("slow").value);
if (!query) return;
const status = document.getElementById("status");
status.innerText = "正在建立实时 WebSocket 连接...";
try {
if (realtimeWs) {
realtimeWs.close();
realtimeWs = null;
}
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
realtimeWs = new WebSocket(`${protocol}://${window.location.host}/ws/stock/realtime?query=${encodeURIComponent(query)}`);
realtimeWs.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.error) {
status.innerText = `实时通道错误: ${data.error}`;
return;
}
renderInfo(data.info);
upsertRealtimePoint(data.realtime_point);
status.innerText = `实时更新中:${data.info.code} ${data.info.name}${data.updated_at}`;
};
realtimeWs.onclose = () => {
if (status.innerText.startsWith("实时更新中")) {
status.innerText = "实时连接已断开";
}
};
realtimeWs.onerror = () => {
status.innerText = "实时连接失败";
};
const res = await fetch(`/api/stock?query=${encodeURIComponent(query)}&sma_fast=${fast}&sma_slow=${slow}`);
const data = await res.json();
if (!res.ok) throw new Error(data.detail || "查询失败");
currentFast = fast;
currentSlow = slow;
currentPoints = data.points.map(p => ({
date: p.date,
close: Number(p.close),
sma_fast: p.sma_fast == null ? null : Number(p.sma_fast),
sma_slow: p.sma_slow == null ? null : Number(p.sma_slow),
volume: p.volume == null ? 0 : Number(p.volume),
}));
currentSignals = data.signals;
renderInfo(data.info);
drawChart(currentPoints, currentSignals);
status.innerText = `历史K线加载完成${data.info.code} ${data.info.name},买点 ${data.signals.filter(s=>s.type==="buy").length},卖点 ${data.signals.filter(s=>s.type==="sell").length}(实时通道已连接)`;
} catch (err) {
status.innerText = `错误: ${err.message}`;
}
}
</script>
</body>
</html>

82
python-app/app/web_app.py Normal file
View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from market import create_provider, load_market_config
from market.service import build_dashboard, build_realtime_info
app = FastAPI(title="AITrading Simple Frontend API", version="1.0.0")
BASE_DIR = Path(__file__).resolve().parent
WEB_INDEX = BASE_DIR / "web" / "index.html"
@app.get("/")
def index() -> FileResponse:
if not WEB_INDEX.exists():
raise HTTPException(status_code=500, detail="frontend page not found")
return FileResponse(WEB_INDEX)
@app.get("/api/stock")
def stock_dashboard(
query: str = Query(..., description="stock code or name"),
days: int = Query(120, ge=30, le=365),
sma_fast: int = Query(5, ge=2, le=60),
sma_slow: int = Query(20, ge=3, le=120),
) -> dict[str, Any]:
try:
provider = create_provider(load_market_config())
return build_dashboard(
provider=provider,
query=query,
days=days,
sma_fast=sma_fast,
sma_slow=sma_slow,
)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/api/stock/realtime")
def stock_realtime(
query: str = Query(..., description="stock code or name"),
) -> dict[str, Any]:
try:
provider = create_provider(load_market_config())
return build_realtime_info(provider=provider, query=query)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.websocket("/ws/stock/realtime")
async def stock_realtime_ws(websocket: WebSocket) -> None:
await websocket.accept()
query = (websocket.query_params.get("query") or "").strip()
if not query:
await websocket.send_json({"error": "query is required"})
await websocket.close(code=1008)
return
try:
while True:
provider = create_provider(load_market_config())
payload = build_realtime_info(provider=provider, query=query)
await websocket.send_json(payload)
# WebSocket channel keeps the connection alive, front-end no longer polls via HTTP.
await asyncio.sleep(2)
except WebSocketDisconnect:
return
except Exception as exc: # noqa: BLE001
await websocket.send_json({"error": str(exc)})
await websocket.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("web_app:app", host="0.0.0.0", port=8000, reload=False)

View File

@@ -0,0 +1,10 @@
pandas
mplfinance
matplotlib
backtrader
redis
akshare
fastapi
uvicorn
futu-api
websockets