fix 修复bug

This commit is contained in:
Daniel
2026-04-28 19:40:02 +08:00
parent c234fe64d6
commit 0134a5ef64
11 changed files with 720 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ import logging
import math
import re
import secrets
import sqlite3
import socket
import time
import uuid
@@ -33,6 +34,7 @@ from app.schemas import (
ForgotPasswordResetRequest,
IMPublishRequest,
PosterGenerateRequest,
ResetCodeRegenerateRequest,
RewriteRequest,
WechatCoverUploadByUrlRequest,
WechatCoverGenerateRequest,
@@ -80,6 +82,7 @@ _login_rate: dict[str, list[float]] = {}
_challenge_pool: dict[str, dict] = {}
USERNAME_RE = re.compile(r"^[A-Za-z0-9_]{4,24}$")
PASSWORD_STRONG_RE = re.compile(r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^A-Za-z0-9])\S{10,64}$")
TABLE_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _session_ttl(remember_me: bool) -> int:
@@ -100,6 +103,17 @@ def _require_user(request: Request) -> dict | None:
return u
def _require_super_admin(request: Request) -> None:
configured = (settings.super_admin_token or "").strip()
provided = (
request.headers.get("X-Admin-Token")
or request.query_params.get("token")
or ""
).strip()
if not configured or provided != configured:
raise HTTPException(status_code=403, detail="forbidden")
def _client_ip(request: Request) -> str:
xff = (request.headers.get("x-forwarded-for") or "").strip()
if xff:
@@ -162,18 +176,66 @@ def _validate_username_password(username: str, password: str) -> tuple[bool, str
def _platform_model_cfg() -> dict:
override = _get_platform_model_override()
return {
"api_key": settings.platform_openai_api_key or "",
"base_url": settings.platform_openai_base_url or "",
"model": settings.platform_openai_model,
"image_model": settings.platform_openai_image_model,
"timeout_sec": float(settings.platform_openai_timeout),
"max_output_tokens": int(settings.platform_openai_max_output_tokens),
"max_retries": int(settings.platform_openai_max_retries),
"api_key": str(override.get("api_key") or settings.platform_openai_api_key or ""),
"base_url": str(override.get("base_url") or settings.platform_openai_base_url or ""),
"model": str(override.get("model") or settings.platform_openai_model),
"image_model": str(override.get("image_model") or settings.platform_openai_image_model),
"timeout_sec": float(override.get("timeout_sec") or settings.platform_openai_timeout),
"max_output_tokens": int(override.get("max_output_tokens") or settings.platform_openai_max_output_tokens),
"max_retries": int(override.get("max_retries") or settings.platform_openai_max_retries),
"model_name": "平台模型",
}
def _ensure_system_settings_table() -> None:
with sqlite3.connect(settings.auth_db_path) as c:
c.execute(
"""
CREATE TABLE IF NOT EXISTS system_settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
)
"""
)
c.commit()
def _get_system_setting(key: str, default: str = "") -> str:
_ensure_system_settings_table()
with sqlite3.connect(settings.auth_db_path) as c:
row = c.execute("SELECT value FROM system_settings WHERE key=? LIMIT 1", (key,)).fetchone()
return str(row[0]) if row and row[0] is not None else default
def _set_system_setting(key: str, value: str) -> None:
_ensure_system_settings_table()
with sqlite3.connect(settings.auth_db_path) as c:
c.execute(
"""
INSERT INTO system_settings(key, value, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value, updated_at=excluded.updated_at
""",
(key, value, int(time.time())),
)
c.commit()
def _get_platform_model_override() -> dict:
return {
"api_key": _get_system_setting("platform_model_api_key", ""),
"base_url": _get_system_setting("platform_model_base_url", ""),
"model": _get_system_setting("platform_model_text_model", ""),
"image_model": _get_system_setting("platform_model_image_model", ""),
"timeout_sec": _get_system_setting("platform_model_timeout_sec", ""),
"max_output_tokens": _get_system_setting("platform_model_max_output_tokens", ""),
"max_retries": _get_system_setting("platform_model_max_retries", ""),
}
def _select_model_cfg(user_id: int, prefer_vip: bool = True) -> tuple[dict | None, str]:
vip = users.get_vip_status(user_id)
now = int(time.time())
@@ -390,6 +452,12 @@ async def guide_page(request: Request):
return templates.TemplateResponse("guide.html", {"request": request, "app_name": settings.app_name})
@app.get("/admin", response_class=HTMLResponse)
async def admin_page(request: Request):
_require_super_admin(request)
return templates.TemplateResponse("admin.html", {"request": request, "app_name": settings.app_name})
@app.get("/favicon.ico", include_in_schema=False)
async def favicon():
# 浏览器通常请求 /favicon.ico统一跳转到静态图标
@@ -420,6 +488,139 @@ async def api_config(request: Request):
}
@app.get("/api/admin/tables")
async def admin_tables(request: Request):
_require_super_admin(request)
with sqlite3.connect(settings.auth_db_path) as c:
rows = c.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name ASC"
).fetchall()
return {"ok": True, "tables": [str(r[0]) for r in rows]}
@app.get("/api/admin/table/{table_name}")
async def admin_table_rows(table_name: str, request: Request, limit: int = 100, offset: int = 0):
_require_super_admin(request)
tname = (table_name or "").strip()
if not TABLE_NAME_RE.match(tname):
return {"ok": False, "detail": "invalid table name"}
page_limit = max(1, min(int(limit), 500))
page_offset = max(0, min(int(offset), 100000))
with sqlite3.connect(settings.auth_db_path) as c:
c.row_factory = sqlite3.Row
exists = c.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1",
(tname,),
).fetchone()
if not exists:
return {"ok": False, "detail": "table not found"}
count_row = c.execute(f'SELECT COUNT(*) AS cnt FROM "{tname}"').fetchone()
total = int((count_row["cnt"] if count_row else 0) or 0)
rows = c.execute(
f'SELECT * FROM "{tname}" LIMIT ? OFFSET ?',
(page_limit, page_offset),
).fetchall()
items = [dict(r) for r in rows]
columns = list(items[0].keys()) if items else []
return {
"ok": True,
"table": tname,
"limit": page_limit,
"offset": page_offset,
"total": total,
"columns": columns,
"rows": items,
}
@app.get("/api/admin/platform-model")
async def admin_get_platform_model(request: Request):
_require_super_admin(request)
cfg = _platform_model_cfg()
return {
"ok": True,
"config": {
"api_key": cfg.get("api_key", ""),
"base_url": cfg.get("base_url", ""),
"model": cfg.get("model", ""),
"image_model": cfg.get("image_model", ""),
"timeout_sec": cfg.get("timeout_sec", 120),
"max_output_tokens": cfg.get("max_output_tokens", 8192),
"max_retries": cfg.get("max_retries", 0),
},
"text_model_options": [s.strip() for s in (settings.platform_openai_text_model_options or "").split(",") if s.strip()],
"image_model_options": [s.strip() for s in (settings.platform_openai_image_model_options or "").split(",") if s.strip()],
}
@app.post("/api/admin/platform-model")
async def admin_update_platform_model(request: Request):
_require_super_admin(request)
body = await request.json()
api_key = str((body or {}).get("api_key") or "").strip()
base_url = str((body or {}).get("base_url") or "").strip()
model = str((body or {}).get("model") or "").strip()
image_model = str((body or {}).get("image_model") or "").strip()
timeout_sec = max(5.0, min(600.0, float((body or {}).get("timeout_sec") or settings.platform_openai_timeout)))
max_output_tokens = max(256, min(65535, int((body or {}).get("max_output_tokens") or settings.platform_openai_max_output_tokens)))
max_retries = max(0, min(5, int((body or {}).get("max_retries") or settings.platform_openai_max_retries)))
if not api_key or not model:
return {"ok": False, "detail": "平台模型配置至少需要 API Key 和文本模型"}
_set_system_setting("platform_model_api_key", api_key)
_set_system_setting("platform_model_base_url", base_url)
_set_system_setting("platform_model_text_model", model)
_set_system_setting("platform_model_image_model", image_model)
_set_system_setting("platform_model_timeout_sec", str(timeout_sec))
_set_system_setting("platform_model_max_output_tokens", str(max_output_tokens))
_set_system_setting("platform_model_max_retries", str(max_retries))
return {"ok": True, "detail": "平台模型配置已保存"}
@app.get("/api/admin/users/overview")
async def admin_users_overview(request: Request, limit: int = 30):
_require_super_admin(request)
now = int(time.time())
today_start = now - (now % 86400)
page_limit = max(1, min(int(limit), 200))
with sqlite3.connect(settings.auth_db_path) as c:
c.row_factory = sqlite3.Row
total_row = c.execute("SELECT COUNT(*) AS cnt FROM users").fetchone()
active_row = c.execute("SELECT COUNT(*) AS cnt FROM users WHERE deleted_at IS NULL").fetchone()
deleted_row = c.execute("SELECT COUNT(*) AS cnt FROM users WHERE deleted_at IS NOT NULL").fetchone()
today_row = c.execute(
"SELECT COUNT(*) AS cnt FROM users WHERE created_at>=? AND created_at<?",
(today_start, today_start + 86400),
).fetchone()
recent = c.execute(
"""
SELECT id, username, created_at, deleted_at
FROM users
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(page_limit,),
).fetchall()
return {
"ok": True,
"stats": {
"total_users": int((total_row["cnt"] if total_row else 0) or 0),
"active_users": int((active_row["cnt"] if active_row else 0) or 0),
"deleted_users": int((deleted_row["cnt"] if deleted_row else 0) or 0),
"today_new_users": int((today_row["cnt"] if today_row else 0) or 0),
},
"recent_users": [
{
"id": int(r["id"] or 0),
"username": str(r["username"] or ""),
"created_at": int(r["created_at"] or 0),
"deleted_at": int(r["deleted_at"] or 0) if r["deleted_at"] else 0,
}
for r in recent
],
}
@app.get("/api/auth/me")
async def auth_me(request: Request):
user = _current_user(request)
@@ -470,9 +671,9 @@ async def auth_challenge():
@app.post("/api/auth/register")
async def auth_register(req: AuthCredentialRequest, request: Request, response: Response):
ip = _client_ip(request)
if _hit_limit(_register_rate, f"ip:{ip}", limit=8, window_sec=600):
if _hit_limit(_register_rate, f"ip:{ip}", limit=20, window_sec=300):
return {"ok": False, "detail": "请求过于频繁,请稍后再试"}
if _hit_limit(_register_rate, f"user:{(req.username or '').strip().lower()}", limit=6, window_sec=600):
if _hit_limit(_register_rate, f"user:{(req.username or '').strip().lower()}", limit=12, window_sec=300):
return {"ok": False, "detail": "该用户名操作过于频繁,请稍后再试"}
username = (req.username or "").strip()
password = req.password or ""
@@ -505,15 +706,17 @@ async def auth_register(req: AuthCredentialRequest, request: Request, response:
"detail": "注册并登录成功,已赠送试用 Credits请保存重置码",
"user": {"id": user["id"], "username": user["username"]},
"reset_code": user.get("reset_code", ""),
"is_new_user": True,
"redirect_to": "/guide",
}
@app.post("/api/auth/login")
async def auth_login(req: AuthCredentialRequest, request: Request, response: Response):
ip = _client_ip(request)
if _hit_limit(_login_rate, f"ip:{ip}", limit=20, window_sec=600):
if _hit_limit(_login_rate, f"ip:{ip}", limit=60, window_sec=300):
return {"ok": False, "detail": "登录过于频繁,请稍后再试"}
if _hit_limit(_login_rate, f"user:{(req.username or '').strip().lower()}", limit=12, window_sec=600):
if _hit_limit(_login_rate, f"user:{(req.username or '').strip().lower()}", limit=30, window_sec=300):
return {"ok": False, "detail": "该账户登录尝试过多,请稍后再试"}
try:
user = users.verify_user((req.username or "").strip(), req.password or "")
@@ -565,6 +768,21 @@ async def auth_forgot_password_reset(req: ForgotPasswordResetRequest):
return {"ok": True, "detail": "密码重置成功,请返回登录页重新登录"}
@app.post("/api/auth/reset-code/regenerate")
async def auth_regenerate_reset_code(req: ResetCodeRegenerateRequest, request: Request):
user = _require_user(request)
if not user:
return {"ok": False, "detail": "请先登录"}
new_code = users.regenerate_reset_code(user["id"], req.password or "")
if not new_code:
return {"ok": False, "detail": "当前密码错误,无法生成新重置码"}
return {
"ok": True,
"detail": "新重置码已生成,仅展示一次,请立即保存",
"reset_code": new_code,
}
@app.post("/api/auth/password/change")
async def auth_change_password(req: ChangePasswordRequest, request: Request, response: Response):
user = _require_user(request)