feat: 修复报错

This commit is contained in:
Daniel
2026-03-26 14:13:44 +08:00
commit b2223ec058
31 changed files with 17401 additions and 0 deletions

18
.gitignore vendored Normal file
View File

@@ -0,0 +1,18 @@
.DS_Store
.idea/
.vscode/
__pycache__/
*.pyc
*.pyo
*.pyd
.pytest_cache/
.mypy_cache/
build/
dist/
venv/
.venv/
*.log
*.csv
*.db
data/
!data/.gitkeep

152
A.md Normal file
View File

@@ -0,0 +1,152 @@
使用清华源安装(推荐)
如官方源下载较慢,推荐使用清华镜像源:
pip install cmesdata -U -i https://pypi.tuna.tsinghua.edu.cn/simple
接口介绍
展示国内股票数据接口的功能说明、调用样例和注意事项。
股票接口提供实时行情、历史K线、分笔tick等多种数据获取方式。
实时行情支持获取沪深股票、ETF、指数、可转债的五档订单薄实时行情数据。
历史数据支持获取日线、分钟线等不同频率的历史K线数据可盘中更新。
分笔数据提供股票、ETF、可转债的分笔成交tick历史行情数据盘后更新。
登录、登出接口
登录 / token
登录接口传参为个人token。
#登录接口
login("c79cesdffwerfer56422f9sdffgrtg35ec8")
#退出登录
login_out()
*注意事项:
1、token可在权限申请处复制开通权限后即可使用。
2、登录接口在程序启动后只需要调用一次建立长连接即可请不要频繁调用登录接口
五档订单薄实时行情
五档 / 实时
盘中提供沪深股票、ETF、指数的五档订单薄实时行情数据。
*入参为列表返回为DataFrame格式数据
3、接口仅盘中提供实时行情盘后返回空
df = get_real_hq(['SH.600000', 'SZ.000001']) # 注意:品种入参是列表,不要忘记[]
*注意事项:
1、单次可获取80只股票超过请分批获取
2、单次请求股票代码列表中不能有错误或者停牌的代码否则返回空数据。
代码 价格 昨收价 开盘价 最高价 最低价 时间 成交量 成交额 总卖 总买 买一价 卖一价 买一量 卖一量 买二价 卖二价 买二量 卖二量 买三价 卖三价 买三量 卖三量 买四价 卖四价 买四量 卖四量 买五价 卖五价 买五量 卖五量
0 600000 11.37 11.62 11.60 11.65 11.35 2025-11-17 21:51:54 984855 1.127109e+09 644410 340445 11.36 11.37 4767 944 11.35 11.38 9624 4537 11.34 11.39 1696 3879 11.33 11.40 3259 1090 11.32 11.41 1176 422
1 000001 11.67 11.75 11.75 11.75 11.62 2025-11-17 21:51:54 995232 1.161416e+09 531978 463255 11.67 11.68 1559 1312 11.66 11.69 6836 6015 11.65 11.70 11362 5688 11.64 11.71 5905 2062 11.63 11.72 5870 4163
可转债五档订单薄实时行情
五档 / 实时
提供可转债的五档订单薄实时行情数据。
*入参为列表返回为DataFrame格式数据
df = get_real_kzz(['SH.113634', 'SZ.127071']) # 注意:品种入参是列表,不要忘记[]
*注意事项:
1、单次可获取80只可转债超过请分批获取
2、单次请求代码列表中不能有错误或者停牌的可转债代码否则返回空数据。
3、接口仅盘中提供实时行情盘后返回空
代码 价格 昨收价 开盘价 最高价 最低价 时间 成交量 成交额 总卖 总买 买一价 卖一价 买一量 卖一量 买二价 卖二价 买二量 卖二量 买三价 卖三价 买三量 卖三量 买四价 卖四价 买四量 卖四量 买五价 卖五价 买五量 卖五量
0 113634 126.930 127.587 127.398 127.648 126.704 2025-11-17 21:54:26 21831 27720952.0 10838 10993 126.900 127.001 645 4 126.843 127.002 1 99 126.800 127.040 200 2 126.78 127.043 4 99 126.763 127.094 1 28
1 127071 135.413 136.074 136.670 139.000 135.101 2025-11-17 21:54:26 29663 40522148.0 17221 12442 135.348 135.413 2 2 135.300 135.414 1 13 135.211 135.415 1 2 135.21 135.416 3 30 135.209 135.458 3 8
历史K线数据
股票 / K线
提供沪深股票、ETF、可转债日线、分钟线等近5个月历史行情数据可盘中实时更新。
*入参为代码、起始日期、结束日期、频率返回为DataFrame格式数据
频率1min、5min、15min、30min、60min、D(日线)、W(周线)、M(月线)
df = get_history_data("SZ.000001", "2025-01-01", "2025-11-13", "1min")
注意上海和深圳股票、ETF、可转债代码前缀分别为SH和SZ。
时间 开盘价 最高价 最低价 收盘价 成交量 成交额
0 2025-07-07 09:31:00 12.60 12.65 12.60 12.65 5170900.0 65207856.0
1 2025-07-07 09:32:00 12.64 12.68 12.60 12.62 3552900.0 44927408.0
2 2025-07-07 09:33:00 12.62 12.67 12.60 12.67 2450600.0 30951484.0
3 2025-07-07 09:34:00 12.66 12.66 12.63 12.64 1479900.0 18712490.0
4 2025-07-07 09:35:00 12.64 12.64 12.61 12.61 1256700.0 15860039.0
... ... ... ... ... ... ... ...
21115 2025-11-13 14:56:00 11.70 11.71 11.69 11.70 875500.0 10241476.0
21116 2025-11-13 14:57:00 11.69 11.71 11.69 11.69 207700.0 2430150.0
21117 2025-11-13 14:58:00 11.70 11.70 11.70 11.70 1300.0 15210.0
21118 2025-11-13 14:59:00 11.70 11.70 11.70 11.70 0.0 0.0
21119 2025-11-13 15:00:00 11.70 11.70 11.70 11.70 1150500.0 13460850.0
指数历史K线数据
指数 / K线
提供指数日线、分钟线等近5个月历史行情数据可盘中实时更新。
*入参为代码、起始日期、结束日期、频率返回为DataFrame格式数据
频率1min、5min、15min、30min、60min、D(日线)、W(周线)、M(月线)
df = get_index_data("SH.000001", "2025-01-01", "2025-11-13", "1min")
注意:
上海指数00开头例如SH.000001(上证指数)
深圳指数399开头例如SZ.399001(深证指数)
时间 开盘价 最高价 最低价 收盘价 成交量 成交额
0 2025-07-07 09:31:00 3467.98 3467.98 3466.68 3466.68 212590960.0 2.125910e+10
1 2025-07-07 09:32:00 3466.52 3468.88 3466.32 3467.20 126269952.0 1.262700e+10
2 2025-07-07 09:33:00 3467.32 3469.01 3465.88 3468.71 86889384.0 8.688939e+09
3 2025-07-07 09:34:00 3469.52 3469.52 3467.25 3467.40 74650608.0 7.465062e+09
4 2025-07-07 09:35:00 3467.20 3467.20 3464.09 3464.17 70837960.0 7.083796e+09
... ... ... ... ... ... ... ...
21115 2025-11-13 14:56:00 4029.80 4030.13 4029.43 4029.93 60431156.0 6.043117e+09
21116 2025-11-13 14:57:00 4030.20 4030.40 4029.53 4029.53 67377768.0 6.737778e+09
21117 2025-11-13 14:58:00 4030.11 4030.33 4030.11 4030.28 4642201.0 4.642203e+08
21118 2025-11-13 14:59:00 4030.28 4030.28 4030.28 4030.28 0.0 0.000000e+00
21119 2025-11-13 15:00:00 4029.99 4029.99 4029.50 4029.50 102880552.0 1.028806e+10
分笔成交tick
高频 / tick
提供股票、ETF、可转债分笔成交tick近5个月历史行情数据盘后更新。
*入参为代码、日期返回为DataFrame格式数据
df = get_tick("SZ.000001", "2025-11-17")
*注意事项:返回结果中的买卖方向:0(买)1(卖)2(中性)
时间 价格 成交量 买卖方向
0 09:25 11.68 5277 2
1 09:30 11.70 1806 0
2 09:30 11.68 4927 1
3 09:30 11.68 993 1
4 09:30 11.69 265 1
... ... ... ... ...
4223 14:56 11.70 51 0
4224 14:56 11.70 11 0
4225 14:56 11.69 78 1
4226 14:57 11.70 13 2
4227 15:00 11.70 11505 2
完整代码展示
样例 / demo
使用接口前请先安装库cmesdata参考接口安装页面指导以下为样例代码展示如未获取到数据请先检查入参是否正确实时行情接口入参中不能有错误代码和停盘股票
from cmesdata import *
# 使用自己的token登录
login("c79cesab1a56422h9a920fe0aad35e38")
# 实时股票五档订单薄
df = get_real_hq(['SH.600000', 'SZ.000001']) # 注意:品种入参是列表,不要忘记[]
print(df)
# 实时可转债五档订单薄
df = get_real_kzz(['SH.113634', 'SZ.127071']) # 注意:品种入参是列表,不要忘记[]
print(df)
# 历史行情数据
df = get_history_data("SZ.000001", "2026-03-12", "2026-03-13", "1min")
print(df)
# 指数历史行情
df = get_index_data("SH.000001", "2026-03-12", "2026-03-13", "D")
print(df)
# 历史分笔tick
df = get_tick("SH.600000", "2026-03-13")
print(df)
# 程序结束退出登录
login_out()

