feat: 修复报错
This commit is contained in:
24
python-app/Dockerfile
Normal file
24
python-app/Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
ARG DOCKER_MIRROR_PREFIX=m.daocloud.io/docker.io/library
|
||||
FROM ${DOCKER_MIRROR_PREFIX}/python:3.11-slim-bookworm
|
||||
|
||||
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ \
|
||||
PIP_TRUSTED_HOST=mirrors.aliyun.com \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
|
||||
&& sed -i 's|security.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
tzdata \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY app ./app
|
||||
|
||||
CMD ["python", "app/backtest_runner.py", "--help"]
|
||||
142
python-app/app/backtest_runner.py
Normal file
142
python-app/app/backtest_runner.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Type
|
||||
|
||||
import backtrader as bt
|
||||
import pandas as pd
|
||||
|
||||
from data_protocol import KlineMessage
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Backtrader runner")
|
||||
parser.add_argument("--input", required=True, help="Path to kline jsonl file")
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
default="strategies.sma_cross.SmaCrossStrategy",
|
||||
help="Strategy class path, e.g. strategies.sma_cross.SmaCrossStrategy",
|
||||
)
|
||||
parser.add_argument("--cash", type=float, default=100000.0, help="Initial cash")
|
||||
parser.add_argument("--commission", type=float, default=0.001, help="Commission")
|
||||
parser.add_argument(
|
||||
"--strategy-param",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Strategy parameter in key=value format, repeatable",
|
||||
)
|
||||
parser.add_argument("--plot", action="store_true", help="Plot result")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_strategy_class(class_path: str) -> Type[bt.Strategy]:
|
||||
module_name, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
strategy_cls = getattr(module, class_name)
|
||||
if not issubclass(strategy_cls, bt.Strategy):
|
||||
raise TypeError(f"{class_path} is not a backtrader.Strategy subclass")
|
||||
return strategy_cls
|
||||
|
||||
|
||||
def load_dataframe(path: Path) -> pd.DataFrame:
|
||||
rows = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
msg = KlineMessage.from_dict(payload)
|
||||
rows.append(
|
||||
{
|
||||
"datetime": pd.to_datetime(msg.open_time, unit="ms"),
|
||||
"open": float(msg.open),
|
||||
"high": float(msg.high),
|
||||
"low": float(msg.low),
|
||||
"close": float(msg.close),
|
||||
"volume": float(msg.volume),
|
||||
"openinterest": 0.0,
|
||||
}
|
||||
)
|
||||
if not rows:
|
||||
raise RuntimeError("no kline rows found for backtest")
|
||||
frame = pd.DataFrame(rows).set_index("datetime")
|
||||
return frame.sort_index()
|
||||
|
||||
|
||||
def parse_strategy_params(raw_params: list[str]) -> dict[str, Any]:
|
||||
parsed: dict[str, Any] = {}
|
||||
for item in raw_params:
|
||||
if "=" not in item:
|
||||
raise ValueError(f"invalid --strategy-param '{item}', expected key=value")
|
||||
key, raw_value = item.split("=", 1)
|
||||
key = key.strip()
|
||||
raw_value = raw_value.strip()
|
||||
if not key:
|
||||
raise ValueError(f"invalid --strategy-param '{item}', empty key")
|
||||
try:
|
||||
value = ast.literal_eval(raw_value)
|
||||
except (ValueError, SyntaxError):
|
||||
value = raw_value
|
||||
parsed[key] = value
|
||||
return parsed
|
||||
|
||||
|
||||
def extract_report(result: list[bt.Strategy], final_value: float) -> dict[str, Any]:
|
||||
strategy = result[0]
|
||||
drawdown_info = strategy.analyzers.drawdown.get_analysis()
|
||||
sharpe_info = strategy.analyzers.sharpe.get_analysis()
|
||||
|
||||
max_drawdown = drawdown_info.get("max", {}).get("drawdown")
|
||||
sharpe_ratio = sharpe_info.get("sharperatio")
|
||||
return {
|
||||
"final_value": final_value,
|
||||
"max_drawdown_pct": None if max_drawdown is None else float(max_drawdown),
|
||||
"sharpe_ratio": None if sharpe_ratio is None else float(sharpe_ratio),
|
||||
}
|
||||
|
||||
|
||||
def print_report(report: dict[str, Any]) -> None:
|
||||
print("\n===== Backtest Report =====")
|
||||
print(f"Final Value : {report['final_value']:.2f}")
|
||||
max_drawdown = report["max_drawdown_pct"]
|
||||
sharpe_ratio = report["sharpe_ratio"]
|
||||
print(f"Max Drawdown (%) : {'N/A' if max_drawdown is None else f'{max_drawdown:.4f}'}")
|
||||
print(f"Sharpe Ratio : {'N/A' if sharpe_ratio is None else f'{sharpe_ratio:.4f}'}")
|
||||
print("===========================\n")
|
||||
|
||||
|
||||
def run_backtest() -> None:
|
||||
args = parse_args()
|
||||
data_path = Path(args.input)
|
||||
if not data_path.exists():
|
||||
raise FileNotFoundError(f"input file not found: {data_path}")
|
||||
|
||||
strategy_cls = load_strategy_class(args.strategy)
|
||||
strategy_params = parse_strategy_params(args.strategy_param)
|
||||
df = load_dataframe(data_path)
|
||||
|
||||
cerebro = bt.Cerebro()
|
||||
cerebro.addstrategy(strategy_cls, **strategy_params)
|
||||
cerebro.broker.setcash(args.cash)
|
||||
cerebro.broker.setcommission(commission=args.commission)
|
||||
cerebro.adddata(bt.feeds.PandasData(dataname=df))
|
||||
cerebro.addanalyzer(bt.analyzers.DrawDown, _name="drawdown")
|
||||
cerebro.addanalyzer(bt.analyzers.SharpeRatio_A, _name="sharpe", riskfreerate=0.0)
|
||||
|
||||
print(f"starting portfolio value: {cerebro.broker.getvalue():.2f}")
|
||||
result = cerebro.run()
|
||||
final_value = cerebro.broker.getvalue()
|
||||
print(f"final portfolio value: {final_value:.2f}")
|
||||
print(f"strategies executed: {len(result)}")
|
||||
print_report(extract_report(result, final_value))
|
||||
|
||||
if args.plot:
|
||||
cerebro.plot(style="candlestick")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_backtest()
|
||||
59
python-app/app/data_protocol.py
Normal file
59
python-app/app/data_protocol.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class KlineMessage:
|
||||
type: str
|
||||
source: str
|
||||
symbol: str
|
||||
interval: str
|
||||
event_time: int
|
||||
open_time: int
|
||||
close_time: int
|
||||
open: str
|
||||
high: str
|
||||
low: str
|
||||
close: str
|
||||
volume: str
|
||||
trade_num: int
|
||||
final: bool
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: Dict[str, Any]) -> "KlineMessage":
|
||||
return cls(
|
||||
type=str(payload["type"]),
|
||||
source=str(payload["source"]),
|
||||
symbol=str(payload["symbol"]),
|
||||
interval=str(payload["interval"]),
|
||||
event_time=int(payload["event_time"]),
|
||||
open_time=int(payload["open_time"]),
|
||||
close_time=int(payload["close_time"]),
|
||||
open=str(payload["open"]),
|
||||
high=str(payload["high"]),
|
||||
low=str(payload["low"]),
|
||||
close=str(payload["close"]),
|
||||
volume=str(payload["volume"]),
|
||||
trade_num=int(payload["trade_num"]),
|
||||
final=bool(payload["final"]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"source": self.source,
|
||||
"symbol": self.symbol,
|
||||
"interval": self.interval,
|
||||
"event_time": self.event_time,
|
||||
"open_time": self.open_time,
|
||||
"close_time": self.close_time,
|
||||
"open": self.open,
|
||||
"high": self.high,
|
||||
"low": self.low,
|
||||
"close": self.close,
|
||||
"volume": self.volume,
|
||||
"trade_num": self.trade_num,
|
||||
"final": self.final,
|
||||
}
|
||||
113
python-app/app/kline_viewer.py
Normal file
113
python-app/app/kline_viewer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mplfinance as mpf
|
||||
import pandas as pd
|
||||
import redis
|
||||
|
||||
from data_protocol import KlineMessage
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Realtime K-line viewer from Redis")
|
||||
parser.add_argument("--redis-host", default="127.0.0.1", help="Redis host")
|
||||
parser.add_argument("--redis-port", type=int, default=6379, help="Redis port")
|
||||
parser.add_argument("--redis-db", type=int, default=0, help="Redis DB")
|
||||
parser.add_argument("--redis-password", default="", help="Redis password")
|
||||
parser.add_argument("--channel", default="kline.stream", help="Pub/Sub channel")
|
||||
parser.add_argument("--window", type=int, default=80, help="Window size to render")
|
||||
parser.add_argument("--only-final", action="store_true", help="Render only closed candles")
|
||||
parser.add_argument("--timeout", type=float, default=1.0, help="Subscriber timeout (seconds)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def upsert_message(buffer: OrderedDict[int, KlineMessage], message: KlineMessage, max_size: int) -> None:
|
||||
buffer[message.open_time] = message
|
||||
buffer.move_to_end(message.open_time)
|
||||
while len(buffer) > max_size:
|
||||
buffer.popitem(last=False)
|
||||
|
||||
|
||||
def to_dataframe(messages: list[KlineMessage]) -> pd.DataFrame:
|
||||
rows = []
|
||||
for msg in messages:
|
||||
rows.append(
|
||||
{
|
||||
"Date": pd.to_datetime(msg.open_time, unit="ms"),
|
||||
"Open": float(msg.open),
|
||||
"High": float(msg.high),
|
||||
"Low": float(msg.low),
|
||||
"Close": float(msg.close),
|
||||
"Volume": float(msg.volume),
|
||||
}
|
||||
)
|
||||
frame = pd.DataFrame(rows).set_index("Date")
|
||||
return frame
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
plt.ion()
|
||||
client = redis.Redis(
|
||||
host=args.redis_host,
|
||||
port=args.redis_port,
|
||||
db=args.redis_db,
|
||||
password=args.redis_password or None,
|
||||
decode_responses=True,
|
||||
)
|
||||
pubsub = client.pubsub(ignore_subscribe_messages=True)
|
||||
pubsub.subscribe(args.channel)
|
||||
|
||||
print(
|
||||
f"subscribed redis channel={args.channel} "
|
||||
f"at {args.redis_host}:{args.redis_port}/{args.redis_db}"
|
||||
)
|
||||
|
||||
candles: OrderedDict[int, KlineMessage] = OrderedDict()
|
||||
|
||||
try:
|
||||
while True:
|
||||
packet = pubsub.get_message(timeout=args.timeout)
|
||||
if not packet:
|
||||
plt.pause(0.01)
|
||||
continue
|
||||
|
||||
raw = packet.get("data")
|
||||
if not isinstance(raw, str):
|
||||
continue
|
||||
|
||||
payload = json.loads(raw)
|
||||
message = KlineMessage.from_dict(payload)
|
||||
if args.only_final and not message.final:
|
||||
continue
|
||||
|
||||
upsert_message(candles, message, args.window)
|
||||
frame = to_dataframe(list(candles.values()))
|
||||
if frame.empty:
|
||||
continue
|
||||
|
||||
plt.clf()
|
||||
mpf.plot(
|
||||
frame,
|
||||
type="candle",
|
||||
style="yahoo",
|
||||
volume=True,
|
||||
datetime_format="%H:%M",
|
||||
tight_layout=True,
|
||||
)
|
||||
plt.pause(0.001)
|
||||
except KeyboardInterrupt:
|
||||
print("viewer stopped by user")
|
||||
finally:
|
||||
pubsub.close()
|
||||
client.close()
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
python-app/app/market/__init__.py
Normal file
4
python-app/app/market/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from market.config import load_market_config
|
||||
from market.factory import create_provider
|
||||
|
||||
__all__ = ["load_market_config", "create_provider"]
|
||||
40
python-app/app/market/config.py
Normal file
40
python-app/app/market/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketConfig:
|
||||
channel: str
|
||||
provider: str
|
||||
cmes_token: str
|
||||
futu_host: str
|
||||
futu_port: int
|
||||
futu_is_encrypt: bool | None
|
||||
futu_market: str
|
||||
|
||||
|
||||
def load_market_config() -> MarketConfig:
|
||||
channel = os.getenv("MARKET_CHANNEL", "cn").strip().lower()
|
||||
provider = os.getenv("MARKET_PROVIDER", "akshare").strip().lower()
|
||||
cmes_token = os.getenv("CMES_TOKEN", "").strip()
|
||||
futu_host = os.getenv("FUTU_HOST", "127.0.0.1").strip()
|
||||
futu_port = int(os.getenv("FUTU_PORT", "11111").strip())
|
||||
futu_encrypt_raw = os.getenv("FUTU_IS_ENCRYPT", "").strip().lower()
|
||||
if futu_encrypt_raw in {"1", "true", "yes", "y"}:
|
||||
futu_is_encrypt = True
|
||||
elif futu_encrypt_raw in {"0", "false", "no", "n"}:
|
||||
futu_is_encrypt = False
|
||||
else:
|
||||
futu_is_encrypt = None
|
||||
futu_market = os.getenv("FUTU_MARKET", "").strip().lower()
|
||||
return MarketConfig(
|
||||
channel=channel,
|
||||
provider=provider,
|
||||
cmes_token=cmes_token,
|
||||
futu_host=futu_host,
|
||||
futu_port=futu_port,
|
||||
futu_is_encrypt=futu_is_encrypt,
|
||||
futu_market=futu_market,
|
||||
)
|
||||
29
python-app/app/market/factory.py
Normal file
29
python-app/app/market/factory.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from market.config import MarketConfig
|
||||
from market.provider_base import MarketDataProvider
|
||||
from market.providers.akshare_provider import AkshareCnProvider
|
||||
from market.providers.cmes_provider import CmesCnProvider
|
||||
|
||||
|
||||
def create_provider(config: MarketConfig) -> MarketDataProvider:
|
||||
if config.channel == "cn" and config.provider == "akshare":
|
||||
return AkshareCnProvider()
|
||||
if config.channel == "cn" and config.provider == "cmesdata":
|
||||
return CmesCnProvider(token=config.cmes_token)
|
||||
if config.provider == "futu" and config.channel in {"cn", "hk", "us"}:
|
||||
from market.providers.futu_provider import FutuProvider
|
||||
|
||||
return FutuProvider(
|
||||
channel=config.channel,
|
||||
host=config.futu_host,
|
||||
port=config.futu_port,
|
||||
is_encrypt=config.futu_is_encrypt,
|
||||
market=config.futu_market,
|
||||
)
|
||||
if config.channel in {"us", "hk"}:
|
||||
raise RuntimeError(
|
||||
f"channel={config.channel} provider={config.provider} 尚未实现,"
|
||||
"请新增 Provider 后接入 factory。"
|
||||
)
|
||||
raise RuntimeError(f"不支持的通道配置: channel={config.channel}, provider={config.provider}")
|
||||
30
python-app/app/market/provider_base.py
Normal file
30
python-app/app/market/provider_base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class MarketDataProvider(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def channel(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fetch_spot(self) -> pd.DataFrame:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
|
||||
raise NotImplementedError
|
||||
12
python-app/app/market/providers/__init__.py
Normal file
12
python-app/app/market/providers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from market.providers.akshare_provider import AkshareCnProvider
|
||||
from market.providers.cmes_provider import CmesCnProvider
|
||||
|
||||
__all__ = ["AkshareCnProvider", "CmesCnProvider"]
|
||||
|
||||
try:
|
||||
from market.providers.futu_provider import FutuProvider
|
||||
|
||||
__all__.append("FutuProvider")
|
||||
except Exception:
|
||||
# Allow non-futu environments to keep using other providers.
|
||||
pass
|
||||
63
python-app/app/market/providers/akshare_provider.py
Normal file
63
python-app/app/market/providers/akshare_provider.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import akshare as ak
|
||||
import pandas as pd
|
||||
|
||||
from market.provider_base import MarketDataProvider
|
||||
|
||||
|
||||
class AkshareCnProvider(MarketDataProvider):
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "akshare"
|
||||
|
||||
@property
|
||||
def channel(self) -> str:
|
||||
return "cn"
|
||||
|
||||
def fetch_spot(self) -> pd.DataFrame:
|
||||
df = ak.stock_zh_a_spot_em()
|
||||
if df.empty:
|
||||
raise RuntimeError("未获取到 A 股实时行情数据")
|
||||
return df
|
||||
|
||||
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
|
||||
q = query.strip().lower()
|
||||
if not q:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = self.fetch_spot()
|
||||
code = df["代码"].astype(str).str.lower()
|
||||
name = df["名称"].astype(str).str.lower()
|
||||
|
||||
exact = df[(code == q) | (name == q)]
|
||||
if not exact.empty:
|
||||
return exact.head(limit)
|
||||
|
||||
starts = df[code.str.startswith(q) | name.str.startswith(q)]
|
||||
if not starts.empty:
|
||||
return starts.head(limit)
|
||||
|
||||
return df[code.str.contains(q, na=False) | name.str.contains(q, na=False)].head(limit)
|
||||
|
||||
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
|
||||
hist = ak.stock_zh_a_hist(
|
||||
symbol=code,
|
||||
period="daily",
|
||||
start_date=start.strftime("%Y%m%d"),
|
||||
end_date=end.strftime("%Y%m%d"),
|
||||
adjust="qfq",
|
||||
)
|
||||
if hist.empty:
|
||||
raise RuntimeError(f"未获取到 K 线数据: {code}")
|
||||
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"date": pd.to_datetime(hist["日期"]),
|
||||
"close": pd.to_numeric(hist["收盘"], errors="coerce"),
|
||||
"volume": pd.to_numeric(hist["成交量"], errors="coerce"),
|
||||
}
|
||||
).dropna()
|
||||
return frame.sort_values("date").reset_index(drop=True)
|
||||
70
python-app/app/market/providers/cmes_provider.py
Normal file
70
python-app/app/market/providers/cmes_provider.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from market.provider_base import MarketDataProvider
|
||||
|
||||
|
||||
class CmesCnProvider(MarketDataProvider):
|
||||
def __init__(self, token: str | None = None) -> None:
|
||||
self._token = token or os.getenv("CMES_TOKEN", "").strip()
|
||||
self._module = importlib.import_module("cmesdata")
|
||||
self._login_once()
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "cmesdata"
|
||||
|
||||
@property
|
||||
def channel(self) -> str:
|
||||
return "cn"
|
||||
|
||||
def _login_once(self) -> None:
|
||||
if not self._token:
|
||||
raise RuntimeError("CMES_TOKEN 未配置,无法使用 cmesdata 通道")
|
||||
self._module.login(self._token)
|
||||
|
||||
@staticmethod
|
||||
def _to_prefixed_code(code: str) -> str:
|
||||
raw = code.strip().upper().replace(".", "")
|
||||
if raw.startswith("SH") or raw.startswith("SZ"):
|
||||
return f"{raw[:2]}.{raw[2:]}"
|
||||
if raw.isdigit() and len(raw) == 6:
|
||||
if raw.startswith(("6", "9")):
|
||||
return f"SH.{raw}"
|
||||
return f"SZ.{raw}"
|
||||
raise RuntimeError("cmesdata 通道仅支持 6 位 A 股代码或 SH./SZ. 前缀代码")
|
||||
|
||||
def fetch_spot(self) -> pd.DataFrame:
|
||||
raise RuntimeError("cmesdata 不支持全市场快照拉取,请通过精确代码查询")
|
||||
|
||||
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
|
||||
_ = limit
|
||||
code = self._to_prefixed_code(query)
|
||||
df = self._module.get_real_hq([code])
|
||||
if df is None or df.empty:
|
||||
raise RuntimeError(f"未获取到实时行情: {code}")
|
||||
return df
|
||||
|
||||
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
|
||||
prefixed = self._to_prefixed_code(code)
|
||||
df = self._module.get_history_data(
|
||||
prefixed,
|
||||
start.strftime("%Y-%m-%d"),
|
||||
end.strftime("%Y-%m-%d"),
|
||||
"D",
|
||||
)
|
||||
if df is None or df.empty:
|
||||
raise RuntimeError(f"未获取到历史 K 线: {prefixed}")
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"date": pd.to_datetime(df["时间"]),
|
||||
"close": pd.to_numeric(df["收盘价"], errors="coerce"),
|
||||
"volume": pd.to_numeric(df["成交量"], errors="coerce"),
|
||||
}
|
||||
).dropna()
|
||||
return frame.sort_values("date").reset_index(drop=True)
|
||||
207
python-app/app/market/providers/futu_provider.py
Normal file
207
python-app/app/market/providers/futu_provider.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
from futu import AuType, KLType, Market, OpenQuoteContext, RET_OK, SecurityType
|
||||
|
||||
from market.provider_base import MarketDataProvider
|
||||
|
||||
|
||||
class FutuProvider(MarketDataProvider):
|
||||
def __init__(
|
||||
self,
|
||||
channel: str,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 11111,
|
||||
is_encrypt: bool | None = None,
|
||||
market: str = "",
|
||||
) -> None:
|
||||
self._channel = channel.strip().lower()
|
||||
self._host = host
|
||||
self._port = int(port)
|
||||
self._is_encrypt = is_encrypt
|
||||
self._market = (market or self._channel).strip().lower()
|
||||
self._ctx = OpenQuoteContext(host=self._host, port=self._port, is_encrypt=self._is_encrypt)
|
||||
self._basicinfo_cache: pd.DataFrame | None = None
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "futu"
|
||||
|
||||
@property
|
||||
def channel(self) -> str:
|
||||
return self._channel
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self._ctx.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _require_market(self) -> Market:
|
||||
if self._market == "cn":
|
||||
return Market.SH
|
||||
if self._market == "hk":
|
||||
return Market.HK
|
||||
if self._market == "us":
|
||||
return Market.US
|
||||
raise RuntimeError(f"不支持的 FUTU_MARKET: {self._market},可选: cn/hk/us")
|
||||
|
||||
def _normalize_code(self, query: str) -> str:
|
||||
q = query.strip().upper()
|
||||
if not q:
|
||||
return ""
|
||||
if "." in q:
|
||||
return q
|
||||
if self._market == "cn":
|
||||
if q.isdigit() and len(q) == 6:
|
||||
prefix = "SH" if q.startswith(("5", "6", "9")) else "SZ"
|
||||
return f"{prefix}.{q}"
|
||||
return q
|
||||
if self._market == "hk":
|
||||
if q.isdigit():
|
||||
return f"HK.{q.zfill(5)}"
|
||||
return f"HK.{q}"
|
||||
if self._market == "us":
|
||||
return f"US.{q}"
|
||||
return q
|
||||
|
||||
def _snapshot_to_unified(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
frame = df.copy()
|
||||
if "code" not in frame.columns:
|
||||
raise RuntimeError("Futu 快照返回缺少 code 字段")
|
||||
if "name" not in frame.columns:
|
||||
frame["name"] = ""
|
||||
frame["涨跌额"] = pd.to_numeric(frame.get("last_price"), errors="coerce") - pd.to_numeric(
|
||||
frame.get("prev_close_price"), errors="coerce"
|
||||
)
|
||||
frame["市盈率-动态"] = frame.get("pe_ttm_ratio", frame.get("pe_ratio"))
|
||||
frame["市净率"] = frame.get("pb_ratio")
|
||||
frame["总市值"] = frame.get("total_market_val")
|
||||
frame["流通市值"] = frame.get("circular_market_val")
|
||||
frame["代码"] = frame["code"].astype(str)
|
||||
frame["名称"] = frame["name"]
|
||||
frame["最新价"] = frame.get("last_price")
|
||||
frame["涨跌幅"] = frame.get("change_rate")
|
||||
frame["成交量"] = frame.get("volume")
|
||||
frame["成交额"] = frame.get("turnover")
|
||||
frame["振幅"] = frame.get("amplitude")
|
||||
frame["最高"] = frame.get("high_price")
|
||||
frame["最低"] = frame.get("low_price")
|
||||
frame["今开"] = frame.get("open_price")
|
||||
frame["昨收"] = frame.get("prev_close_price")
|
||||
columns = [
|
||||
"代码",
|
||||
"名称",
|
||||
"最新价",
|
||||
"涨跌幅",
|
||||
"涨跌额",
|
||||
"成交量",
|
||||
"成交额",
|
||||
"振幅",
|
||||
"最高",
|
||||
"最低",
|
||||
"今开",
|
||||
"昨收",
|
||||
"市盈率-动态",
|
||||
"市净率",
|
||||
"总市值",
|
||||
"流通市值",
|
||||
]
|
||||
return frame[columns]
|
||||
|
||||
def _get_snapshot(self, codes: list[str]) -> pd.DataFrame:
|
||||
ret, data = self._ctx.get_market_snapshot(codes)
|
||||
if ret != RET_OK:
|
||||
raise RuntimeError(f"Futu 获取快照失败: {data}")
|
||||
if data is None or data.empty:
|
||||
return pd.DataFrame()
|
||||
return self._snapshot_to_unified(data)
|
||||
|
||||
def _load_basicinfo(self) -> pd.DataFrame:
|
||||
if self._basicinfo_cache is not None:
|
||||
return self._basicinfo_cache
|
||||
market = self._require_market()
|
||||
ret, data = self._ctx.get_stock_basicinfo(market=market, stock_type=SecurityType.STOCK)
|
||||
if ret != RET_OK:
|
||||
raise RuntimeError(f"Futu 获取股票基础信息失败: {data}")
|
||||
if data is None:
|
||||
self._basicinfo_cache = pd.DataFrame(columns=["code", "name"])
|
||||
return self._basicinfo_cache
|
||||
frame = data.copy()
|
||||
frame["code"] = frame["code"].astype(str)
|
||||
frame["name"] = frame["name"].astype(str)
|
||||
self._basicinfo_cache = frame
|
||||
return frame
|
||||
|
||||
def fetch_spot(self) -> pd.DataFrame:
|
||||
raise RuntimeError("Futu 不支持直接拉取全市场实时快照,请使用 search_spot")
|
||||
|
||||
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
|
||||
q = query.strip()
|
||||
if not q:
|
||||
return pd.DataFrame()
|
||||
|
||||
code = self._normalize_code(q)
|
||||
if code:
|
||||
exact = self._get_snapshot([code])
|
||||
if not exact.empty:
|
||||
return exact.head(limit)
|
||||
|
||||
basic = self._load_basicinfo()
|
||||
q_lower = q.lower()
|
||||
code_col = basic["code"].astype(str)
|
||||
name_col = basic["name"].astype(str)
|
||||
mask = (
|
||||
code_col.str.lower().eq(q_lower)
|
||||
| name_col.str.lower().eq(q_lower)
|
||||
| code_col.str.lower().str.startswith(q_lower)
|
||||
| name_col.str.lower().str.startswith(q_lower)
|
||||
| code_col.str.lower().str.contains(q_lower, na=False)
|
||||
| name_col.str.lower().str.contains(q_lower, na=False)
|
||||
)
|
||||
candidates = basic.loc[mask, "code"].drop_duplicates().head(max(limit * 2, 20)).tolist()
|
||||
if not candidates:
|
||||
return pd.DataFrame()
|
||||
snap = self._get_snapshot(candidates)
|
||||
if snap.empty:
|
||||
return pd.DataFrame()
|
||||
return snap.head(limit)
|
||||
|
||||
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
|
||||
normalized = self._normalize_code(code)
|
||||
page_key = None
|
||||
frames: list[pd.DataFrame] = []
|
||||
|
||||
while True:
|
||||
ret, data, page_key = self._ctx.request_history_kline(
|
||||
normalized,
|
||||
start=start.strftime("%Y-%m-%d"),
|
||||
end=end.strftime("%Y-%m-%d"),
|
||||
ktype=KLType.K_DAY,
|
||||
autype=AuType.QFQ,
|
||||
max_count=1000,
|
||||
page_req_key=page_key,
|
||||
)
|
||||
if ret != RET_OK:
|
||||
raise RuntimeError(f"Futu 获取历史 K 线失败: {data}")
|
||||
if data is not None and not data.empty:
|
||||
frames.append(data.copy())
|
||||
if page_key is None:
|
||||
break
|
||||
|
||||
if not frames:
|
||||
raise RuntimeError(f"未获取到 K 线数据: {normalized}")
|
||||
|
||||
full = pd.concat(frames, ignore_index=True)
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"date": pd.to_datetime(full["time_key"]),
|
||||
"close": pd.to_numeric(full["close"], errors="coerce"),
|
||||
"volume": pd.to_numeric(full["volume"], errors="coerce"),
|
||||
}
|
||||
).dropna()
|
||||
if frame.empty:
|
||||
raise RuntimeError(f"K 线数据为空: {normalized}")
|
||||
return frame.sort_values("date").reset_index(drop=True)
|
||||
164
python-app/app/market/service.py
Normal file
164
python-app/app/market/service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from time import time
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from market.provider_base import MarketDataProvider
|
||||
|
||||
|
||||
DISPLAY_FIELDS = [
|
||||
"代码",
|
||||
"名称",
|
||||
"最新价",
|
||||
"涨跌幅",
|
||||
"涨跌额",
|
||||
"成交量",
|
||||
"成交额",
|
||||
"振幅",
|
||||
"最高",
|
||||
"最低",
|
||||
"今开",
|
||||
"昨收",
|
||||
"市盈率-动态",
|
||||
"市净率",
|
||||
"总市值",
|
||||
"流通市值",
|
||||
]
|
||||
|
||||
|
||||
def _safe(v: Any) -> Any:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, float) and pd.isna(v):
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
def build_signals(hist: pd.DataFrame, sma_fast: int, sma_slow: int) -> list[dict[str, Any]]:
|
||||
frame = hist.copy()
|
||||
frame["sma_fast"] = frame["close"].rolling(sma_fast).mean()
|
||||
frame["sma_slow"] = frame["close"].rolling(sma_slow).mean()
|
||||
signals: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(1, len(frame)):
|
||||
prev = frame.iloc[i - 1]
|
||||
curr = frame.iloc[i]
|
||||
if pd.isna(prev["sma_fast"]) or pd.isna(prev["sma_slow"]) or pd.isna(curr["sma_fast"]) or pd.isna(curr["sma_slow"]):
|
||||
continue
|
||||
cross_up = prev["sma_fast"] <= prev["sma_slow"] and curr["sma_fast"] > curr["sma_slow"]
|
||||
cross_down = prev["sma_fast"] >= prev["sma_slow"] and curr["sma_fast"] < curr["sma_slow"]
|
||||
if cross_up:
|
||||
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "buy", "price": float(curr["close"])})
|
||||
elif cross_down:
|
||||
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "sell", "price": float(curr["close"])})
|
||||
return signals
|
||||
|
||||
|
||||
def resolve_candidates(provider: MarketDataProvider, query: str, limit: int) -> pd.DataFrame:
|
||||
return provider.search_spot(query=query, limit=limit)
|
||||
|
||||
|
||||
def _pick_first_candidate(provider: MarketDataProvider, query: str) -> pd.Series:
|
||||
candidates = resolve_candidates(provider, query=query, limit=1)
|
||||
if candidates.empty:
|
||||
raise RuntimeError(f"未找到匹配股票: {query}")
|
||||
row = candidates.iloc[0]
|
||||
code = str(row.get("代码", "")).strip()
|
||||
if not code:
|
||||
raise RuntimeError("行情返回缺少代码字段")
|
||||
return row
|
||||
|
||||
|
||||
def _build_info(provider: MarketDataProvider, row: pd.Series) -> dict[str, Any]:
|
||||
code = str(row.get("代码", "")).strip()
|
||||
return {
|
||||
"code": code,
|
||||
"name": str(row.get("名称", "")),
|
||||
"price": _safe(row.get("最新价")),
|
||||
"change_pct": _safe(row.get("涨跌幅")),
|
||||
"change": _safe(row.get("涨跌额")),
|
||||
"volume": _safe(row.get("成交量")),
|
||||
"turnover": _safe(row.get("成交额")),
|
||||
"amplitude": _safe(row.get("振幅")),
|
||||
"high": _safe(row.get("最高")),
|
||||
"low": _safe(row.get("最低")),
|
||||
"open": _safe(row.get("今开")),
|
||||
"prev_close": _safe(row.get("昨收")),
|
||||
"pe": _safe(row.get("市盈率-动态")),
|
||||
"pb": _safe(row.get("市净率")),
|
||||
"market_cap": _safe(row.get("总市值")),
|
||||
"float_market_cap": _safe(row.get("流通市值")),
|
||||
"provider": provider.provider_name,
|
||||
"channel": provider.channel,
|
||||
}
|
||||
|
||||
|
||||
def build_realtime_info(provider: MarketDataProvider, query: str) -> dict[str, Any]:
|
||||
row = _pick_first_candidate(provider, query=query)
|
||||
info = _build_info(provider=provider, row=row)
|
||||
now = datetime.now()
|
||||
price = info.get("price")
|
||||
volume = info.get("volume")
|
||||
realtime_point = {
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"close": float(price) if price is not None else None,
|
||||
"volume": float(volume) if volume is not None else 0.0,
|
||||
}
|
||||
return {
|
||||
"info": info,
|
||||
"realtime_point": realtime_point,
|
||||
"source": "realtime",
|
||||
"updated_at": now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
|
||||
_DASHBOARD_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||
_DASHBOARD_CACHE_TTL_SECONDS = 15.0
|
||||
|
||||
|
||||
def build_dashboard(
|
||||
provider: MarketDataProvider,
|
||||
query: str,
|
||||
days: int,
|
||||
sma_fast: int,
|
||||
sma_slow: int,
|
||||
) -> dict[str, Any]:
|
||||
if sma_fast >= sma_slow:
|
||||
raise RuntimeError("sma_fast must be less than sma_slow")
|
||||
|
||||
row = _pick_first_candidate(provider, query=query)
|
||||
code = str(row.get("代码", "")).strip()
|
||||
cache_key = "|".join([provider.provider_name, provider.channel, code, str(days), str(sma_fast), str(sma_slow)])
|
||||
now = time()
|
||||
cached = _DASHBOARD_CACHE.get(cache_key)
|
||||
if cached is not None and now - cached[0] <= _DASHBOARD_CACHE_TTL_SECONDS:
|
||||
return cached[1]
|
||||
|
||||
end = datetime.now()
|
||||
lookback_days = max(days + (sma_slow * 3), days + 30)
|
||||
start = end - timedelta(days=lookback_days)
|
||||
hist = provider.fetch_daily_kline(code=code, start=start, end=end).tail(days).reset_index(drop=True)
|
||||
if hist.empty:
|
||||
raise RuntimeError(f"未获取到 K 线数据: {code}")
|
||||
|
||||
signals = build_signals(hist, sma_fast=sma_fast, sma_slow=sma_slow)
|
||||
hist["sma_fast"] = hist["close"].rolling(sma_fast).mean()
|
||||
hist["sma_slow"] = hist["close"].rolling(sma_slow).mean()
|
||||
|
||||
points = [
|
||||
{
|
||||
"date": d.strftime("%Y-%m-%d"),
|
||||
"close": float(c),
|
||||
"sma_fast": _safe(f),
|
||||
"sma_slow": _safe(s),
|
||||
"volume": float(v),
|
||||
}
|
||||
for d, c, f, s, v in zip(hist["date"], hist["close"], hist["sma_fast"], hist["sma_slow"], hist["volume"])
|
||||
]
|
||||
|
||||
result = {"info": _build_info(provider=provider, row=row), "points": points, "signals": signals}
|
||||
_DASHBOARD_CACHE[cache_key] = (now, result)
|
||||
return result
|
||||
32
python-app/app/market/types.py
Normal file
32
python-app/app/market/types.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpotRow:
|
||||
code: str
|
||||
name: str
|
||||
price: Any
|
||||
change_pct: Any
|
||||
change: Any
|
||||
volume: Any
|
||||
turnover: Any
|
||||
amplitude: Any
|
||||
high: Any
|
||||
low: Any
|
||||
open: Any
|
||||
prev_close: Any
|
||||
pe: Any
|
||||
pb: Any
|
||||
market_cap: Any
|
||||
float_market_cap: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class KlineRow:
|
||||
date: datetime
|
||||
close: float
|
||||
volume: float
|
||||
20
python-app/app/sample_data/klines.jsonl
Normal file
20
python-app/app/sample_data/klines.jsonl
Normal file
@@ -0,0 +1,20 @@
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430400000,"open_time":1711430400000,"close_time":1711430459999,"open":"68000.00","high":"68025.00","low":"67990.00","close":"68010.00","volume":"12.40","trade_num":102,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430460000,"open_time":1711430460000,"close_time":1711430519999,"open":"68010.00","high":"68040.00","low":"68000.00","close":"68028.00","volume":"10.12","trade_num":96,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430520000,"open_time":1711430520000,"close_time":1711430579999,"open":"68028.00","high":"68036.00","low":"68005.00","close":"68012.00","volume":"8.80","trade_num":87,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430580000,"open_time":1711430580000,"close_time":1711430639999,"open":"68012.00","high":"68022.00","low":"67980.00","close":"67992.00","volume":"11.50","trade_num":120,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430640000,"open_time":1711430640000,"close_time":1711430699999,"open":"67992.00","high":"68018.00","low":"67970.00","close":"68005.00","volume":"9.64","trade_num":91,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430700000,"open_time":1711430700000,"close_time":1711430759999,"open":"68005.00","high":"68055.00","low":"68002.00","close":"68048.00","volume":"14.03","trade_num":133,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430760000,"open_time":1711430760000,"close_time":1711430819999,"open":"68048.00","high":"68070.00","low":"68030.00","close":"68062.00","volume":"13.38","trade_num":121,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430820000,"open_time":1711430820000,"close_time":1711430879999,"open":"68062.00","high":"68080.00","low":"68040.00","close":"68044.00","volume":"10.99","trade_num":99,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430880000,"open_time":1711430880000,"close_time":1711430939999,"open":"68044.00","high":"68046.00","low":"68000.00","close":"68006.00","volume":"12.75","trade_num":115,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430940000,"open_time":1711430940000,"close_time":1711430999999,"open":"68006.00","high":"68012.00","low":"67972.00","close":"67986.00","volume":"11.66","trade_num":105,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431000000,"open_time":1711431000000,"close_time":1711431059999,"open":"67986.00","high":"68008.00","low":"67970.00","close":"67996.00","volume":"9.21","trade_num":88,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431060000,"open_time":1711431060000,"close_time":1711431119999,"open":"67996.00","high":"68035.00","low":"67992.00","close":"68020.00","volume":"10.77","trade_num":94,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431120000,"open_time":1711431120000,"close_time":1711431179999,"open":"68020.00","high":"68058.00","low":"68018.00","close":"68052.00","volume":"12.62","trade_num":110,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431180000,"open_time":1711431180000,"close_time":1711431239999,"open":"68052.00","high":"68066.00","low":"68030.00","close":"68042.00","volume":"8.19","trade_num":76,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431240000,"open_time":1711431240000,"close_time":1711431299999,"open":"68042.00","high":"68052.00","low":"68008.00","close":"68016.00","volume":"10.40","trade_num":93,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431300000,"open_time":1711431300000,"close_time":1711431359999,"open":"68016.00","high":"68018.00","low":"67986.00","close":"67998.00","volume":"9.07","trade_num":84,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431360000,"open_time":1711431360000,"close_time":1711431419999,"open":"67998.00","high":"68012.00","low":"67974.00","close":"67988.00","volume":"7.55","trade_num":69,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431420000,"open_time":1711431420000,"close_time":1711431479999,"open":"67988.00","high":"68028.00","low":"67982.00","close":"68024.00","volume":"11.88","trade_num":101,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431480000,"open_time":1711431480000,"close_time":1711431539999,"open":"68024.00","high":"68070.00","low":"68020.00","close":"68062.00","volume":"13.45","trade_num":124,"final":true}
|
||||
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431540000,"open_time":1711431540000,"close_time":1711431599999,"open":"68062.00","high":"68082.00","low":"68048.00","close":"68074.00","volume":"12.03","trade_num":111,"final":true}
|
||||
91
python-app/app/stock_lookup.py
Normal file
91
python-app/app/stock_lookup.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
|
||||
from market import create_provider, load_market_config
|
||||
from market.service import DISPLAY_FIELDS, resolve_candidates
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Stock quick lookup (A-share)")
|
||||
parser.add_argument("query", nargs="?", default="", help="stock code/name/abbr")
|
||||
parser.add_argument("--limit", type=int, default=8, help="max candidate rows")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def safe_str(value: object) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, float) and pd.isna(value):
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
def print_row(row: pd.Series) -> None:
|
||||
print("\n===== 股票信息 =====")
|
||||
for field in DISPLAY_FIELDS:
|
||||
if field in row.index:
|
||||
print(f"{field}: {safe_str(row[field])}")
|
||||
print("===================\n")
|
||||
|
||||
|
||||
def pick_candidate(candidates: pd.DataFrame) -> pd.Series | None:
|
||||
if len(candidates) == 1:
|
||||
return candidates.iloc[0]
|
||||
|
||||
print("\n匹配到多个标的,请选择:")
|
||||
for idx, (_, row) in enumerate(candidates.iterrows(), start=1):
|
||||
print(f"{idx}. {safe_str(row['代码'])} {safe_str(row['名称'])} 最新价={safe_str(row.get('最新价'))}")
|
||||
|
||||
choice = input("输入序号(直接回车默认 1): ").strip()
|
||||
if not choice:
|
||||
return candidates.iloc[0]
|
||||
if not choice.isdigit():
|
||||
print("输入无效")
|
||||
return None
|
||||
pos = int(choice)
|
||||
if pos < 1 or pos > len(candidates):
|
||||
print("输入超出范围")
|
||||
return None
|
||||
return candidates.iloc[pos - 1]
|
||||
|
||||
|
||||
def run_once(query: str, limit: int) -> None:
|
||||
provider = create_provider(load_market_config())
|
||||
candidates = resolve_candidates(provider, query, limit)
|
||||
if candidates.empty:
|
||||
print(f"未找到匹配股票: {query}")
|
||||
return
|
||||
selected = pick_candidate(candidates)
|
||||
if selected is None:
|
||||
return
|
||||
print_row(selected)
|
||||
|
||||
|
||||
def interactive_loop(limit: int) -> None:
|
||||
print("股票速查已启动,输入代码/名称(输入 q 退出)")
|
||||
while True:
|
||||
query = input("查询> ").strip()
|
||||
if query.lower() in {"q", "quit", "exit"}:
|
||||
print("已退出股票速查")
|
||||
return
|
||||
if not query:
|
||||
continue
|
||||
try:
|
||||
run_once(query, limit)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"查询失败: {exc}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if not args.query:
|
||||
interactive_loop(args.limit)
|
||||
return
|
||||
|
||||
run_once(args.query, args.limit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
python-app/app/strategies/__init__.py
Normal file
3
python-app/app/strategies/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sma_cross import SmaCrossStrategy
|
||||
|
||||
__all__ = ["SmaCrossStrategy"]
|
||||
22
python-app/app/strategies/sma_cross.py
Normal file
22
python-app/app/strategies/sma_cross.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import backtrader as bt
|
||||
|
||||
|
||||
class SmaCrossStrategy(bt.Strategy):
|
||||
params = (
|
||||
("fast", 3),
|
||||
("slow", 8),
|
||||
("stake", 1),
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.fast_sma = bt.indicators.SMA(self.data.close, period=self.params.fast)
|
||||
self.slow_sma = bt.indicators.SMA(self.data.close, period=self.params.slow)
|
||||
self.cross_signal = bt.indicators.CrossOver(self.fast_sma, self.slow_sma)
|
||||
|
||||
def next(self) -> None:
|
||||
if not self.position and self.cross_signal > 0:
|
||||
self.buy(size=self.params.stake)
|
||||
elif self.position and self.cross_signal < 0:
|
||||
self.close()
|
||||
211
python-app/app/web/index.html
Normal file
211
python-app/app/web/index.html
Normal file
@@ -0,0 +1,211 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>AITrading 股票看板</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.3/dist/chart.umd.min.js"></script>
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Arial, sans-serif; margin: 20px; background:#f7f8fa; color:#1f2937; }
|
||||
.card { background:#fff; border-radius:10px; box-shadow:0 2px 8px rgba(0,0,0,.06); padding:16px; margin-bottom:16px; }
|
||||
.row { display:flex; gap:12px; flex-wrap:wrap; }
|
||||
input, button { padding:8px 10px; border-radius:8px; border:1px solid #d1d5db; }
|
||||
button { background:#2563eb; color:#fff; border:none; cursor:pointer; }
|
||||
.grid { display:grid; grid-template-columns: repeat(3, minmax(180px,1fr)); gap:8px; }
|
||||
.item { background:#f9fafb; border-radius:8px; padding:8px; }
|
||||
.label { color:#6b7280; font-size:12px; }
|
||||
.value { font-size:14px; font-weight:600; }
|
||||
#status { font-size:13px; color:#6b7280; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h2>股票信息与买卖点</h2>
|
||||
<div class="row">
|
||||
<input id="query" placeholder="输入股票代码或名称,如 600519 / 贵州茅台" style="min-width:300px;" />
|
||||
<input id="fast" type="number" min="2" max="60" value="5" />
|
||||
<input id="slow" type="number" min="3" max="120" value="20" />
|
||||
<button onclick="searchStock()">查询</button>
|
||||
</div>
|
||||
<p id="status">请输入代码或名称并点击查询。</p>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<div id="info" class="grid"></div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<canvas id="chart" height="120"></canvas>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let chart;
|
||||
let realtimeWs = null;
|
||||
let currentPoints = [];
|
||||
let currentSignals = [];
|
||||
let currentFast = 5;
|
||||
let currentSlow = 20;
|
||||
|
||||
function renderInfo(info) {
|
||||
const fields = [
|
||||
["代码", info.code], ["名称", info.name], ["最新价", info.price],
|
||||
["涨跌幅", info.change_pct], ["涨跌额", info.change], ["成交额", info.turnover],
|
||||
["PE(动态)", info.pe], ["PB", info.pb], ["总市值", info.market_cap]
|
||||
];
|
||||
const container = document.getElementById("info");
|
||||
container.innerHTML = fields.map(([k, v]) =>
|
||||
`<div class="item"><div class="label">${k}</div><div class="value">${v ?? "-"}</div></div>`
|
||||
).join("");
|
||||
}
|
||||
|
||||
function buildSignals(points, fast, slow) {
|
||||
const signals = [];
|
||||
for (let i = 1; i < points.length; i++) {
|
||||
const prev = points[i - 1];
|
||||
const curr = points[i];
|
||||
if (prev.sma_fast == null || prev.sma_slow == null || curr.sma_fast == null || curr.sma_slow == null) continue;
|
||||
const crossUp = prev.sma_fast <= prev.sma_slow && curr.sma_fast > curr.sma_slow;
|
||||
const crossDown = prev.sma_fast >= prev.sma_slow && curr.sma_fast < curr.sma_slow;
|
||||
if (crossUp) signals.push({ date: curr.date, type: "buy", price: curr.close });
|
||||
else if (crossDown) signals.push({ date: curr.date, type: "sell", price: curr.close });
|
||||
}
|
||||
return signals;
|
||||
}
|
||||
|
||||
function recomputeSma(points, fast, slow) {
|
||||
const sumFast = [];
|
||||
const sumSlow = [];
|
||||
for (let i = 0; i < points.length; i++) {
|
||||
sumFast[i] = (sumFast[i - 1] || 0) + Number(points[i].close || 0);
|
||||
sumSlow[i] = (sumSlow[i - 1] || 0) + Number(points[i].close || 0);
|
||||
if (i >= fast - 1) {
|
||||
const prev = i - fast >= 0 ? sumFast[i - fast] : 0;
|
||||
points[i].sma_fast = (sumFast[i] - prev) / fast;
|
||||
} else {
|
||||
points[i].sma_fast = null;
|
||||
}
|
||||
if (i >= slow - 1) {
|
||||
const prev = i - slow >= 0 ? sumSlow[i - slow] : 0;
|
||||
points[i].sma_slow = (sumSlow[i] - prev) / slow;
|
||||
} else {
|
||||
points[i].sma_slow = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function upsertRealtimePoint(point) {
|
||||
if (!point || point.close == null || currentPoints.length === 0) return;
|
||||
const last = currentPoints[currentPoints.length - 1];
|
||||
if (last.date === point.date) {
|
||||
last.close = Number(point.close);
|
||||
if (point.volume != null) last.volume = Number(point.volume);
|
||||
} else if (point.date > last.date) {
|
||||
currentPoints.push({
|
||||
date: point.date,
|
||||
close: Number(point.close),
|
||||
volume: Number(point.volume || 0),
|
||||
sma_fast: null,
|
||||
sma_slow: null,
|
||||
});
|
||||
}
|
||||
recomputeSma(currentPoints, currentFast, currentSlow);
|
||||
currentSignals = buildSignals(currentPoints, currentFast, currentSlow);
|
||||
drawChart(currentPoints, currentSignals);
|
||||
}
|
||||
|
||||
function drawChart(points, signals) {
|
||||
const labels = points.map(p => p.date);
|
||||
const close = points.map(p => p.close);
|
||||
const fast = points.map(p => p.sma_fast);
|
||||
const slow = points.map(p => p.sma_slow);
|
||||
|
||||
const buyPoints = signals.filter(s => s.type === "buy").map(s => ({x: s.date, y: s.price}));
|
||||
const sellPoints = signals.filter(s => s.type === "sell").map(s => ({x: s.date, y: s.price}));
|
||||
|
||||
if (!chart) {
|
||||
chart = new Chart(document.getElementById("chart"), {
|
||||
type: "line",
|
||||
data: {
|
||||
labels,
|
||||
datasets: [
|
||||
{ label: "收盘价", data: close, borderColor: "#1f2937", tension: 0.2, pointRadius: 0 },
|
||||
{ label: "SMA Fast", data: fast, borderColor: "#2563eb", tension: 0.2, pointRadius: 0 },
|
||||
{ label: "SMA Slow", data: slow, borderColor: "#16a34a", tension: 0.2, pointRadius: 0 },
|
||||
{ label: "买点", data: buyPoints, parsing: {xAxisKey: "x", yAxisKey: "y"}, showLine:false, pointRadius:5, pointStyle:"triangle", pointBackgroundColor:"#dc2626", pointBorderColor:"#dc2626" },
|
||||
{ label: "卖点", data: sellPoints, parsing: {xAxisKey: "x", yAxisKey: "y"}, showLine:false, pointRadius:5, pointStyle:"rectRot", pointBackgroundColor:"#7c3aed", pointBorderColor:"#7c3aed" }
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
animation: false,
|
||||
scales: { x: { ticks: { maxTicksLimit: 10 } } }
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
chart.data.labels = labels;
|
||||
chart.data.datasets[0].data = close;
|
||||
chart.data.datasets[1].data = fast;
|
||||
chart.data.datasets[2].data = slow;
|
||||
chart.data.datasets[3].data = buyPoints;
|
||||
chart.data.datasets[4].data = sellPoints;
|
||||
chart.update("none");
|
||||
}
|
||||
|
||||
async function searchStock() {
|
||||
const query = document.getElementById("query").value.trim();
|
||||
const fast = Number(document.getElementById("fast").value);
|
||||
const slow = Number(document.getElementById("slow").value);
|
||||
if (!query) return;
|
||||
const status = document.getElementById("status");
|
||||
status.innerText = "正在建立实时 WebSocket 连接...";
|
||||
try {
|
||||
if (realtimeWs) {
|
||||
realtimeWs.close();
|
||||
realtimeWs = null;
|
||||
}
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
realtimeWs = new WebSocket(`${protocol}://${window.location.host}/ws/stock/realtime?query=${encodeURIComponent(query)}`);
|
||||
realtimeWs.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.error) {
|
||||
status.innerText = `实时通道错误: ${data.error}`;
|
||||
return;
|
||||
}
|
||||
renderInfo(data.info);
|
||||
upsertRealtimePoint(data.realtime_point);
|
||||
status.innerText = `实时更新中:${data.info.code} ${data.info.name}(${data.updated_at})`;
|
||||
};
|
||||
realtimeWs.onclose = () => {
|
||||
if (status.innerText.startsWith("实时更新中")) {
|
||||
status.innerText = "实时连接已断开";
|
||||
}
|
||||
};
|
||||
realtimeWs.onerror = () => {
|
||||
status.innerText = "实时连接失败";
|
||||
};
|
||||
|
||||
const res = await fetch(`/api/stock?query=${encodeURIComponent(query)}&sma_fast=${fast}&sma_slow=${slow}`);
|
||||
const data = await res.json();
|
||||
if (!res.ok) throw new Error(data.detail || "查询失败");
|
||||
currentFast = fast;
|
||||
currentSlow = slow;
|
||||
currentPoints = data.points.map(p => ({
|
||||
date: p.date,
|
||||
close: Number(p.close),
|
||||
sma_fast: p.sma_fast == null ? null : Number(p.sma_fast),
|
||||
sma_slow: p.sma_slow == null ? null : Number(p.sma_slow),
|
||||
volume: p.volume == null ? 0 : Number(p.volume),
|
||||
}));
|
||||
currentSignals = data.signals;
|
||||
renderInfo(data.info);
|
||||
drawChart(currentPoints, currentSignals);
|
||||
status.innerText = `历史K线加载完成:${data.info.code} ${data.info.name},买点 ${data.signals.filter(s=>s.type==="buy").length},卖点 ${data.signals.filter(s=>s.type==="sell").length}(实时通道已连接)`;
|
||||
} catch (err) {
|
||||
status.innerText = `错误: ${err.message}`;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
82
python-app/app/web_app.py
Normal file
82
python-app/app/web_app.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from market import create_provider, load_market_config
|
||||
from market.service import build_dashboard, build_realtime_info
|
||||
|
||||
app = FastAPI(title="AITrading Simple Frontend API", version="1.0.0")
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
WEB_INDEX = BASE_DIR / "web" / "index.html"
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def index() -> FileResponse:
|
||||
if not WEB_INDEX.exists():
|
||||
raise HTTPException(status_code=500, detail="frontend page not found")
|
||||
return FileResponse(WEB_INDEX)
|
||||
|
||||
|
||||
@app.get("/api/stock")
|
||||
def stock_dashboard(
|
||||
query: str = Query(..., description="stock code or name"),
|
||||
days: int = Query(120, ge=30, le=365),
|
||||
sma_fast: int = Query(5, ge=2, le=60),
|
||||
sma_slow: int = Query(20, ge=3, le=120),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
provider = create_provider(load_market_config())
|
||||
return build_dashboard(
|
||||
provider=provider,
|
||||
query=query,
|
||||
days=days,
|
||||
sma_fast=sma_fast,
|
||||
sma_slow=sma_slow,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.get("/api/stock/realtime")
|
||||
def stock_realtime(
|
||||
query: str = Query(..., description="stock code or name"),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
provider = create_provider(load_market_config())
|
||||
return build_realtime_info(provider=provider, query=query)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.websocket("/ws/stock/realtime")
|
||||
async def stock_realtime_ws(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
query = (websocket.query_params.get("query") or "").strip()
|
||||
if not query:
|
||||
await websocket.send_json({"error": "query is required"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
provider = create_provider(load_market_config())
|
||||
payload = build_realtime_info(provider=provider, query=query)
|
||||
await websocket.send_json(payload)
|
||||
# WebSocket channel keeps the connection alive, front-end no longer polls via HTTP.
|
||||
await asyncio.sleep(2)
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await websocket.send_json({"error": str(exc)})
|
||||
await websocket.close(code=1011)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("web_app:app", host="0.0.0.0", port=8000, reload=False)
|
||||
10
python-app/requirements.txt
Normal file
10
python-app/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
pandas
|
||||
mplfinance
|
||||
matplotlib
|
||||
backtrader
|
||||
redis
|
||||
akshare
|
||||
fastapi
|
||||
uvicorn
|
||||
futu-api
|
||||
websockets
|
||||
Reference in New Issue
Block a user