diff --git a/app/config.py b/app/config.py index 20c9594..fe7b690 100644 --- a/app/config.py +++ b/app/config.py @@ -74,5 +74,28 @@ class Settings(BaseSettings): auth_remember_session_ttl_sec: int = Field(default=604800, alias="AUTH_REMEMBER_SESSION_TTL_SEC") auth_password_reset_key: str | None = Field(default="x2ws-reset-2026", alias="AUTH_PASSWORD_RESET_KEY") + vip_trial_tokens: int = Field( + default=20000, + alias="VIP_TRIAL_TOKENS", + description="新用户试用赠送 token", + ) + vip_rewrite_token_per_1k_chars: int = Field( + default=1200, + alias="VIP_REWRITE_TOKEN_PER_1K_CHARS", + description="改写按千字计费 token 单价", + ) + vip_image_token_per_image: int = Field( + default=1800, + alias="VIP_IMAGE_TOKEN_PER_IMAGE", + description="文生图单张扣减 token", + ) + platform_openai_api_key: str | None = Field(default=None, alias="PLATFORM_OPENAI_API_KEY") + platform_openai_base_url: str | None = Field(default=None, alias="PLATFORM_OPENAI_BASE_URL") + platform_openai_model: str = Field(default="gpt-4.1-mini", alias="PLATFORM_OPENAI_MODEL") + platform_openai_image_model: str = Field(default="gpt-image-1", alias="PLATFORM_OPENAI_IMAGE_MODEL") + platform_openai_timeout: float = Field(default=120.0, alias="PLATFORM_OPENAI_TIMEOUT") + platform_openai_max_output_tokens: int = Field(default=8192, alias="PLATFORM_OPENAI_MAX_OUTPUT_TOKENS") + platform_openai_max_retries: int = Field(default=0, alias="PLATFORM_OPENAI_MAX_RETRIES") + settings = Settings() diff --git a/app/main.py b/app/main.py index d2f41cd..6445f1f 100644 --- a/app/main.py +++ b/app/main.py @@ -30,6 +30,8 @@ from app.schemas import ( WechatBindingRequest, WechatPublishRequest, WechatSwitchRequest, + VipRechargeRequest, + VipToggleRequest, ) from app.services.ai_rewriter import AIRewriter from app.services.im import IMPublisher @@ -82,6 +84,37 @@ def _require_user(request: Request) -> dict | None: return u +def _platform_model_cfg() -> dict: + return { + "api_key": settings.platform_openai_api_key or "", + "base_url": settings.platform_openai_base_url or "", + "model": settings.platform_openai_model, + "image_model": settings.platform_openai_image_model, + "timeout_sec": float(settings.platform_openai_timeout), + "max_output_tokens": int(settings.platform_openai_max_output_tokens), + "max_retries": int(settings.platform_openai_max_retries), + "model_name": "平台模型", + } + + +def _select_model_cfg(user_id: int, prefer_vip: bool = True) -> tuple[dict | None, str]: + vip = users.get_vip_status(user_id) + if prefer_vip and vip.get("vip_enabled") and int(vip.get("token_balance") or 0) > 0: + cfg = _platform_model_cfg() + if cfg.get("api_key"): + return cfg, "vip" + cfg = users.get_active_ai_model(user_id) + return cfg, "user" + + +def _estimate_rewrite_cost(req: RewriteRequest, result) -> int: + src_chars = len((req.source_text or "").strip()) + out_chars = len((result.body_markdown or "").strip()) + len((result.title or "").strip()) + len((result.summary or "").strip()) + total_chars = max(1, src_chars + out_chars) + blocks = (total_chars + 999) // 1000 + return int(blocks * max(1, int(settings.vip_rewrite_token_per_1k_chars))) + + @app.get("/", response_class=HTMLResponse) async def index(request: Request): if not _current_user(request): @@ -129,12 +162,14 @@ async def api_config(request: Request): provider = "dashscope" if "dashscope.aliyuncs.com" in base else "openai_compatible" host = urlparse(base).netloc if base else "" model_name = (model_cfg or {}).get("model") or None + image_model_name = (model_cfg or {}).get("image_model") or settings.openai_image_model timeout_sec = (model_cfg or {}).get("timeout_sec") or None max_output_tokens = (model_cfg or {}).get("max_output_tokens") or None key_configured = bool((model_cfg or {}).get("api_key")) return { "openai_configured": key_configured, "openai_model": model_name, + "openai_image_model": image_model_name, "provider": provider, "base_url_host": host or None, "openai_timeout_sec": timeout_sec, @@ -158,6 +193,7 @@ async def auth_me(request: Request): "wechat_accounts": bindings, "active_ai_model": users.get_active_ai_model(user["id"]), "ai_models": users.list_ai_models(user["id"]), + "vip": users.get_vip_status(user["id"]), } @@ -176,6 +212,7 @@ async def auth_register(req: AuthCredentialRequest, response: Response): return {"ok": False, "detail": "注册失败:账号库异常,请稍后重试"} if not user: return {"ok": False, "detail": "用户名已存在"} + users.ensure_trial_tokens(user["id"], settings.vip_trial_tokens) ttl = _session_ttl(bool(req.remember_me)) token = users.create_session(user["id"], ttl_seconds=ttl) response.set_cookie( @@ -188,7 +225,7 @@ async def auth_register(req: AuthCredentialRequest, response: Response): ) return { "ok": True, - "detail": "注册并登录成功,请保存重置码", + "detail": "注册并登录成功,已赠送试用 token,请保存重置码", "user": {"id": user["id"], "username": user["username"]}, "reset_code": user.get("reset_code", ""), } @@ -358,6 +395,7 @@ async def auth_ai_model_add(req: AIModelCreateRequest, request: Request): api_key=api_key, base_url=(req.base_url or "").strip(), model=model, + image_model=(req.image_model or "").strip(), timeout_sec=max(10.0, float(req.timeout_sec or 120.0)), max_output_tokens=max(256, int(req.max_output_tokens or 8192)), max_retries=max(0, int(req.max_retries or 0)), @@ -416,6 +454,7 @@ async def rewrite(req: RewriteRequest, request: Request): "openai_timeout": settings.openai_timeout, "openai_max_output_tokens": settings.openai_max_output_tokens, "openai_max_retries": settings.openai_max_retries, + "openai_image_model": settings.openai_image_model, } try: settings.openai_api_key = model_cfg.get("api_key") or "" @@ -432,6 +471,7 @@ async def rewrite(req: RewriteRequest, request: Request): settings.openai_timeout = backup["openai_timeout"] settings.openai_max_output_tokens = backup["openai_max_output_tokens"] settings.openai_max_retries = backup["openai_max_retries"] + settings.openai_image_model = backup["openai_image_model"] tr = result.trace or {} logger.info( "api_rewrite_out rid=%s mode=%s duration_ms=%s quality_notes=%d trace_steps=%s soft_accept=%s", @@ -567,6 +607,7 @@ async def generate_wechat_cover(req: WechatCoverGenerateRequest, request: Reques "openai_timeout": settings.openai_timeout, "openai_max_output_tokens": settings.openai_max_output_tokens, "openai_max_retries": settings.openai_max_retries, + "openai_image_model": settings.openai_image_model, } try: if model_cfg: @@ -576,6 +617,7 @@ async def generate_wechat_cover(req: WechatCoverGenerateRequest, request: Reques settings.openai_timeout = float(model_cfg.get("timeout_sec") or 120.0) settings.openai_max_output_tokens = int(model_cfg.get("max_output_tokens") or 8192) settings.openai_max_retries = int(model_cfg.get("max_retries") or 0) + settings.openai_image_model = (model_cfg.get("image_model") or "").strip() or backup["openai_image_model"] else: settings.openai_api_key = "" settings.openai_base_url = "" @@ -583,6 +625,7 @@ async def generate_wechat_cover(req: WechatCoverGenerateRequest, request: Reques settings.openai_timeout = 120.0 settings.openai_max_output_tokens = 8192 settings.openai_max_retries = 0 + settings.openai_image_model = backup["openai_image_model"] out = await PosterMaterialService(wechat).generate_cover(req, request_id=rid, account=binding) finally: settings.openai_api_key = backup["openai_api_key"] @@ -591,6 +634,7 @@ async def generate_wechat_cover(req: WechatCoverGenerateRequest, request: Reques settings.openai_timeout = backup["openai_timeout"] settings.openai_max_output_tokens = backup["openai_max_output_tokens"] settings.openai_max_retries = backup["openai_max_retries"] + settings.openai_image_model = backup["openai_image_model"] logger.info( "api_wechat_cover_generate_out rid=%s ok=%s thumb=%s note=%s warnings=%d", rid, @@ -648,6 +692,7 @@ async def generate_posters(req: PosterGenerateRequest, request: Request): "openai_timeout": settings.openai_timeout, "openai_max_output_tokens": settings.openai_max_output_tokens, "openai_max_retries": settings.openai_max_retries, + "openai_image_model": settings.openai_image_model, } try: if model_cfg: @@ -657,6 +702,7 @@ async def generate_posters(req: PosterGenerateRequest, request: Request): settings.openai_timeout = float(model_cfg.get("timeout_sec") or 120.0) settings.openai_max_output_tokens = int(model_cfg.get("max_output_tokens") or 8192) settings.openai_max_retries = int(model_cfg.get("max_retries") or 0) + settings.openai_image_model = (model_cfg.get("image_model") or "").strip() or backup["openai_image_model"] else: settings.openai_api_key = "" settings.openai_base_url = "" @@ -664,6 +710,7 @@ async def generate_posters(req: PosterGenerateRequest, request: Request): settings.openai_timeout = 120.0 settings.openai_max_output_tokens = 8192 settings.openai_max_retries = 0 + settings.openai_image_model = backup["openai_image_model"] out = await PosterMaterialService(wechat).generate(req, request_id=rid, account=binding) finally: settings.openai_api_key = backup["openai_api_key"] @@ -672,6 +719,7 @@ async def generate_posters(req: PosterGenerateRequest, request: Request): settings.openai_timeout = backup["openai_timeout"] settings.openai_max_output_tokens = backup["openai_max_output_tokens"] settings.openai_max_retries = backup["openai_max_retries"] + settings.openai_image_model = backup["openai_image_model"] logger.info( "api_poster_generate_out rid=%s ok=%s posters=%d warnings=%d", rid, diff --git a/app/schemas.py b/app/schemas.py index 370155a..2ca6ba4 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -105,6 +105,7 @@ class AIModelCreateRequest(BaseModel): api_key: str base_url: str = "" model: str + image_model: str = "" timeout_sec: float = 120.0 max_output_tokens: int = 8192 max_retries: int = 0 @@ -118,6 +119,14 @@ class AIModelDeleteRequest(BaseModel): model_id: int +class VipToggleRequest(BaseModel): + enabled: bool = True + + +class VipRechargeRequest(BaseModel): + tokens: int = Field(default=10000, ge=1, le=10_000_000) + + class PosterGenerateRequest(BaseModel): title: str = "" summary: str = "" diff --git a/app/services/user_store.py b/app/services/user_store.py index 3c1bc9d..3fe6551 100644 --- a/app/services/user_store.py +++ b/app/services/user_store.py @@ -78,6 +78,7 @@ class UserStore: api_key TEXT NOT NULL, base_url TEXT NOT NULL DEFAULT '', model TEXT NOT NULL, + image_model TEXT NOT NULL DEFAULT '', timeout_sec REAL NOT NULL DEFAULT 120.0, max_output_tokens INTEGER NOT NULL DEFAULT 8192, max_retries INTEGER NOT NULL DEFAULT 0, @@ -87,6 +88,21 @@ class UserStore: ) """ ) + c.execute( + """ + CREATE TABLE IF NOT EXISTS user_wallets ( + user_id INTEGER PRIMARY KEY, + vip_enabled INTEGER NOT NULL DEFAULT 0, + token_balance INTEGER NOT NULL DEFAULT 0, + total_consumed_tokens INTEGER NOT NULL DEFAULT 0, + updated_at INTEGER NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(id) + ) + """ + ) + ai_cols = self._table_columns(c, "ai_models") + if "image_model" not in ai_cols: + c.execute("ALTER TABLE ai_models ADD COLUMN image_model TEXT NOT NULL DEFAULT ''") # 兼容历史单绑定结构,自动迁移为默认账号 rows = c.execute( "SELECT user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_bindings" @@ -299,6 +315,13 @@ class UserStore: (username, pwd_hash, salt, reset_hash, reset_salt, now), ) uid = int(cur.lastrowid) + c.execute( + """ + INSERT OR IGNORE INTO user_wallets(user_id, vip_enabled, token_balance, total_consumed_tokens, updated_at) + VALUES (?, 0, 0, 0, ?) + """, + (uid, now), + ) return {"id": uid, "username": username, "reset_code": reset_code} except sqlite3.IntegrityError: return None @@ -440,12 +463,112 @@ class UserStore: c.execute("DELETE FROM ai_models WHERE user_id=?", (user_id,)) c.execute("DELETE FROM user_prefs WHERE user_id=?", (user_id,)) c.execute("DELETE FROM wechat_bindings WHERE user_id=?", (user_id,)) + c.execute("DELETE FROM user_wallets WHERE user_id=?", (user_id,)) c.execute( "UPDATE users SET deleted_at=?, username=username || '#deleted' || ? WHERE id=?", (now, str(now), user_id), ) return True + def _ensure_wallet_row(self, c: sqlite3.Connection, user_id: int) -> None: + now = int(time.time()) + c.execute( + """ + INSERT OR IGNORE INTO user_wallets(user_id, vip_enabled, token_balance, total_consumed_tokens, updated_at) + VALUES (?, 0, 0, 0, ?) + """, + (user_id, now), + ) + + def ensure_trial_tokens(self, user_id: int, trial_tokens: int) -> dict: + amount = max(0, int(trial_tokens)) + now = int(time.time()) + with self._conn() as c: + self._ensure_wallet_row(c, user_id) + row = c.execute( + "SELECT token_balance, total_consumed_tokens FROM user_wallets WHERE user_id=?", + (user_id,), + ).fetchone() + current = int(row["token_balance"] or 0) if row else 0 + if current <= 0 and amount > 0: + c.execute( + """ + UPDATE user_wallets + SET vip_enabled=1, token_balance=?, updated_at=? + WHERE user_id=? + """, + (amount, now, user_id), + ) + return self.get_vip_status(user_id) + + def get_vip_status(self, user_id: int) -> dict: + with self._conn() as c: + self._ensure_wallet_row(c, user_id) + row = c.execute( + """ + SELECT vip_enabled, token_balance, total_consumed_tokens, updated_at + FROM user_wallets + WHERE user_id=? + """, + (user_id,), + ).fetchone() + return { + "vip_enabled": bool(int(row["vip_enabled"] or 0)) if row else False, + "token_balance": int(row["token_balance"] or 0) if row else 0, + "total_consumed_tokens": int(row["total_consumed_tokens"] or 0) if row else 0, + "updated_at": int(row["updated_at"] or 0) if row else 0, + } + + def set_vip_enabled(self, user_id: int, enabled: bool) -> dict: + now = int(time.time()) + with self._conn() as c: + self._ensure_wallet_row(c, user_id) + c.execute( + "UPDATE user_wallets SET vip_enabled=?, updated_at=? WHERE user_id=?", + (1 if enabled else 0, now, user_id), + ) + return self.get_vip_status(user_id) + + def recharge_tokens(self, user_id: int, tokens: int) -> dict: + add = max(0, int(tokens)) + now = int(time.time()) + with self._conn() as c: + self._ensure_wallet_row(c, user_id) + c.execute( + """ + UPDATE user_wallets + SET token_balance=token_balance + ?, vip_enabled=1, updated_at=? + WHERE user_id=? + """, + (add, now, user_id), + ) + return self.get_vip_status(user_id) + + def consume_tokens(self, user_id: int, tokens: int) -> tuple[bool, int]: + cost = max(0, int(tokens)) + now = int(time.time()) + with self._conn() as c: + self._ensure_wallet_row(c, user_id) + row = c.execute( + "SELECT token_balance FROM user_wallets WHERE user_id=?", + (user_id,), + ).fetchone() + balance = int(row["token_balance"] or 0) if row else 0 + if cost <= 0: + return True, balance + if balance < cost: + return False, balance + new_balance = balance - cost + c.execute( + """ + UPDATE user_wallets + SET token_balance=?, total_consumed_tokens=total_consumed_tokens + ?, updated_at=? + WHERE user_id=? + """, + (new_balance, cost, now, user_id), + ) + return True, new_balance + def save_wechat_binding( self, user_id: int, @@ -629,7 +752,7 @@ class UserStore: with self._conn() as c: rows = c.execute( """ - SELECT id, model_name, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at + SELECT id, model_name, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at FROM ai_models WHERE user_id=? ORDER BY updated_at DESC, id DESC @@ -649,6 +772,7 @@ class UserStore: "model_name": r["model_name"] or "", "base_url": r["base_url"] or "", "model": r["model"] or "", + "image_model": r["image_model"] or "", "timeout_sec": float(r["timeout_sec"] or 120.0), "max_output_tokens": int(r["max_output_tokens"] or 8192), "max_retries": int(r["max_retries"] or 0), @@ -665,6 +789,7 @@ class UserStore: api_key: str, base_url: str, model: str, + image_model: str = "", timeout_sec: float = 120.0, max_output_tokens: int = 8192, max_retries: int = 0, @@ -676,20 +801,20 @@ class UserStore: cur = c.execute( """ INSERT INTO ai_models - (user_id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + (user_id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, - (user_id, name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, now), + (user_id, name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, now), ) except sqlite3.IntegrityError: name = f"{name}-{now % 1000}" cur = c.execute( """ INSERT INTO ai_models - (user_id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + (user_id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, - (user_id, name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, now), + (user_id, name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, now), ) aid = int(cur.lastrowid) c.execute( @@ -773,7 +898,7 @@ class UserStore: if aid: row = c.execute( """ - SELECT id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at + SELECT id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at FROM ai_models WHERE id=? AND user_id=? """, @@ -782,7 +907,7 @@ class UserStore: if not row: row = c.execute( """ - SELECT id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at + SELECT id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at FROM ai_models WHERE user_id=? ORDER BY updated_at DESC, id DESC @@ -809,6 +934,7 @@ class UserStore: "api_key": row["api_key"] or "", "base_url": row["base_url"] or "", "model": row["model"] or "", + "image_model": row["image_model"] or "", "timeout_sec": float(row["timeout_sec"] or 120.0), "max_output_tokens": int(row["max_output_tokens"] or 8192), "max_retries": int(row["max_retries"] or 0), diff --git a/app/static/settings.js b/app/static/settings.js index f1c23a7..4cf104a 100644 --- a/app/static/settings.js +++ b/app/static/settings.js @@ -72,7 +72,8 @@ function renderModels(me) { list.forEach((m) => { const opt = document.createElement("option"); opt.value = String(m.id); - opt.textContent = `${m.model_name} (${m.model})`; + const imageModel = (m.image_model || "").trim(); + opt.textContent = imageModel ? `${m.model_name} (${m.model} / 图:${imageModel})` : `${m.model_name} (${m.model})`; if ((active && m.id === active) || m.active) opt.selected = true; sel.appendChild(opt); }); @@ -194,6 +195,7 @@ if (saveModelBtn) { api_key: ($("apiKey") && $("apiKey").value.trim()) || "", base_url: ($("baseUrl") && $("baseUrl").value.trim()) || "", model: ($("modelValue") && $("modelValue").value.trim()) || "", + image_model: ($("imageModelValue") && $("imageModelValue").value.trim()) || "", timeout_sec: Number((($("timeoutSec") && $("timeoutSec").value) || "120").trim()), max_output_tokens: Number((($("maxOutputTokens") && $("maxOutputTokens").value) || "8192").trim()), max_retries: Number((($("maxRetries") && $("maxRetries").value) || "0").trim()), @@ -205,6 +207,7 @@ if (saveModelBtn) { setStatus("模型配置已保存并设为当前。"); if ($("apiKey")) $("apiKey").value = ""; if ($("modelName")) $("modelName").value = ""; + if ($("imageModelValue")) $("imageModelValue").value = ""; await refresh(); } catch (e) { setStatus(e.message || "模型保存失败", true); diff --git a/app/templates/settings.html b/app/templates/settings.html index 678b13f..d0bf38a 100644 --- a/app/templates/settings.html +++ b/app/templates/settings.html @@ -82,10 +82,14 @@