15140
Futu-API-Doc-zh-Python.md Normal file

File diff suppressed because it is too large Load Diff

142
README.md Normal file
View File

@@ -0,0 +1,142 @@
# AITrading 基础设施
本项目提供一个可直接在 Docker 中运行的量化交易基础设施,包含:
- Go 实时行情采集服务Binance WebSocket K 线 + Redis Publisher 接口)
- Python 本地 K 线动态展示工具pandas + mplfinance
- Backtrader 回测框架(支持 Strategy Class 快速接入)
- Go 与 Python 之间统一 JSON 数据交换协议
## 目录结构
```text
.
├── docker-compose.yml
├── go-service
│ ├── Dockerfile
│ ├── go.mod
│ └── main.go
├── python-app
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app
│ ├── data_protocol.py
│ ├── kline_viewer.py
│ ├── backtest_runner.py
│ ├── strategies
│ │ ├── __init__.py
│ │ └── sma_cross.py
│ └── sample_data
│ └── klines.jsonl
└── shared
└── protocol.md
```
## 快速启动Docker
默认已使用国内 Docker Hub 镜像前缀(`m.daocloud.io/docker.io/library`)加速基础镜像拉取。
如需切换其它镜像代理,可先设置:
```bash
export DOCKER_MIRROR_PREFIX=m.daocloud.io/docker.io/library
```
1. 构建并启动 Redis + Go 行情服务:
```bash
docker compose up --build redis go-service
```
2. 单独运行 Python 动态 K 线展示(实时订阅 Redis
```bash
docker compose run --rm python-app python app/kline_viewer.py \
--redis-host redis \
--redis-port 6379 \
--channel kline.stream
```
3. 运行 Backtrader 回测(可指定自定义策略类):
```bash
docker compose run --rm python-app python app/backtest_runner.py \
--input app/sample_data/klines.jsonl \
--strategy strategies.sma_cross.SmaCrossStrategy \
--strategy-param fast=5 \
--strategy-param slow=13
```
4. 股票速查入口(输入代码/名称):
```bash
./start.sh local quote 600519
./start.sh local quote 贵州茅台
```
5. 启动简单前端页面(信息 + 曲线 + 买卖点):
```bash
./start.sh local web
```
打开浏览器访问:`http://localhost:8000`
## 多通道数据接入(解耦)
当前已将上层业务与下游数据源解耦,上层统一走 `market` 服务层:
- `MARKET_CHANNEL=cn` + `MARKET_PROVIDER=akshare`(默认)
- `MARKET_CHANNEL=cn` + `MARKET_PROVIDER=cmesdata`(需 `CMES_TOKEN`
- `MARKET_CHANNEL=cn|hk|us` + `MARKET_PROVIDER=futu`(需本地运行 Futu OpenD
示例:
```bash
export MARKET_CHANNEL=cn
export MARKET_PROVIDER=akshare
# 或使用 cmesdata
# export MARKET_PROVIDER=cmesdata
# export CMES_TOKEN=你的token
# 或使用 futu需先启动 OpenD
# export MARKET_PROVIDER=futu
# export FUTU_HOST=127.0.0.1
# export FUTU_PORT=11111
# export FUTU_MARKET=cn # 可选: cn/hk/us默认跟随 MARKET_CHANNEL
# export FUTU_IS_ENCRYPT=0 # 可选: 1/0
```
说明:
- `web_app.py``stock_lookup.py` 不再直接依赖具体 SDK
- 新增市场(如 `us` / `hk`)时,只需新增 Provider 并在 `market/factory.py` 注册
- `futu` Provider 已接入:支持按代码/名称检索股票与日 K 线读取(通过 OpenD
## Go 服务说明
- 默认订阅:`btcusdt` `1m` K 线
- 会将标准 JSON K 线数据打印到 stdout
- 预留 `RedisPublisher` 接口,默认实现为 `NoopPublisher`
- 若配置 Redis将自动切换到 `RedisChannelPublisher`
可用环境变量:
- `BINANCE_SYMBOL`(默认 `btcusdt`
- `BINANCE_INTERVAL`(默认 `1m`
- `REDIS_ADDR`(例如 `redis:6379`
- `REDIS_PASSWORD`(可选)
- `REDIS_DB`(默认 `0`
- `REDIS_CHANNEL`(默认 `kline.stream`
## Python 回测快速接入策略
你可以新增一个策略类并通过完整路径接入:
- 文件示例:`python-app/app/strategies/my_strategy.py`
- 类示例:`class MyStrategy(bt.Strategy): ...`
- 启动参数:`--strategy strategies.my_strategy.MyStrategy`
- 参数注入:`--strategy-param key=value`(可重复传入)
- 回测简报自动输出:`Final Value``Max Drawdown (%)``Sharpe Ratio`
## 协议定义
统一 JSON 格式位于 `shared/protocol.md`Go 与 Python 都使用同一字段语义与时间戳单位(毫秒)。

43
docker-compose.yml Normal file
View File

@@ -0,0 +1,43 @@
name: ai-trading
services:
redis:
image: ${DOCKER_MIRROR_PREFIX:-m.daocloud.io/docker.io/library}/redis:7.2-alpine
command: ["redis-server", "--appendonly", "yes"]
go-service:
build:
context: ./go-service
dockerfile: Dockerfile
args:
DOCKER_MIRROR_PREFIX: ${DOCKER_MIRROR_PREFIX:-m.daocloud.io/docker.io/library}
environment:
BINANCE_SYMBOL: btcusdt
BINANCE_INTERVAL: 1m
REDIS_ADDR: redis:6379
REDIS_CHANNEL: kline.stream
REDIS_DB: 0
depends_on:
- redis
python-app:
build:
context: ./python-app
dockerfile: Dockerfile
args:
DOCKER_MIRROR_PREFIX: ${DOCKER_MIRROR_PREFIX:-m.daocloud.io/docker.io/library}
depends_on:
- redis
command: ["sleep", "infinity"]
web:
build:
context: ./python-app
dockerfile: Dockerfile
args:
DOCKER_MIRROR_PREFIX: ${DOCKER_MIRROR_PREFIX:-m.daocloud.io/docker.io/library}
depends_on:
- redis
command: ["python", "app/web_app.py"]
ports:
- "8000:8000"

20
go-service/Dockerfile Normal file
View File

@@ -0,0 +1,20 @@
ARG DOCKER_MIRROR_PREFIX=m.daocloud.io/docker.io/library
FROM ${DOCKER_MIRROR_PREFIX}/golang:1.22-bookworm
ENV GOPROXY=https://goproxy.cn,direct \
GOSUMDB=sum.golang.google.cn \
CGO_ENABLED=0
RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
&& sed -i 's|security.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources
WORKDIR /app
COPY go.mod ./
RUN go mod tidy || true
COPY . .
RUN go mod tidy
RUN go build -o /bin/go-service main.go
CMD ["/bin/go-service"]

3
go-service/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module ai-trading/go-service
go 1.22

171
go-service/main.go Normal file
View File

@@ -0,0 +1,171 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
binance "github.com/adshao/go-binance/v2"
"github.com/redis/go-redis/v9"
)
type KlineMessage struct {
Type string `json:"type"`
Source string `json:"source"`
Symbol string `json:"symbol"`
Interval string `json:"interval"`
EventTime int64 `json:"event_time"`
OpenTime int64 `json:"open_time"`
CloseTime int64 `json:"close_time"`
Open string `json:"open"`
High string `json:"high"`
Low string `json:"low"`
Close string `json:"close"`
Volume string `json:"volume"`
TradeNum int64 `json:"trade_num"`
Final bool `json:"final"`
}
type RedisPublisher interface {
Publish(ctx context.Context, channel string, payload []byte) error
Close() error
}
type NoopPublisher struct{}
func (p *NoopPublisher) Publish(_ context.Context, _ string, _ []byte) error {
return nil
}
func (p *NoopPublisher) Close() error {
return nil
}
type RedisChannelPublisher struct {
client *redis.Client
}
func (p *RedisChannelPublisher) Publish(ctx context.Context, channel string, payload []byte) error {
return p.client.Publish(ctx, channel, payload).Err()
}
func (p *RedisChannelPublisher) Close() error {
return p.client.Close()
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
symbol := getEnv("BINANCE_SYMBOL", "btcusdt")
interval := getEnv("BINANCE_INTERVAL", "1m")
channel := getEnv("REDIS_CHANNEL", "kline.stream")
publisher, err := initPublisherFromEnv(ctx)
if err != nil {
log.Fatalf("init publisher failed: %v", err)
}
defer func() {
if err := publisher.Close(); err != nil {
log.Printf("publisher close failed: %v", err)
}
}()
doneC, stopC, err := binance.WsKlineServe(symbol, interval,
func(event *binance.WsKlineEvent) {
msg := KlineMessage{
Type: "kline",
Source: "binance_ws",
Symbol: strings.ToLower(event.Symbol),
Interval: event.Kline.Interval,
EventTime: event.Time,
OpenTime: event.Kline.StartTime,
CloseTime: event.Kline.EndTime,
Open: event.Kline.Open,
High: event.Kline.High,
Low: event.Kline.Low,
Close: event.Kline.Close,
Volume: event.Kline.Volume,
TradeNum: int64(event.Kline.TradeNum),
Final: event.Kline.IsFinal,
}
payload, err := json.Marshal(msg)
if err != nil {
log.Printf("marshal failed: %v", err)
return
}
fmt.Println(string(payload))
if err := publisher.Publish(ctx, channel, payload); err != nil {
log.Printf("redis publish failed: %v", err)
}
},
func(err error) {
log.Printf("ws error: %v", err)
},
)
if err != nil {
log.Fatalf("start ws stream failed: %v", err)
}
log.Printf("subscribed binance kline stream: symbol=%s interval=%s channel=%s", symbol, interval, channel)
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sig:
log.Println("received shutdown signal")
close(stopC)
cancel()
case <-doneC:
log.Println("ws stream closed")
cancel()
}
}
func initPublisherFromEnv(ctx context.Context) (RedisPublisher, error) {
addr := strings.TrimSpace(os.Getenv("REDIS_ADDR"))
if addr == "" {
log.Println("REDIS_ADDR not configured, using NoopPublisher")
return &NoopPublisher{}, nil
}
db, err := strconv.Atoi(getEnv("REDIS_DB", "0"))
if err != nil {
return nil, fmt.Errorf("invalid REDIS_DB: %w", err)
}
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: os.Getenv("REDIS_PASSWORD"),
DB: db,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
})
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("redis ping failed: %w", err)
}
log.Printf("connected redis at %s db=%d", addr, db)
return &RedisChannelPublisher{client: client}, nil
}
func getEnv(key, fallback string) string {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
return v
}

24
python-app/Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
ARG DOCKER_MIRROR_PREFIX=m.daocloud.io/docker.io/library
FROM ${DOCKER_MIRROR_PREFIX}/python:3.11-slim-bookworm
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ \
PIP_TRUSTED_HOST=mirrors.aliyun.com \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
&& sed -i 's|security.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
build-essential \
tzdata \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY app ./app
CMD ["python", "app/backtest_runner.py", "--help"]

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
import argparse
import ast
import importlib
import json
from pathlib import Path
from typing import Any, Type
import backtrader as bt
import pandas as pd
from data_protocol import KlineMessage
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Backtrader runner")
parser.add_argument("--input", required=True, help="Path to kline jsonl file")
parser.add_argument(
"--strategy",
default="strategies.sma_cross.SmaCrossStrategy",
help="Strategy class path, e.g. strategies.sma_cross.SmaCrossStrategy",
)
parser.add_argument("--cash", type=float, default=100000.0, help="Initial cash")
parser.add_argument("--commission", type=float, default=0.001, help="Commission")
parser.add_argument(
"--strategy-param",
action="append",
default=[],
help="Strategy parameter in key=value format, repeatable",
)
parser.add_argument("--plot", action="store_true", help="Plot result")
return parser.parse_args()
def load_strategy_class(class_path: str) -> Type[bt.Strategy]:
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
strategy_cls = getattr(module, class_name)
if not issubclass(strategy_cls, bt.Strategy):
raise TypeError(f"{class_path} is not a backtrader.Strategy subclass")
return strategy_cls
def load_dataframe(path: Path) -> pd.DataFrame:
rows = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
msg = KlineMessage.from_dict(payload)
rows.append(
{
"datetime": pd.to_datetime(msg.open_time, unit="ms"),
"open": float(msg.open),
"high": float(msg.high),
"low": float(msg.low),
"close": float(msg.close),
"volume": float(msg.volume),
"openinterest": 0.0,
}
)
if not rows:
raise RuntimeError("no kline rows found for backtest")
frame = pd.DataFrame(rows).set_index("datetime")
return frame.sort_index()
def parse_strategy_params(raw_params: list[str]) -> dict[str, Any]:
parsed: dict[str, Any] = {}
for item in raw_params:
if "=" not in item:
raise ValueError(f"invalid --strategy-param '{item}', expected key=value")
key, raw_value = item.split("=", 1)
key = key.strip()
raw_value = raw_value.strip()
if not key:
raise ValueError(f"invalid --strategy-param '{item}', empty key")
try:
value = ast.literal_eval(raw_value)
except (ValueError, SyntaxError):
value = raw_value
parsed[key] = value
return parsed
def extract_report(result: list[bt.Strategy], final_value: float) -> dict[str, Any]:
strategy = result[0]
drawdown_info = strategy.analyzers.drawdown.get_analysis()
sharpe_info = strategy.analyzers.sharpe.get_analysis()
max_drawdown = drawdown_info.get("max", {}).get("drawdown")
sharpe_ratio = sharpe_info.get("sharperatio")
return {
"final_value": final_value,
"max_drawdown_pct": None if max_drawdown is None else float(max_drawdown),
"sharpe_ratio": None if sharpe_ratio is None else float(sharpe_ratio),
}
def print_report(report: dict[str, Any]) -> None:
print("\n===== Backtest Report =====")
print(f"Final Value : {report['final_value']:.2f}")
max_drawdown = report["max_drawdown_pct"]
sharpe_ratio = report["sharpe_ratio"]
print(f"Max Drawdown (%) : {'N/A' if max_drawdown is None else f'{max_drawdown:.4f}'}")
print(f"Sharpe Ratio : {'N/A' if sharpe_ratio is None else f'{sharpe_ratio:.4f}'}")
print("===========================\n")
def run_backtest() -> None:
args = parse_args()
data_path = Path(args.input)
if not data_path.exists():
raise FileNotFoundError(f"input file not found: {data_path}")
strategy_cls = load_strategy_class(args.strategy)
strategy_params = parse_strategy_params(args.strategy_param)
df = load_dataframe(data_path)
cerebro = bt.Cerebro()
cerebro.addstrategy(strategy_cls, **strategy_params)
cerebro.broker.setcash(args.cash)
cerebro.broker.setcommission(commission=args.commission)
cerebro.adddata(bt.feeds.PandasData(dataname=df))
cerebro.addanalyzer(bt.analyzers.DrawDown, _name="drawdown")
cerebro.addanalyzer(bt.analyzers.SharpeRatio_A, _name="sharpe", riskfreerate=0.0)
print(f"starting portfolio value: {cerebro.broker.getvalue():.2f}")
result = cerebro.run()
final_value = cerebro.broker.getvalue()
print(f"final portfolio value: {final_value:.2f}")
print(f"strategies executed: {len(result)}")
print_report(extract_report(result, final_value))
if args.plot:
cerebro.plot(style="candlestick")
if __name__ == "__main__":
run_backtest()

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class KlineMessage:
type: str
source: str
symbol: str
interval: str
event_time: int
open_time: int
close_time: int
open: str
high: str
low: str
close: str
volume: str
trade_num: int
final: bool
@classmethod
def from_dict(cls, payload: Dict[str, Any]) -> "KlineMessage":
return cls(
type=str(payload["type"]),
source=str(payload["source"]),
symbol=str(payload["symbol"]),
interval=str(payload["interval"]),
event_time=int(payload["event_time"]),
open_time=int(payload["open_time"]),
close_time=int(payload["close_time"]),
open=str(payload["open"]),
high=str(payload["high"]),
low=str(payload["low"]),
close=str(payload["close"]),
volume=str(payload["volume"]),
trade_num=int(payload["trade_num"]),
final=bool(payload["final"]),
)
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.type,
"source": self.source,
"symbol": self.symbol,
"interval": self.interval,
"event_time": self.event_time,
"open_time": self.open_time,
"close_time": self.close_time,
"open": self.open,
"high": self.high,
"low": self.low,
"close": self.close,
"volume": self.volume,
"trade_num": self.trade_num,
"final": self.final,
}

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import argparse
import json
from collections import OrderedDict
import matplotlib.pyplot as plt
import mplfinance as mpf
import pandas as pd
import redis
from data_protocol import KlineMessage
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Realtime K-line viewer from Redis")
parser.add_argument("--redis-host", default="127.0.0.1", help="Redis host")
parser.add_argument("--redis-port", type=int, default=6379, help="Redis port")
parser.add_argument("--redis-db", type=int, default=0, help="Redis DB")
parser.add_argument("--redis-password", default="", help="Redis password")
parser.add_argument("--channel", default="kline.stream", help="Pub/Sub channel")
parser.add_argument("--window", type=int, default=80, help="Window size to render")
parser.add_argument("--only-final", action="store_true", help="Render only closed candles")
parser.add_argument("--timeout", type=float, default=1.0, help="Subscriber timeout (seconds)")
return parser.parse_args()
def upsert_message(buffer: OrderedDict[int, KlineMessage], message: KlineMessage, max_size: int) -> None:
buffer[message.open_time] = message
buffer.move_to_end(message.open_time)
while len(buffer) > max_size:
buffer.popitem(last=False)
def to_dataframe(messages: list[KlineMessage]) -> pd.DataFrame:
rows = []
for msg in messages:
rows.append(
{
"Date": pd.to_datetime(msg.open_time, unit="ms"),
"Open": float(msg.open),
"High": float(msg.high),
"Low": float(msg.low),
"Close": float(msg.close),
"Volume": float(msg.volume),
}
)
frame = pd.DataFrame(rows).set_index("Date")
return frame
def main() -> None:
args = parse_args()
plt.ion()
client = redis.Redis(
host=args.redis_host,
port=args.redis_port,
db=args.redis_db,
password=args.redis_password or None,
decode_responses=True,
)
pubsub = client.pubsub(ignore_subscribe_messages=True)
pubsub.subscribe(args.channel)
print(
f"subscribed redis channel={args.channel} "
f"at {args.redis_host}:{args.redis_port}/{args.redis_db}"
)
candles: OrderedDict[int, KlineMessage] = OrderedDict()
try:
while True:
packet = pubsub.get_message(timeout=args.timeout)
if not packet:
plt.pause(0.01)
continue
raw = packet.get("data")
if not isinstance(raw, str):
continue
payload = json.loads(raw)
message = KlineMessage.from_dict(payload)
if args.only_final and not message.final:
continue
upsert_message(candles, message, args.window)
frame = to_dataframe(list(candles.values()))
if frame.empty:
continue
plt.clf()
mpf.plot(
frame,
type="candle",
style="yahoo",
volume=True,
datetime_format="%H:%M",
tight_layout=True,
)
plt.pause(0.001)
except KeyboardInterrupt:
print("viewer stopped by user")
finally:
pubsub.close()
client.close()
plt.ioff()
plt.show()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,4 @@
from market.config import load_market_config
from market.factory import create_provider
__all__ = ["load_market_config", "create_provider"]

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import os
from dataclasses import dataclass
@dataclass
class MarketConfig:
channel: str
provider: str
cmes_token: str
futu_host: str
futu_port: int
futu_is_encrypt: bool | None
futu_market: str
def load_market_config() -> MarketConfig:
channel = os.getenv("MARKET_CHANNEL", "cn").strip().lower()
provider = os.getenv("MARKET_PROVIDER", "akshare").strip().lower()
cmes_token = os.getenv("CMES_TOKEN", "").strip()
futu_host = os.getenv("FUTU_HOST", "127.0.0.1").strip()
futu_port = int(os.getenv("FUTU_PORT", "11111").strip())
futu_encrypt_raw = os.getenv("FUTU_IS_ENCRYPT", "").strip().lower()
if futu_encrypt_raw in {"1", "true", "yes", "y"}:
futu_is_encrypt = True
elif futu_encrypt_raw in {"0", "false", "no", "n"}:
futu_is_encrypt = False
else:
futu_is_encrypt = None
futu_market = os.getenv("FUTU_MARKET", "").strip().lower()
return MarketConfig(
channel=channel,
provider=provider,
cmes_token=cmes_token,
futu_host=futu_host,
futu_port=futu_port,
futu_is_encrypt=futu_is_encrypt,
futu_market=futu_market,
)

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from market.config import MarketConfig
from market.provider_base import MarketDataProvider
from market.providers.akshare_provider import AkshareCnProvider
from market.providers.cmes_provider import CmesCnProvider
def create_provider(config: MarketConfig) -> MarketDataProvider:
if config.channel == "cn" and config.provider == "akshare":
return AkshareCnProvider()
if config.channel == "cn" and config.provider == "cmesdata":
return CmesCnProvider(token=config.cmes_token)
if config.provider == "futu" and config.channel in {"cn", "hk", "us"}:
from market.providers.futu_provider import FutuProvider
return FutuProvider(
channel=config.channel,
host=config.futu_host,
port=config.futu_port,
is_encrypt=config.futu_is_encrypt,
market=config.futu_market,
)
if config.channel in {"us", "hk"}:
raise RuntimeError(
f"channel={config.channel} provider={config.provider} 尚未实现,"
"请新增 Provider 后接入 factory。"
)
raise RuntimeError(f"不支持的通道配置: channel={config.channel}, provider={config.provider}")

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
import pandas as pd
class MarketDataProvider(ABC):
@property
@abstractmethod
def provider_name(self) -> str:
raise NotImplementedError
@property
@abstractmethod
def channel(self) -> str:
raise NotImplementedError
@abstractmethod
def fetch_spot(self) -> pd.DataFrame:
raise NotImplementedError
@abstractmethod
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
raise NotImplementedError
@abstractmethod
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
raise NotImplementedError

View File

@@ -0,0 +1,12 @@
from market.providers.akshare_provider import AkshareCnProvider
from market.providers.cmes_provider import CmesCnProvider
__all__ = ["AkshareCnProvider", "CmesCnProvider"]
try:
from market.providers.futu_provider import FutuProvider
__all__.append("FutuProvider")
except Exception:
# Allow non-futu environments to keep using other providers.
pass

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from datetime import datetime
import akshare as ak
import pandas as pd
from market.provider_base import MarketDataProvider
class AkshareCnProvider(MarketDataProvider):
@property
def provider_name(self) -> str:
return "akshare"
@property
def channel(self) -> str:
return "cn"
def fetch_spot(self) -> pd.DataFrame:
df = ak.stock_zh_a_spot_em()
if df.empty:
raise RuntimeError("未获取到 A 股实时行情数据")
return df
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
q = query.strip().lower()
if not q:
return pd.DataFrame()
df = self.fetch_spot()
code = df["代码"].astype(str).str.lower()
name = df["名称"].astype(str).str.lower()
exact = df[(code == q) | (name == q)]
if not exact.empty:
return exact.head(limit)
starts = df[code.str.startswith(q) | name.str.startswith(q)]
if not starts.empty:
return starts.head(limit)
return df[code.str.contains(q, na=False) | name.str.contains(q, na=False)].head(limit)
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
hist = ak.stock_zh_a_hist(
symbol=code,
period="daily",
start_date=start.strftime("%Y%m%d"),
end_date=end.strftime("%Y%m%d"),
adjust="qfq",
)
if hist.empty:
raise RuntimeError(f"未获取到 K 线数据: {code}")
frame = pd.DataFrame(
{
"date": pd.to_datetime(hist["日期"]),
"close": pd.to_numeric(hist["收盘"], errors="coerce"),
"volume": pd.to_numeric(hist["成交量"], errors="coerce"),
}
).dropna()
return frame.sort_values("date").reset_index(drop=True)

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
import importlib
import os
from datetime import datetime
import pandas as pd
from market.provider_base import MarketDataProvider
class CmesCnProvider(MarketDataProvider):
def __init__(self, token: str | None = None) -> None:
self._token = token or os.getenv("CMES_TOKEN", "").strip()
self._module = importlib.import_module("cmesdata")
self._login_once()
@property
def provider_name(self) -> str:
return "cmesdata"
@property
def channel(self) -> str:
return "cn"
def _login_once(self) -> None:
if not self._token:
raise RuntimeError("CMES_TOKEN 未配置,无法使用 cmesdata 通道")
self._module.login(self._token)
@staticmethod
def _to_prefixed_code(code: str) -> str:
raw = code.strip().upper().replace(".", "")
if raw.startswith("SH") or raw.startswith("SZ"):
return f"{raw[:2]}.{raw[2:]}"
if raw.isdigit() and len(raw) == 6:
if raw.startswith(("6", "9")):
return f"SH.{raw}"
return f"SZ.{raw}"
raise RuntimeError("cmesdata 通道仅支持 6 位 A 股代码或 SH./SZ. 前缀代码")
def fetch_spot(self) -> pd.DataFrame:
raise RuntimeError("cmesdata 不支持全市场快照拉取,请通过精确代码查询")
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
_ = limit
code = self._to_prefixed_code(query)
df = self._module.get_real_hq([code])
if df is None or df.empty:
raise RuntimeError(f"未获取到实时行情: {code}")
return df
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
prefixed = self._to_prefixed_code(code)
df = self._module.get_history_data(
prefixed,
start.strftime("%Y-%m-%d"),
end.strftime("%Y-%m-%d"),
"D",
)
if df is None or df.empty:
raise RuntimeError(f"未获取到历史 K 线: {prefixed}")
frame = pd.DataFrame(
{
"date": pd.to_datetime(df["时间"]),
"close": pd.to_numeric(df["收盘价"], errors="coerce"),
"volume": pd.to_numeric(df["成交量"], errors="coerce"),
}
).dropna()
return frame.sort_values("date").reset_index(drop=True)

View File

@@ -0,0 +1,207 @@
from __future__ import annotations
from datetime import datetime
import pandas as pd
from futu import AuType, KLType, Market, OpenQuoteContext, RET_OK, SecurityType
from market.provider_base import MarketDataProvider
class FutuProvider(MarketDataProvider):
def __init__(
self,
channel: str,
host: str = "127.0.0.1",
port: int = 11111,
is_encrypt: bool | None = None,
market: str = "",
) -> None:
self._channel = channel.strip().lower()
self._host = host
self._port = int(port)
self._is_encrypt = is_encrypt
self._market = (market or self._channel).strip().lower()
self._ctx = OpenQuoteContext(host=self._host, port=self._port, is_encrypt=self._is_encrypt)
self._basicinfo_cache: pd.DataFrame | None = None
@property
def provider_name(self) -> str:
return "futu"
@property
def channel(self) -> str:
return self._channel
def __del__(self) -> None:
try:
self._ctx.close()
except Exception:
pass
def _require_market(self) -> Market:
if self._market == "cn":
return Market.SH
if self._market == "hk":
return Market.HK
if self._market == "us":
return Market.US
raise RuntimeError(f"不支持的 FUTU_MARKET: {self._market},可选: cn/hk/us")
def _normalize_code(self, query: str) -> str:
q = query.strip().upper()
if not q:
return ""
if "." in q:
return q
if self._market == "cn":
if q.isdigit() and len(q) == 6:
prefix = "SH" if q.startswith(("5", "6", "9")) else "SZ"
return f"{prefix}.{q}"
return q
if self._market == "hk":
if q.isdigit():
return f"HK.{q.zfill(5)}"
return f"HK.{q}"
if self._market == "us":
return f"US.{q}"
return q
def _snapshot_to_unified(self, df: pd.DataFrame) -> pd.DataFrame:
frame = df.copy()
if "code" not in frame.columns:
raise RuntimeError("Futu 快照返回缺少 code 字段")
if "name" not in frame.columns:
frame["name"] = ""
frame["涨跌额"] = pd.to_numeric(frame.get("last_price"), errors="coerce") - pd.to_numeric(
frame.get("prev_close_price"), errors="coerce"
)
frame["市盈率-动态"] = frame.get("pe_ttm_ratio", frame.get("pe_ratio"))
frame["市净率"] = frame.get("pb_ratio")
frame["总市值"] = frame.get("total_market_val")
frame["流通市值"] = frame.get("circular_market_val")
frame["代码"] = frame["code"].astype(str)
frame["名称"] = frame["name"]
frame["最新价"] = frame.get("last_price")
frame["涨跌幅"] = frame.get("change_rate")
frame["成交量"] = frame.get("volume")
frame["成交额"] = frame.get("turnover")
frame["振幅"] = frame.get("amplitude")
frame["最高"] = frame.get("high_price")
frame["最低"] = frame.get("low_price")
frame["今开"] = frame.get("open_price")
frame["昨收"] = frame.get("prev_close_price")
columns = [
"代码",
"名称",
"最新价",
"涨跌幅",
"涨跌额",
"成交量",
"成交额",
"振幅",
"最高",
"最低",
"今开",
"昨收",
"市盈率-动态",
"市净率",
"总市值",
"流通市值",
]
return frame[columns]
def _get_snapshot(self, codes: list[str]) -> pd.DataFrame:
ret, data = self._ctx.get_market_snapshot(codes)
if ret != RET_OK:
raise RuntimeError(f"Futu 获取快照失败: {data}")
if data is None or data.empty:
return pd.DataFrame()
return self._snapshot_to_unified(data)
def _load_basicinfo(self) -> pd.DataFrame:
if self._basicinfo_cache is not None:
return self._basicinfo_cache
market = self._require_market()
ret, data = self._ctx.get_stock_basicinfo(market=market, stock_type=SecurityType.STOCK)
if ret != RET_OK:
raise RuntimeError(f"Futu 获取股票基础信息失败: {data}")
if data is None:
self._basicinfo_cache = pd.DataFrame(columns=["code", "name"])
return self._basicinfo_cache
frame = data.copy()
frame["code"] = frame["code"].astype(str)
frame["name"] = frame["name"].astype(str)
self._basicinfo_cache = frame
return frame
def fetch_spot(self) -> pd.DataFrame:
raise RuntimeError("Futu 不支持直接拉取全市场实时快照,请使用 search_spot")
def search_spot(self, query: str, limit: int) -> pd.DataFrame:
q = query.strip()
if not q:
return pd.DataFrame()
code = self._normalize_code(q)
if code:
exact = self._get_snapshot([code])
if not exact.empty:
return exact.head(limit)
basic = self._load_basicinfo()
q_lower = q.lower()
code_col = basic["code"].astype(str)
name_col = basic["name"].astype(str)
mask = (
code_col.str.lower().eq(q_lower)
| name_col.str.lower().eq(q_lower)
| code_col.str.lower().str.startswith(q_lower)
| name_col.str.lower().str.startswith(q_lower)
| code_col.str.lower().str.contains(q_lower, na=False)
| name_col.str.lower().str.contains(q_lower, na=False)
)
candidates = basic.loc[mask, "code"].drop_duplicates().head(max(limit * 2, 20)).tolist()
if not candidates:
return pd.DataFrame()
snap = self._get_snapshot(candidates)
if snap.empty:
return pd.DataFrame()
return snap.head(limit)
def fetch_daily_kline(self, code: str, start: datetime, end: datetime) -> pd.DataFrame:
normalized = self._normalize_code(code)
page_key = None
frames: list[pd.DataFrame] = []
while True:
ret, data, page_key = self._ctx.request_history_kline(
normalized,
start=start.strftime("%Y-%m-%d"),
end=end.strftime("%Y-%m-%d"),
ktype=KLType.K_DAY,
autype=AuType.QFQ,
max_count=1000,
page_req_key=page_key,
)
if ret != RET_OK:
raise RuntimeError(f"Futu 获取历史 K 线失败: {data}")
if data is not None and not data.empty:
frames.append(data.copy())
if page_key is None:
break
if not frames:
raise RuntimeError(f"未获取到 K 线数据: {normalized}")
full = pd.concat(frames, ignore_index=True)
frame = pd.DataFrame(
{
"date": pd.to_datetime(full["time_key"]),
"close": pd.to_numeric(full["close"], errors="coerce"),
"volume": pd.to_numeric(full["volume"], errors="coerce"),
}
).dropna()
if frame.empty:
raise RuntimeError(f"K 线数据为空: {normalized}")
return frame.sort_values("date").reset_index(drop=True)

View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any
from time import time
import pandas as pd
from market.provider_base import MarketDataProvider
DISPLAY_FIELDS = [
"代码",
"名称",
"最新价",
"涨跌幅",
"涨跌额",
"成交量",
"成交额",
"振幅",
"最高",
"最低",
"今开",
"昨收",
"市盈率-动态",
"市净率",
"总市值",
"流通市值",
]
def _safe(v: Any) -> Any:
if v is None:
return None
if isinstance(v, float) and pd.isna(v):
return None
return v
def build_signals(hist: pd.DataFrame, sma_fast: int, sma_slow: int) -> list[dict[str, Any]]:
frame = hist.copy()
frame["sma_fast"] = frame["close"].rolling(sma_fast).mean()
frame["sma_slow"] = frame["close"].rolling(sma_slow).mean()
signals: list[dict[str, Any]] = []
for i in range(1, len(frame)):
prev = frame.iloc[i - 1]
curr = frame.iloc[i]
if pd.isna(prev["sma_fast"]) or pd.isna(prev["sma_slow"]) or pd.isna(curr["sma_fast"]) or pd.isna(curr["sma_slow"]):
continue
cross_up = prev["sma_fast"] <= prev["sma_slow"] and curr["sma_fast"] > curr["sma_slow"]
cross_down = prev["sma_fast"] >= prev["sma_slow"] and curr["sma_fast"] < curr["sma_slow"]
if cross_up:
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "buy", "price": float(curr["close"])})
elif cross_down:
signals.append({"date": curr["date"].strftime("%Y-%m-%d"), "type": "sell", "price": float(curr["close"])})
return signals
def resolve_candidates(provider: MarketDataProvider, query: str, limit: int) -> pd.DataFrame:
return provider.search_spot(query=query, limit=limit)
def _pick_first_candidate(provider: MarketDataProvider, query: str) -> pd.Series:
candidates = resolve_candidates(provider, query=query, limit=1)
if candidates.empty:
raise RuntimeError(f"未找到匹配股票: {query}")
row = candidates.iloc[0]
code = str(row.get("代码", "")).strip()
if not code:
raise RuntimeError("行情返回缺少代码字段")
return row
def _build_info(provider: MarketDataProvider, row: pd.Series) -> dict[str, Any]:
code = str(row.get("代码", "")).strip()
return {
"code": code,
"name": str(row.get("名称", "")),
"price": _safe(row.get("最新价")),
"change_pct": _safe(row.get("涨跌幅")),
"change": _safe(row.get("涨跌额")),
"volume": _safe(row.get("成交量")),
"turnover": _safe(row.get("成交额")),
"amplitude": _safe(row.get("振幅")),
"high": _safe(row.get("最高")),
"low": _safe(row.get("最低")),
"open": _safe(row.get("今开")),
"prev_close": _safe(row.get("昨收")),
"pe": _safe(row.get("市盈率-动态")),
"pb": _safe(row.get("市净率")),
"market_cap": _safe(row.get("总市值")),
"float_market_cap": _safe(row.get("流通市值")),
"provider": provider.provider_name,
"channel": provider.channel,
}
def build_realtime_info(provider: MarketDataProvider, query: str) -> dict[str, Any]:
row = _pick_first_candidate(provider, query=query)
info = _build_info(provider=provider, row=row)
now = datetime.now()
price = info.get("price")
volume = info.get("volume")
realtime_point = {
"date": now.strftime("%Y-%m-%d"),
"close": float(price) if price is not None else None,
"volume": float(volume) if volume is not None else 0.0,
}
return {
"info": info,
"realtime_point": realtime_point,
"source": "realtime",
"updated_at": now.strftime("%Y-%m-%d %H:%M:%S"),
}
_DASHBOARD_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
_DASHBOARD_CACHE_TTL_SECONDS = 15.0
def build_dashboard(
provider: MarketDataProvider,
query: str,
days: int,
sma_fast: int,
sma_slow: int,
) -> dict[str, Any]:
if sma_fast >= sma_slow:
raise RuntimeError("sma_fast must be less than sma_slow")
row = _pick_first_candidate(provider, query=query)
code = str(row.get("代码", "")).strip()
cache_key = "|".join([provider.provider_name, provider.channel, code, str(days), str(sma_fast), str(sma_slow)])
now = time()
cached = _DASHBOARD_CACHE.get(cache_key)
if cached is not None and now - cached[0] <= _DASHBOARD_CACHE_TTL_SECONDS:
return cached[1]
end = datetime.now()
lookback_days = max(days + (sma_slow * 3), days + 30)
start = end - timedelta(days=lookback_days)
hist = provider.fetch_daily_kline(code=code, start=start, end=end).tail(days).reset_index(drop=True)
if hist.empty:
raise RuntimeError(f"未获取到 K 线数据: {code}")
signals = build_signals(hist, sma_fast=sma_fast, sma_slow=sma_slow)
hist["sma_fast"] = hist["close"].rolling(sma_fast).mean()
hist["sma_slow"] = hist["close"].rolling(sma_slow).mean()
points = [
{
"date": d.strftime("%Y-%m-%d"),
"close": float(c),
"sma_fast": _safe(f),
"sma_slow": _safe(s),
"volume": float(v),
}
for d, c, f, s, v in zip(hist["date"], hist["close"], hist["sma_fast"], hist["sma_slow"], hist["volume"])
]
result = {"info": _build_info(provider=provider, row=row), "points": points, "signals": signals}
_DASHBOARD_CACHE[cache_key] = (now, result)
return result

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@dataclass
class SpotRow:
code: str
name: str
price: Any
change_pct: Any
change: Any
volume: Any
turnover: Any
amplitude: Any
high: Any
low: Any
open: Any
prev_close: Any
pe: Any
pb: Any
market_cap: Any
float_market_cap: Any
@dataclass
class KlineRow:
date: datetime
close: float
volume: float

View File

@@ -0,0 +1,20 @@
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430400000,"open_time":1711430400000,"close_time":1711430459999,"open":"68000.00","high":"68025.00","low":"67990.00","close":"68010.00","volume":"12.40","trade_num":102,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430460000,"open_time":1711430460000,"close_time":1711430519999,"open":"68010.00","high":"68040.00","low":"68000.00","close":"68028.00","volume":"10.12","trade_num":96,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430520000,"open_time":1711430520000,"close_time":1711430579999,"open":"68028.00","high":"68036.00","low":"68005.00","close":"68012.00","volume":"8.80","trade_num":87,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430580000,"open_time":1711430580000,"close_time":1711430639999,"open":"68012.00","high":"68022.00","low":"67980.00","close":"67992.00","volume":"11.50","trade_num":120,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430640000,"open_time":1711430640000,"close_time":1711430699999,"open":"67992.00","high":"68018.00","low":"67970.00","close":"68005.00","volume":"9.64","trade_num":91,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430700000,"open_time":1711430700000,"close_time":1711430759999,"open":"68005.00","high":"68055.00","low":"68002.00","close":"68048.00","volume":"14.03","trade_num":133,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430760000,"open_time":1711430760000,"close_time":1711430819999,"open":"68048.00","high":"68070.00","low":"68030.00","close":"68062.00","volume":"13.38","trade_num":121,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430820000,"open_time":1711430820000,"close_time":1711430879999,"open":"68062.00","high":"68080.00","low":"68040.00","close":"68044.00","volume":"10.99","trade_num":99,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430880000,"open_time":1711430880000,"close_time":1711430939999,"open":"68044.00","high":"68046.00","low":"68000.00","close":"68006.00","volume":"12.75","trade_num":115,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711430940000,"open_time":1711430940000,"close_time":1711430999999,"open":"68006.00","high":"68012.00","low":"67972.00","close":"67986.00","volume":"11.66","trade_num":105,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431000000,"open_time":1711431000000,"close_time":1711431059999,"open":"67986.00","high":"68008.00","low":"67970.00","close":"67996.00","volume":"9.21","trade_num":88,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431060000,"open_time":1711431060000,"close_time":1711431119999,"open":"67996.00","high":"68035.00","low":"67992.00","close":"68020.00","volume":"10.77","trade_num":94,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431120000,"open_time":1711431120000,"close_time":1711431179999,"open":"68020.00","high":"68058.00","low":"68018.00","close":"68052.00","volume":"12.62","trade_num":110,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431180000,"open_time":1711431180000,"close_time":1711431239999,"open":"68052.00","high":"68066.00","low":"68030.00","close":"68042.00","volume":"8.19","trade_num":76,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431240000,"open_time":1711431240000,"close_time":1711431299999,"open":"68042.00","high":"68052.00","low":"68008.00","close":"68016.00","volume":"10.40","trade_num":93,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431300000,"open_time":1711431300000,"close_time":1711431359999,"open":"68016.00","high":"68018.00","low":"67986.00","close":"67998.00","volume":"9.07","trade_num":84,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431360000,"open_time":1711431360000,"close_time":1711431419999,"open":"67998.00","high":"68012.00","low":"67974.00","close":"67988.00","volume":"7.55","trade_num":69,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431420000,"open_time":1711431420000,"close_time":1711431479999,"open":"67988.00","high":"68028.00","low":"67982.00","close":"68024.00","volume":"11.88","trade_num":101,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431480000,"open_time":1711431480000,"close_time":1711431539999,"open":"68024.00","high":"68070.00","low":"68020.00","close":"68062.00","volume":"13.45","trade_num":124,"final":true}
{"type":"kline","source":"binance_ws","symbol":"btcusdt","interval":"1m","event_time":1711431540000,"open_time":1711431540000,"close_time":1711431599999,"open":"68062.00","high":"68082.00","low":"68048.00","close":"68074.00","volume":"12.03","trade_num":111,"final":true}

