from __future__ import annotations import hmac import logging from urllib.parse import urlparse from fastapi import FastAPI, File, Request, Response, UploadFile from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from app.config import settings from app.logging_setup import configure_logging from app.middleware import RequestContextMiddleware from app.schemas import ( AuthCredentialRequest, ChangePasswordRequest, ForgotPasswordResetRequest, IMPublishRequest, RewriteRequest, WechatBindingRequest, WechatPublishRequest, WechatSwitchRequest, ) from app.services.ai_rewriter import AIRewriter from app.services.im import IMPublisher from app.services.user_store import UserStore from app.services.wechat import WechatPublisher configure_logging() logger = logging.getLogger(__name__) app = FastAPI(title=settings.app_name) @app.on_event("startup") async def _log_startup() -> None: logger.info( "app_start name=%s openai_configured=%s ai_soft_accept=%s", settings.app_name, bool(settings.openai_api_key), settings.ai_soft_accept, ) app.add_middleware(RequestContextMiddleware) app.mount("/static", StaticFiles(directory="app/static"), name="static") templates = Jinja2Templates(directory="app/templates") rewriter = AIRewriter() wechat = WechatPublisher() im = IMPublisher() users = UserStore(settings.auth_db_path) def _session_ttl(remember_me: bool) -> int: normal = max(600, int(settings.auth_session_ttl_sec)) remembered = max(normal, int(settings.auth_remember_session_ttl_sec)) return remembered if remember_me else normal def _current_user(request: Request) -> dict | None: token = request.cookies.get(settings.auth_cookie_name, "") return users.get_user_by_session(token) if token else None def _require_user(request: Request) -> dict | None: u = _current_user(request) if not u: return None return u @app.get("/", response_class=HTMLResponse) async def index(request: Request): if not _current_user(request): return RedirectResponse(url="/auth?next=/", status_code=302) return templates.TemplateResponse("index.html", {"request": request, "app_name": settings.app_name}) @app.get("/auth", response_class=HTMLResponse) async def auth_page(request: Request): nxt = (request.query_params.get("next") or "/").strip() or "/" if _current_user(request): return RedirectResponse(url=nxt, status_code=302) return templates.TemplateResponse( "auth.html", {"request": request, "app_name": settings.app_name, "next": nxt}, ) @app.get("/settings", response_class=HTMLResponse) async def settings_page(request: Request): if not _current_user(request): return RedirectResponse(url="/auth?next=/settings", status_code=302) return templates.TemplateResponse("settings.html", {"request": request, "app_name": settings.app_name}) @app.get("/favicon.ico", include_in_schema=False) async def favicon(): # 浏览器通常请求 /favicon.ico,统一跳转到静态图标 return RedirectResponse(url="/static/favicon.svg?v=20260406a") @app.get("/api/config") async def api_config(): """供页面展示:当前是否接入模型、模型名、提供方(不含密钥)。""" base = settings.openai_base_url or "" provider = "dashscope" if "dashscope.aliyuncs.com" in base else "openai_compatible" host = urlparse(base).netloc if base else "" return { "openai_configured": bool(settings.openai_api_key), "openai_model": settings.openai_model, "provider": provider, "base_url_host": host or None, "openai_timeout_sec": settings.openai_timeout, "openai_max_output_tokens": settings.openai_max_output_tokens, } @app.get("/api/auth/me") async def auth_me(request: Request): user = _current_user(request) if not user: return {"ok": True, "logged_in": False} binding = users.get_active_wechat_binding(user["id"]) bindings = users.list_wechat_bindings(user["id"]) return { "ok": True, "logged_in": True, "user": {"id": user["id"], "username": user["username"]}, "wechat_bound": bool(binding and binding.get("appid") and binding.get("secret")), "active_wechat_account": binding, "wechat_accounts": bindings, } @app.post("/api/auth/register") async def auth_register(req: AuthCredentialRequest, response: Response): username = (req.username or "").strip() password = req.password or "" if len(username) < 2: return {"ok": False, "detail": "用户名至少 2 个字符"} if len(password) < 6: return {"ok": False, "detail": "密码至少 6 个字符"} try: user = users.create_user(username, password) except Exception as exc: logger.exception("auth_register_failed username=%s detail=%s", username, str(exc)) return {"ok": False, "detail": "注册失败:账号库异常,请稍后重试"} if not user: return {"ok": False, "detail": "用户名已存在"} ttl = _session_ttl(bool(req.remember_me)) token = users.create_session(user["id"], ttl_seconds=ttl) response.set_cookie( key=settings.auth_cookie_name, value=token, httponly=True, samesite="lax", max_age=ttl, path="/", ) return {"ok": True, "detail": "注册并登录成功", "user": user} @app.post("/api/auth/login") async def auth_login(req: AuthCredentialRequest, response: Response): try: user = users.verify_user((req.username or "").strip(), req.password or "") except Exception as exc: logger.exception("auth_login_failed username=%s detail=%s", (req.username or "").strip(), str(exc)) return {"ok": False, "detail": "登录失败:账号库异常,请稍后重试"} if not user: return {"ok": False, "detail": "用户名或密码错误"} ttl = _session_ttl(bool(req.remember_me)) token = users.create_session(user["id"], ttl_seconds=ttl) response.set_cookie( key=settings.auth_cookie_name, value=token, httponly=True, samesite="lax", max_age=ttl, path="/", ) return {"ok": True, "detail": "登录成功", "user": user} @app.post("/api/auth/logout") async def auth_logout(request: Request, response: Response): token = request.cookies.get(settings.auth_cookie_name, "") if token: users.delete_session(token) response.delete_cookie(settings.auth_cookie_name, path="/") return {"ok": True, "detail": "已退出登录"} @app.get("/auth/forgot", response_class=HTMLResponse) async def forgot_password_page(request: Request): return templates.TemplateResponse("forgot_password.html", {"request": request, "app_name": settings.app_name}) @app.post("/api/auth/password/forgot") async def auth_forgot_password_reset(req: ForgotPasswordResetRequest): reset_key = (req.reset_key or "").strip() expected_key = (settings.auth_password_reset_key or "x2ws-reset-2026").strip() username = (req.username or "").strip() new_password = req.new_password or "" if not expected_key: return {"ok": False, "detail": "系统未启用忘记密码重置功能,请联系管理员"} if len(username) < 2: return {"ok": False, "detail": "请输入正确的用户名"} if len(new_password) < 6: return {"ok": False, "detail": "新密码至少 6 个字符"} if not hmac.compare_digest(reset_key, expected_key): return {"ok": False, "detail": "重置码错误"} ok = users.reset_password_by_username(username, new_password) if not ok: return {"ok": False, "detail": "用户不存在,无法重置"} return {"ok": True, "detail": "密码重置成功,请返回登录页重新登录"} @app.post("/api/auth/password/change") async def auth_change_password(req: ChangePasswordRequest, request: Request, response: Response): user = _require_user(request) if not user: return {"ok": False, "detail": "请先登录"} old_password = req.old_password or "" new_password = req.new_password or "" if len(old_password) < 1: return {"ok": False, "detail": "请输入当前密码"} if len(new_password) < 6: return {"ok": False, "detail": "新密码至少 6 个字符"} if old_password == new_password: return {"ok": False, "detail": "新密码不能与当前密码相同"} ok = users.change_password(user["id"], old_password, new_password) if not ok: return {"ok": False, "detail": "当前密码错误,修改失败"} users.delete_sessions_by_user(user["id"]) ttl = _session_ttl(False) token = users.create_session(user["id"], ttl_seconds=ttl) response.set_cookie( key=settings.auth_cookie_name, value=token, httponly=True, samesite="lax", max_age=ttl, path="/", ) return {"ok": True, "detail": "密码修改成功,已刷新登录状态"} @app.post("/api/auth/wechat/bind") async def auth_wechat_bind(req: WechatBindingRequest, request: Request): user = _require_user(request) if not user: return {"ok": False, "detail": "请先登录"} appid = (req.appid or "").strip() secret = (req.secret or "").strip() if not appid or not secret: return {"ok": False, "detail": "appid/secret 不能为空"} created = users.add_wechat_binding( user_id=user["id"], account_name=(req.account_name or "").strip() or "公众号账号", appid=appid, secret=secret, author=(req.author or "").strip(), thumb_media_id=(req.thumb_media_id or "").strip(), thumb_image_path=(req.thumb_image_path or "").strip(), ) return {"ok": True, "detail": "公众号账号绑定成功", "account": created} @app.post("/api/auth/wechat/switch") async def auth_wechat_switch(req: WechatSwitchRequest, request: Request): user = _require_user(request) if not user: return {"ok": False, "detail": "请先登录"} ok = users.switch_active_wechat_binding(user["id"], int(req.account_id)) if not ok: return {"ok": False, "detail": "切换失败:账号不存在或无权限"} return {"ok": True, "detail": "已切换当前公众号账号"} @app.post("/api/rewrite") async def rewrite(req: RewriteRequest, request: Request): rid = getattr(request.state, "request_id", "") src = req.source_text or "" logger.info( "api_rewrite_in rid=%s source_chars=%d title_hint_chars=%d tone=%s audience=%s " "keep_points_chars=%d avoid_words_chars=%d", rid, len(src), len(req.title_hint or ""), req.tone, req.audience, len(req.keep_points or ""), len(req.avoid_words or ""), ) result = rewriter.rewrite(req, request_id=rid) tr = result.trace or {} logger.info( "api_rewrite_out rid=%s mode=%s duration_ms=%s quality_notes=%d trace_steps=%s soft_accept=%s", rid, result.mode, tr.get("duration_ms"), len(result.quality_notes or []), len((tr.get("steps") or [])), tr.get("quality_soft_accept"), ) return result @app.post("/api/publish/wechat") async def publish_wechat(req: WechatPublishRequest, request: Request): user = _require_user(request) if not user: return {"ok": False, "detail": "请先登录"} binding = users.get_active_wechat_binding(user["id"]) if not binding: return {"ok": False, "detail": "当前账号未绑定公众号 token,请先在页面绑定"} rid = getattr(request.state, "request_id", "") logger.info( "api_wechat_in rid=%s title_chars=%d summary_chars=%d body_md_chars=%d author_set=%s", rid, len(req.title or ""), len(req.summary or ""), len(req.body_markdown or ""), bool((req.author or "").strip()), ) out = await wechat.publish_draft(req, request_id=rid, account=binding) wcode = (out.data or {}).get("errcode") if isinstance(out.data, dict) else None logger.info( "api_wechat_out rid=%s ok=%s wechat_errcode=%s detail_preview=%s", rid, out.ok, wcode, (out.detail or "")[:240], ) return out @app.post("/api/wechat/cover/upload") async def upload_wechat_cover(request: Request, file: UploadFile = File(...)): user = _require_user(request) if not user: return {"ok": False, "detail": "请先登录"} binding = users.get_active_wechat_binding(user["id"]) if not binding: return {"ok": False, "detail": "当前账号未绑定公众号 token,请先在页面绑定"} rid = getattr(request.state, "request_id", "") fn = file.filename or "cover.jpg" content = await file.read() logger.info("api_wechat_cover_upload_in rid=%s filename=%s bytes=%d", rid, fn, len(content)) out = await wechat.upload_cover(fn, content, request_id=rid, account=binding) logger.info( "api_wechat_cover_upload_out rid=%s ok=%s detail=%s", rid, out.ok, (out.detail or "")[:160], ) return out @app.post("/api/wechat/material/upload") async def upload_wechat_material(request: Request, file: UploadFile = File(...)): user = _require_user(request) if not user: return {"ok": False, "detail": "请先登录"} binding = users.get_active_wechat_binding(user["id"]) if not binding: return {"ok": False, "detail": "当前账号未绑定公众号 token,请先在页面绑定"} rid = getattr(request.state, "request_id", "") fn = file.filename or "material.jpg" content = await file.read() logger.info("api_wechat_material_upload_in rid=%s filename=%s bytes=%d", rid, fn, len(content)) out = await wechat.upload_body_material(fn, content, request_id=rid, account=binding) logger.info( "api_wechat_material_upload_out rid=%s ok=%s detail=%s", rid, out.ok, (out.detail or "")[:160], ) return out @app.post("/api/publish/im") async def publish_im(req: IMPublishRequest, request: Request): rid = getattr(request.state, "request_id", "") logger.info( "api_im_in rid=%s title_chars=%d body_md_chars=%d", rid, len(req.title or ""), len(req.body_markdown or ""), ) out = await im.publish(req, request_id=rid) logger.info("api_im_out rid=%s ok=%s detail=%s", rid, out.ok, (out.detail or "")[:120]) return out