feat: 修复报错
This commit is contained in:
142
python-app/app/backtest_runner.py
Normal file
142
python-app/app/backtest_runner.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Type
|
||||
|
||||
import backtrader as bt
|
||||
import pandas as pd
|
||||
|
||||
from data_protocol import KlineMessage
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Backtrader runner")
|
||||
parser.add_argument("--input", required=True, help="Path to kline jsonl file")
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
default="strategies.sma_cross.SmaCrossStrategy",
|
||||
help="Strategy class path, e.g. strategies.sma_cross.SmaCrossStrategy",
|
||||
)
|
||||
parser.add_argument("--cash", type=float, default=100000.0, help="Initial cash")
|
||||
parser.add_argument("--commission", type=float, default=0.001, help="Commission")
|
||||
parser.add_argument(
|
||||
"--strategy-param",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Strategy parameter in key=value format, repeatable",
|
||||
)
|
||||
parser.add_argument("--plot", action="store_true", help="Plot result")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_strategy_class(class_path: str) -> Type[bt.Strategy]:
|
||||
module_name, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
strategy_cls = getattr(module, class_name)
|
||||
if not issubclass(strategy_cls, bt.Strategy):
|
||||
raise TypeError(f"{class_path} is not a backtrader.Strategy subclass")
|
||||
return strategy_cls
|
||||
|
||||
|
||||
def load_dataframe(path: Path) -> pd.DataFrame:
|
||||
rows = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
msg = KlineMessage.from_dict(payload)
|
||||
rows.append(
|
||||
{
|
||||
"datetime": pd.to_datetime(msg.open_time, unit="ms"),
|
||||
"open": float(msg.open),
|
||||
"high": float(msg.high),
|
||||
"low": float(msg.low),
|
||||
"close": float(msg.close),
|
||||
"volume": float(msg.volume),
|
||||
"openinterest": 0.0,
|
||||
}
|
||||
)
|
||||
if not rows:
|
||||
raise RuntimeError("no kline rows found for backtest")
|
||||
frame = pd.DataFrame(rows).set_index("datetime")
|
||||
return frame.sort_index()
|
||||
|
||||
|
||||
def parse_strategy_params(raw_params: list[str]) -> dict[str, Any]:
|
||||
parsed: dict[str, Any] = {}
|
||||
for item in raw_params:
|
||||
if "=" not in item:
|
||||
raise ValueError(f"invalid --strategy-param '{item}', expected key=value")
|
||||
key, raw_value = item.split("=", 1)
|
||||
key = key.strip()
|
||||
raw_value = raw_value.strip()
|
||||
if not key:
|
||||
raise ValueError(f"invalid --strategy-param '{item}', empty key")
|
||||
try:
|
||||
value = ast.literal_eval(raw_value)
|
||||
except (ValueError, SyntaxError):
|
||||
value = raw_value
|
||||
parsed[key] = value
|
||||
return parsed
|
||||
|
||||
|
||||
def extract_report(result: list[bt.Strategy], final_value: float) -> dict[str, Any]:
|
||||
strategy = result[0]
|
||||
drawdown_info = strategy.analyzers.drawdown.get_analysis()
|
||||
sharpe_info = strategy.analyzers.sharpe.get_analysis()
|
||||
|
||||
max_drawdown = drawdown_info.get("max", {}).get("drawdown")
|
||||
sharpe_ratio = sharpe_info.get("sharperatio")
|
||||
return {
|
||||
"final_value": final_value,
|
||||
"max_drawdown_pct": None if max_drawdown is None else float(max_drawdown),
|
||||
"sharpe_ratio": None if sharpe_ratio is None else float(sharpe_ratio),
|
||||
}
|
||||
|
||||
|
||||
def print_report(report: dict[str, Any]) -> None:
|
||||
print("\n===== Backtest Report =====")
|
||||
print(f"Final Value : {report['final_value']:.2f}")
|
||||
max_drawdown = report["max_drawdown_pct"]
|
||||
sharpe_ratio = report["sharpe_ratio"]
|
||||
print(f"Max Drawdown (%) : {'N/A' if max_drawdown is None else f'{max_drawdown:.4f}'}")
|
||||
print(f"Sharpe Ratio : {'N/A' if sharpe_ratio is None else f'{sharpe_ratio:.4f}'}")
|
||||
print("===========================\n")
|
||||
|
||||
|
||||
def run_backtest() -> None:
|
||||
args = parse_args()
|
||||
data_path = Path(args.input)
|
||||
if not data_path.exists():
|
||||
raise FileNotFoundError(f"input file not found: {data_path}")
|
||||
|
||||
strategy_cls = load_strategy_class(args.strategy)
|
||||
strategy_params = parse_strategy_params(args.strategy_param)
|
||||
df = load_dataframe(data_path)
|
||||
|
||||
cerebro = bt.Cerebro()
|
||||
cerebro.addstrategy(strategy_cls, **strategy_params)
|
||||
cerebro.broker.setcash(args.cash)
|
||||
cerebro.broker.setcommission(commission=args.commission)
|
||||
cerebro.adddata(bt.feeds.PandasData(dataname=df))
|
||||
cerebro.addanalyzer(bt.analyzers.DrawDown, _name="drawdown")
|
||||
cerebro.addanalyzer(bt.analyzers.SharpeRatio_A, _name="sharpe", riskfreerate=0.0)
|
||||
|
||||
print(f"starting portfolio value: {cerebro.broker.getvalue():.2f}")
|
||||
result = cerebro.run()
|
||||
final_value = cerebro.broker.getvalue()
|
||||
print(f"final portfolio value: {final_value:.2f}")
|
||||
print(f"strategies executed: {len(result)}")
|
||||
print_report(extract_report(result, final_value))
|
||||
|
||||
if args.plot:
|
||||
cerebro.plot(style="candlestick")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_backtest()
|
||||
Reference in New Issue
Block a user