View File

@@ -0,0 +1,91 @@
from __future__ import annotations
import argparse
import pandas as pd
from market import create_provider, load_market_config
from market.service import DISPLAY_FIELDS, resolve_candidates
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Stock quick lookup (A-share)")
parser.add_argument("query", nargs="?", default="", help="stock code/name/abbr")
parser.add_argument("--limit", type=int, default=8, help="max candidate rows")
return parser.parse_args()
def safe_str(value: object) -> str:
if value is None:
return ""
if isinstance(value, float) and pd.isna(value):
return ""
return str(value).strip()
def print_row(row: pd.Series) -> None:
print("\n===== 股票信息 =====")
for field in DISPLAY_FIELDS:
if field in row.index:
print(f"{field}: {safe_str(row[field])}")
print("===================\n")
def pick_candidate(candidates: pd.DataFrame) -> pd.Series | None:
if len(candidates) == 1:
return candidates.iloc[0]
print("\n匹配到多个标的,请选择:")
for idx, (_, row) in enumerate(candidates.iterrows(), start=1):
print(f"{idx}. {safe_str(row['代码'])} {safe_str(row['名称'])} 最新价={safe_str(row.get('最新价'))}")
choice = input("输入序号(直接回车默认 1): ").strip()
if not choice:
return candidates.iloc[0]
if not choice.isdigit():
print("输入无效")
return None
pos = int(choice)
if pos < 1 or pos > len(candidates):
print("输入超出范围")
return None
return candidates.iloc[pos - 1]
def run_once(query: str, limit: int) -> None:
provider = create_provider(load_market_config())
candidates = resolve_candidates(provider, query, limit)
if candidates.empty:
print(f"未找到匹配股票: {query}")
return
selected = pick_candidate(candidates)
if selected is None:
return
print_row(selected)
def interactive_loop(limit: int) -> None:
print("股票速查已启动,输入代码/名称(输入 q 退出)")
while True:
query = input("查询> ").strip()
if query.lower() in {"q", "quit", "exit"}:
print("已退出股票速查")
return
if not query:
continue
try:
run_once(query, limit)
except Exception as exc: # noqa: BLE001
print(f"查询失败: {exc}")
def main() -> None:
args = parse_args()
if not args.query:
interactive_loop(args.limit)
return
run_once(args.query, args.limit)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from .sma_cross import SmaCrossStrategy
__all__ = ["SmaCrossStrategy"]

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
import backtrader as bt
class SmaCrossStrategy(bt.Strategy):
params = (
("fast", 3),
("slow", 8),
("stake", 1),
)
def __init__(self) -> None:
self.fast_sma = bt.indicators.SMA(self.data.close, period=self.params.fast)
self.slow_sma = bt.indicators.SMA(self.data.close, period=self.params.slow)
self.cross_signal = bt.indicators.CrossOver(self.fast_sma, self.slow_sma)
def next(self) -> None:
if not self.position and self.cross_signal > 0:
self.buy(size=self.params.stake)
elif self.position and self.cross_signal < 0:
self.close()

View File

@@ -0,0 +1,211 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>AITrading 股票看板</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.3/dist/chart.umd.min.js"></script>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Arial, sans-serif; margin: 20px; background:#f7f8fa; color:#1f2937; }
.card { background:#fff; border-radius:10px; box-shadow:0 2px 8px rgba(0,0,0,.06); padding:16px; margin-bottom:16px; }
.row { display:flex; gap:12px; flex-wrap:wrap; }
input, button { padding:8px 10px; border-radius:8px; border:1px solid #d1d5db; }
button { background:#2563eb; color:#fff; border:none; cursor:pointer; }
.grid { display:grid; grid-template-columns: repeat(3, minmax(180px,1fr)); gap:8px; }
.item { background:#f9fafb; border-radius:8px; padding:8px; }
.label { color:#6b7280; font-size:12px; }
.value { font-size:14px; font-weight:600; }
#status { font-size:13px; color:#6b7280; }
</style>
</head>
<body>
<div class="card">
<h2>股票信息与买卖点</h2>
<div class="row">
<input id="query" placeholder="输入股票代码或名称,如 600519 / 贵州茅台" style="min-width:300px;" />
<input id="fast" type="number" min="2" max="60" value="5" />
<input id="slow" type="number" min="3" max="120" value="20" />
<button onclick="searchStock()">查询</button>
</div>
<p id="status">请输入代码或名称并点击查询。</p>
</div>
<div class="card">
<div id="info" class="grid"></div>
</div>
<div class="card">
<canvas id="chart" height="120"></canvas>
</div>
<script>
let chart;
let realtimeWs = null;
let currentPoints = [];
let currentSignals = [];
let currentFast = 5;
let currentSlow = 20;
function renderInfo(info) {
const fields = [
["代码", info.code], ["名称", info.name], ["最新价", info.price],
["涨跌幅", info.change_pct], ["涨跌额", info.change], ["成交额", info.turnover],
["PE(动态)", info.pe], ["PB", info.pb], ["总市值", info.market_cap]
];
const container = document.getElementById("info");
container.innerHTML = fields.map(([k, v]) =>
`<div class="item"><div class="label">${k}</div><div class="value">${v ?? "-"}</div></div>`
).join("");
}
function buildSignals(points, fast, slow) {
const signals = [];
for (let i = 1; i < points.length; i++) {
const prev = points[i - 1];
const curr = points[i];
if (prev.sma_fast == null || prev.sma_slow == null || curr.sma_fast == null || curr.sma_slow == null) continue;
const crossUp = prev.sma_fast <= prev.sma_slow && curr.sma_fast > curr.sma_slow;
const crossDown = prev.sma_fast >= prev.sma_slow && curr.sma_fast < curr.sma_slow;
if (crossUp) signals.push({ date: curr.date, type: "buy", price: curr.close });
else if (crossDown) signals.push({ date: curr.date, type: "sell", price: curr.close });
}
return signals;
}
function recomputeSma(points, fast, slow) {
const sumFast = [];
const sumSlow = [];
for (let i = 0; i < points.length; i++) {
sumFast[i] = (sumFast[i - 1] || 0) + Number(points[i].close || 0);
sumSlow[i] = (sumSlow[i - 1] || 0) + Number(points[i].close || 0);
if (i >= fast - 1) {
const prev = i - fast >= 0 ? sumFast[i - fast] : 0;
points[i].sma_fast = (sumFast[i] - prev) / fast;
} else {
points[i].sma_fast = null;
}
if (i >= slow - 1) {
const prev = i - slow >= 0 ? sumSlow[i - slow] : 0;
points[i].sma_slow = (sumSlow[i] - prev) / slow;
} else {
points[i].sma_slow = null;
}
}
}
function upsertRealtimePoint(point) {
if (!point || point.close == null || currentPoints.length === 0) return;
const last = currentPoints[currentPoints.length - 1];
if (last.date === point.date) {
last.close = Number(point.close);
if (point.volume != null) last.volume = Number(point.volume);
} else if (point.date > last.date) {
currentPoints.push({
date: point.date,
close: Number(point.close),
volume: Number(point.volume || 0),
sma_fast: null,
sma_slow: null,
});
}
recomputeSma(currentPoints, currentFast, currentSlow);
currentSignals = buildSignals(currentPoints, currentFast, currentSlow);
drawChart(currentPoints, currentSignals);
}
function drawChart(points, signals) {
const labels = points.map(p => p.date);
const close = points.map(p => p.close);
const fast = points.map(p => p.sma_fast);
const slow = points.map(p => p.sma_slow);
const buyPoints = signals.filter(s => s.type === "buy").map(s => ({x: s.date, y: s.price}));
const sellPoints = signals.filter(s => s.type === "sell").map(s => ({x: s.date, y: s.price}));
if (!chart) {
chart = new Chart(document.getElementById("chart"), {
type: "line",
data: {
labels,
datasets: [
{ label: "收盘价", data: close, borderColor: "#1f2937", tension: 0.2, pointRadius: 0 },
{ label: "SMA Fast", data: fast, borderColor: "#2563eb", tension: 0.2, pointRadius: 0 },
{ label: "SMA Slow", data: slow, borderColor: "#16a34a", tension: 0.2, pointRadius: 0 },
{ label: "买点", data: buyPoints, parsing: {xAxisKey: "x", yAxisKey: "y"}, showLine:false, pointRadius:5, pointStyle:"triangle", pointBackgroundColor:"#dc2626", pointBorderColor:"#dc2626" },
{ label: "卖点", data: sellPoints, parsing: {xAxisKey: "x", yAxisKey: "y"}, showLine:false, pointRadius:5, pointStyle:"rectRot", pointBackgroundColor:"#7c3aed", pointBorderColor:"#7c3aed" }
]
},
options: {
responsive: true,
animation: false,
scales: { x: { ticks: { maxTicksLimit: 10 } } }
}
});
return;
}
chart.data.labels = labels;
chart.data.datasets[0].data = close;
chart.data.datasets[1].data = fast;
chart.data.datasets[2].data = slow;
chart.data.datasets[3].data = buyPoints;
chart.data.datasets[4].data = sellPoints;
chart.update("none");
}
async function searchStock() {
const query = document.getElementById("query").value.trim();
const fast = Number(document.getElementById("fast").value);
const slow = Number(document.getElementById("slow").value);
if (!query) return;
const status = document.getElementById("status");
status.innerText = "正在建立实时 WebSocket 连接...";
try {
if (realtimeWs) {
realtimeWs.close();
realtimeWs = null;
}
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
realtimeWs = new WebSocket(`${protocol}://${window.location.host}/ws/stock/realtime?query=${encodeURIComponent(query)}`);
realtimeWs.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.error) {
status.innerText = `实时通道错误: ${data.error}`;
return;
}
renderInfo(data.info);
upsertRealtimePoint(data.realtime_point);
status.innerText = `实时更新中:${data.info.code} ${data.info.name}${data.updated_at}`;
};
realtimeWs.onclose = () => {
if (status.innerText.startsWith("实时更新中")) {
status.innerText = "实时连接已断开";
}
};
realtimeWs.onerror = () => {
status.innerText = "实时连接失败";
};
const res = await fetch(`/api/stock?query=${encodeURIComponent(query)}&sma_fast=${fast}&sma_slow=${slow}`);
const data = await res.json();
if (!res.ok) throw new Error(data.detail || "查询失败");
currentFast = fast;
currentSlow = slow;
currentPoints = data.points.map(p => ({
date: p.date,
close: Number(p.close),
sma_fast: p.sma_fast == null ? null : Number(p.sma_fast),
sma_slow: p.sma_slow == null ? null : Number(p.sma_slow),
volume: p.volume == null ? 0 : Number(p.volume),
}));
currentSignals = data.signals;
renderInfo(data.info);
drawChart(currentPoints, currentSignals);
status.innerText = `历史K线加载完成${data.info.code} ${data.info.name},买点 ${data.signals.filter(s=>s.type==="buy").length},卖点 ${data.signals.filter(s=>s.type==="sell").length}(实时通道已连接)`;
} catch (err) {
status.innerText = `错误: ${err.message}`;
}
}
</script>
</body>
</html>

82
python-app/app/web_app.py Normal file
View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from market import create_provider, load_market_config
from market.service import build_dashboard, build_realtime_info
app = FastAPI(title="AITrading Simple Frontend API", version="1.0.0")
BASE_DIR = Path(__file__).resolve().parent
WEB_INDEX = BASE_DIR / "web" / "index.html"
@app.get("/")
def index() -> FileResponse:
if not WEB_INDEX.exists():
raise HTTPException(status_code=500, detail="frontend page not found")
return FileResponse(WEB_INDEX)
@app.get("/api/stock")
def stock_dashboard(
query: str = Query(..., description="stock code or name"),
days: int = Query(120, ge=30, le=365),
sma_fast: int = Query(5, ge=2, le=60),
sma_slow: int = Query(20, ge=3, le=120),
) -> dict[str, Any]:
try:
provider = create_provider(load_market_config())
return build_dashboard(
provider=provider,
query=query,
days=days,
sma_fast=sma_fast,
sma_slow=sma_slow,
)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/api/stock/realtime")
def stock_realtime(
query: str = Query(..., description="stock code or name"),
) -> dict[str, Any]:
try:
provider = create_provider(load_market_config())
return build_realtime_info(provider=provider, query=query)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.websocket("/ws/stock/realtime")
async def stock_realtime_ws(websocket: WebSocket) -> None:
await websocket.accept()
query = (websocket.query_params.get("query") or "").strip()
if not query:
await websocket.send_json({"error": "query is required"})
await websocket.close(code=1008)
return
try:
while True:
provider = create_provider(load_market_config())
payload = build_realtime_info(provider=provider, query=query)
await websocket.send_json(payload)
# WebSocket channel keeps the connection alive, front-end no longer polls via HTTP.
await asyncio.sleep(2)
except WebSocketDisconnect:
return
except Exception as exc: # noqa: BLE001
await websocket.send_json({"error": str(exc)})
await websocket.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("web_app:app", host="0.0.0.0", port=8000, reload=False)

View File

@@ -0,0 +1,10 @@
pandas
mplfinance
matplotlib
backtrader
redis
akshare
fastapi
uvicorn
futu-api
websockets

44
shared/protocol.md Normal file
View File

@@ -0,0 +1,44 @@
# Kline JSON Protocol (Go <-> Python)
用于 Go 行情采集端与 Python 展示/回测端的数据交换。
## 标准消息格式
```json
{
"type": "kline",
"source": "binance_ws",
"symbol": "btcusdt",
"interval": "1m",
"event_time": 1711430400000,
"open_time": 1711430400000,
"close_time": 1711430459999,
"open": "68000.00",
"high": "68025.00",
"low": "67990.00",
"close": "68010.00",
"volume": "12.40",
"trade_num": 102,
"final": true
}
```
## 字段说明
- `type`: 消息类型,固定 `kline`
- `source`: 数据来源,默认 `binance_ws`
- `symbol`: 交易对(小写,如 `btcusdt`
- `interval`: K 线周期(如 `1m``5m``1h`
- `event_time`: 事件时间(毫秒时间戳)
- `open_time`: K 线起始时间(毫秒时间戳)
- `close_time`: K 线结束时间(毫秒时间戳)
- `open/high/low/close`: 价格字段,使用字符串保证精度
- `volume`: 成交量,字符串
- `trade_num`: 成交笔数,整数
- `final`: 是否为收盘完成 K 线(`true` 表示该根 K 线已闭合)
## 约束约定
- 时间戳统一为 UTC 毫秒
- 数值字段在传输层统一为字符串trade_num 除外),消费端按需转换
- 数据可通过 Redis Pub/Sub、文件流jsonl或消息队列进行传输

240
start.sh Executable file
View File

@@ -0,0 +1,240 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
ENV_FILE="${ENV_FILE:-.env}"
PROJECT_NAME="${PROJECT_NAME:-ai-trading}"
if ! command -v docker >/dev/null 2>&1; then
echo "[ERROR] docker 未安装,请先安装 Docker。"
exit 1
fi
if ! docker compose version >/dev/null 2>&1; then
echo "[ERROR] docker compose 不可用,请升级 Docker Desktop 或安装 compose 插件。"
exit 1
fi
if [[ ! -f "$COMPOSE_FILE" ]]; then
echo "[ERROR] 未找到 $COMPOSE_FILE"
exit 1
fi
MODE="${1:-local}"
ACTION="${2:-up}"
QUERY="${3:-}"
compose_cmd=(
docker compose
-p "$PROJECT_NAME"
-f "$COMPOSE_FILE"
)
if [[ -n "$ENV_FILE" ]]; then
if [[ -f "$ENV_FILE" ]]; then
compose_cmd+=(--env-file "$ENV_FILE")
else
echo "[WARN] 未找到 env 文件: ${ENV_FILE}, 已跳过 --env-file"
fi
fi
show_help() {
cat <<'EOF'
用法:
./start.sh [mode] [action]
mode:
local 本地开发(默认)
prod 生产部署
action:
up 启动服务(默认)
down 停止并删除容器/网络
restart 重启服务
quick 一键清理并启动cleanup + up + web
logs 查看日志Ctrl+C 退出)
status 查看服务状态
pull 拉取基础镜像
build 构建镜像
quote 股票速查(代码/名称)
web 启动前端页面(信息+曲线+买卖点)
cleanup 清理历史遗留容器/项目(推荐迁移后执行一次)
示例:
./start.sh local up
./start.sh local logs
./start.sh local quote 600519
./start.sh local quote 贵州茅台
./start.sh local web
./start.sh prod up
./start.sh prod restart
可选环境变量:
COMPOSE_FILE 指定 compose 文件(默认 docker-compose.yml
ENV_FILE 指定环境变量文件(默认 .env
PROJECT_NAME compose 项目名(默认 ai-trading
EOF
}
run_local() {
case "$ACTION" in
up)
"${compose_cmd[@]}" up --build -d redis go-service
;;
down)
"${compose_cmd[@]}" down
;;
restart)
"${compose_cmd[@]}" down
"${compose_cmd[@]}" up --build -d
;;
quick)
"${compose_cmd[@]}" down --remove-orphans || true
docker compose -p aitrading -f "$COMPOSE_FILE" down --remove-orphans 2>/dev/null || true
stale_run_ids="$(docker ps -aq --filter "name=^ai-trading-python-app-run-")"
if [[ -n "$stale_run_ids" ]]; then
docker rm -f $stale_run_ids >/dev/null 2>&1 || true
fi
docker rm -f ai-trading-redis >/dev/null 2>&1 || true
"${compose_cmd[@]}" up --build -d redis go-service web
echo "[INFO] Quick 启动完成: http://localhost:8000"
;;
logs)
"${compose_cmd[@]}" logs -f --tail=200
;;
status)
"${compose_cmd[@]}" ps
;;
pull)
"${compose_cmd[@]}" pull || true
;;
build)
"${compose_cmd[@]}" build
;;
quote)
if [[ -z "$QUERY" ]]; then
"${compose_cmd[@]}" run --rm --no-deps python-app python app/stock_lookup.py
else
"${compose_cmd[@]}" run --rm --no-deps python-app python app/stock_lookup.py "$QUERY"
fi
;;
web)
"${compose_cmd[@]}" up --build -d web
echo "[INFO] Web 已启动: http://localhost:8000"
;;
cleanup)
# 1) 优先清理当前 compose 项目里的孤儿资源
"${compose_cmd[@]}" down --remove-orphans || true
# 2) 清理历史旧项目名(未显式 -p 时常见)
docker compose -p aitrading -f "$COMPOSE_FILE" down --remove-orphans 2>/dev/null || true
# 3) 清理遗留临时容器(例如 ai-trading-python-app-run-xxxxx
stale_run_ids="$(docker ps -aq --filter "name=^ai-trading-python-app-run-")"
if [[ -n "$stale_run_ids" ]]; then
docker rm -f $stale_run_ids >/dev/null 2>&1 || true
fi
# 4) 清理曾经手动起的固定 redis 容器
docker rm -f ai-trading-redis >/dev/null 2>&1 || true
echo "[INFO] 清理完成,建议执行: ./start.sh local up && ./start.sh local web"
;;
*)
echo "[ERROR] 不支持的 action: $ACTION"
show_help
exit 1
;;
esac
}
run_prod() {
case "$ACTION" in
up)
"${compose_cmd[@]}" pull || true
"${compose_cmd[@]}" up -d --build --remove-orphans redis go-service
;;
down)
"${compose_cmd[@]}" down
;;
restart)
"${compose_cmd[@]}" up -d --build --force-recreate --remove-orphans
;;
quick)
"${compose_cmd[@]}" down --remove-orphans || true
docker compose -p aitrading -f "$COMPOSE_FILE" down --remove-orphans 2>/dev/null || true
stale_run_ids="$(docker ps -aq --filter "name=^ai-trading-python-app-run-")"
if [[ -n "$stale_run_ids" ]]; then
docker rm -f $stale_run_ids >/dev/null 2>&1 || true
fi
docker rm -f ai-trading-redis >/dev/null 2>&1 || true
"${compose_cmd[@]}" pull || true
"${compose_cmd[@]}" up -d --build --remove-orphans redis go-service web
echo "[INFO] Quick 启动完成: http://localhost:8000"
;;
logs)
"${compose_cmd[@]}" logs -f --tail=300
;;
status)
"${compose_cmd[@]}" ps
;;
pull)
"${compose_cmd[@]}" pull || true
;;
build)
"${compose_cmd[@]}" build --pull
;;
quote)
if [[ -z "$QUERY" ]]; then
"${compose_cmd[@]}" run --rm --no-deps python-app python app/stock_lookup.py
else
"${compose_cmd[@]}" run --rm --no-deps python-app python app/stock_lookup.py "$QUERY"
fi
;;
web)
"${compose_cmd[@]}" up -d --build --remove-orphans web
echo "[INFO] Web 已启动: http://localhost:8000"
;;
cleanup)
"${compose_cmd[@]}" down --remove-orphans || true
docker compose -p aitrading -f "$COMPOSE_FILE" down --remove-orphans 2>/dev/null || true
stale_run_ids="$(docker ps -aq --filter "name=^ai-trading-python-app-run-")"
if [[ -n "$stale_run_ids" ]]; then
docker rm -f $stale_run_ids >/dev/null 2>&1 || true
fi
docker rm -f ai-trading-redis >/dev/null 2>&1 || true
echo "[INFO] 清理完成,建议执行: ./start.sh prod up && ./start.sh prod web"
;;
*)
echo "[ERROR] 不支持的 action: $ACTION"
show_help
exit 1
;;
esac
}
if [[ "$MODE" == "-h" || "$MODE" == "--help" ]]; then
show_help
exit 0
fi
echo "[INFO] mode=$MODE action=$ACTION compose=$COMPOSE_FILE env=$ENV_FILE project=$PROJECT_NAME"
case "$MODE" in
local)
run_local
;;
prod)
run_prod
;;
*)
echo "[ERROR] 不支持的 mode: $MODE"
show_help
exit 1
;;
esac
echo "[INFO] 完成: mode=$MODE action=$ACTION"