1508 lines
61 KiB
Python
1508 lines
61 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import secrets
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
|
|
class UserStore:
|
|
def __init__(self, db_path: str) -> None:
|
|
self._db_path = db_path
|
|
p = Path(db_path)
|
|
if p.parent and not p.parent.exists():
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_db()
|
|
|
|
def _conn(self) -> sqlite3.Connection:
|
|
c = sqlite3.connect(self._db_path)
|
|
c.row_factory = sqlite3.Row
|
|
return c
|
|
|
|
def _init_db(self) -> None:
|
|
with self._conn() as c:
|
|
self._ensure_users_table(c)
|
|
self._ensure_sessions_table(c)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS wechat_bindings (
|
|
user_id INTEGER PRIMARY KEY,
|
|
appid TEXT NOT NULL,
|
|
secret TEXT NOT NULL,
|
|
author TEXT NOT NULL DEFAULT '',
|
|
thumb_media_id TEXT NOT NULL DEFAULT '',
|
|
thumb_image_path TEXT NOT NULL DEFAULT '',
|
|
updated_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS wechat_accounts (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
account_name TEXT NOT NULL,
|
|
appid TEXT NOT NULL,
|
|
secret TEXT NOT NULL,
|
|
author TEXT NOT NULL DEFAULT '',
|
|
thumb_media_id TEXT NOT NULL DEFAULT '',
|
|
thumb_image_path TEXT NOT NULL DEFAULT '',
|
|
updated_at INTEGER NOT NULL,
|
|
UNIQUE(user_id, account_name),
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS user_prefs (
|
|
user_id INTEGER PRIMARY KEY,
|
|
active_wechat_account_id INTEGER,
|
|
updated_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
pref_cols = self._table_columns(c, "user_prefs")
|
|
if "active_ai_model_id" not in pref_cols:
|
|
c.execute("ALTER TABLE user_prefs ADD COLUMN active_ai_model_id INTEGER")
|
|
if "subscriber_name" not in pref_cols:
|
|
c.execute("ALTER TABLE user_prefs ADD COLUMN subscriber_name TEXT NOT NULL DEFAULT ''")
|
|
if "subscriber_phone" not in pref_cols:
|
|
c.execute("ALTER TABLE user_prefs ADD COLUMN subscriber_phone TEXT NOT NULL DEFAULT ''")
|
|
if "shipping_address" not in pref_cols:
|
|
c.execute("ALTER TABLE user_prefs ADD COLUMN shipping_address TEXT NOT NULL DEFAULT ''")
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS ai_models (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
model_name TEXT NOT NULL,
|
|
api_key TEXT NOT NULL,
|
|
base_url TEXT NOT NULL DEFAULT '',
|
|
model TEXT NOT NULL,
|
|
image_model TEXT NOT NULL DEFAULT '',
|
|
timeout_sec REAL NOT NULL DEFAULT 120.0,
|
|
max_output_tokens INTEGER NOT NULL DEFAULT 8192,
|
|
max_retries INTEGER NOT NULL DEFAULT 0,
|
|
updated_at INTEGER NOT NULL,
|
|
UNIQUE(user_id, model_name),
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS user_wallets (
|
|
user_id INTEGER PRIMARY KEY,
|
|
vip_enabled INTEGER NOT NULL DEFAULT 0,
|
|
token_balance INTEGER NOT NULL DEFAULT 0,
|
|
total_consumed_tokens INTEGER NOT NULL DEFAULT 0,
|
|
seat_quota_credits INTEGER NOT NULL DEFAULT 1500,
|
|
seat_used_credits INTEGER NOT NULL DEFAULT 0,
|
|
seat_cycle TEXT NOT NULL DEFAULT '',
|
|
cycle_started_at INTEGER NOT NULL DEFAULT 0,
|
|
cycle_expires_at INTEGER NOT NULL DEFAULT 0,
|
|
updated_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
wallet_cols = self._table_columns(c, "user_wallets")
|
|
if "seat_quota_credits" not in wallet_cols:
|
|
c.execute("ALTER TABLE user_wallets ADD COLUMN seat_quota_credits INTEGER NOT NULL DEFAULT 25000")
|
|
if "seat_used_credits" not in wallet_cols:
|
|
c.execute("ALTER TABLE user_wallets ADD COLUMN seat_used_credits INTEGER NOT NULL DEFAULT 0")
|
|
if "seat_cycle" not in wallet_cols:
|
|
c.execute("ALTER TABLE user_wallets ADD COLUMN seat_cycle TEXT NOT NULL DEFAULT ''")
|
|
if "cycle_started_at" not in wallet_cols:
|
|
c.execute("ALTER TABLE user_wallets ADD COLUMN cycle_started_at INTEGER NOT NULL DEFAULT 0")
|
|
if "cycle_expires_at" not in wallet_cols:
|
|
c.execute("ALTER TABLE user_wallets ADD COLUMN cycle_expires_at INTEGER NOT NULL DEFAULT 0")
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS recharge_orders (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
order_no TEXT NOT NULL UNIQUE,
|
|
channel TEXT NOT NULL DEFAULT '',
|
|
token_amount INTEGER NOT NULL DEFAULT 0,
|
|
amount_cny REAL NOT NULL DEFAULT 0.0,
|
|
status TEXT NOT NULL DEFAULT 'pending',
|
|
external_txn_id TEXT NOT NULL DEFAULT '',
|
|
meta_json TEXT NOT NULL DEFAULT '{}',
|
|
created_at INTEGER NOT NULL,
|
|
paid_at INTEGER,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS token_ledger (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
direction TEXT NOT NULL,
|
|
token_change INTEGER NOT NULL,
|
|
balance_after INTEGER NOT NULL,
|
|
kind TEXT NOT NULL DEFAULT '',
|
|
ref_type TEXT NOT NULL DEFAULT '',
|
|
ref_id TEXT NOT NULL DEFAULT '',
|
|
detail_json TEXT NOT NULL DEFAULT '{}',
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
ai_cols = self._table_columns(c, "ai_models")
|
|
if "image_model" not in ai_cols:
|
|
c.execute("ALTER TABLE ai_models ADD COLUMN image_model TEXT NOT NULL DEFAULT ''")
|
|
# 兼容历史单绑定结构,自动迁移为默认账号
|
|
rows = c.execute(
|
|
"SELECT user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_bindings"
|
|
).fetchall()
|
|
for r in rows:
|
|
now = int(time.time())
|
|
cur = c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO wechat_accounts
|
|
(user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, '默认公众号', ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
int(r["user_id"]),
|
|
r["appid"] or "",
|
|
r["secret"] or "",
|
|
r["author"] or "",
|
|
r["thumb_media_id"] or "",
|
|
r["thumb_image_path"] or "",
|
|
int(r["updated_at"] or now),
|
|
),
|
|
)
|
|
if cur.rowcount:
|
|
aid = int(
|
|
c.execute(
|
|
"SELECT id FROM wechat_accounts WHERE user_id=? AND account_name='默认公众号'",
|
|
(int(r["user_id"]),),
|
|
).fetchone()["id"]
|
|
)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=COALESCE(user_prefs.active_wechat_account_id, excluded.active_wechat_account_id),
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(int(r["user_id"]), aid, now),
|
|
)
|
|
|
|
def _table_columns(self, c: sqlite3.Connection, table_name: str) -> set[str]:
|
|
rows = c.execute(f"PRAGMA table_info({table_name})").fetchall()
|
|
return {str(r["name"]) for r in rows}
|
|
|
|
def _ensure_users_table(self, c: sqlite3.Connection) -> None:
|
|
required = {
|
|
"id",
|
|
"username",
|
|
"password_hash",
|
|
"password_salt",
|
|
"reset_code_hash",
|
|
"reset_code_salt",
|
|
"created_at",
|
|
"deleted_at",
|
|
}
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
password_salt TEXT NOT NULL,
|
|
reset_code_hash TEXT NOT NULL DEFAULT '',
|
|
reset_code_salt TEXT NOT NULL DEFAULT '',
|
|
created_at INTEGER NOT NULL,
|
|
deleted_at INTEGER
|
|
)
|
|
"""
|
|
)
|
|
cols = self._table_columns(c, "users")
|
|
if required.issubset(cols):
|
|
return
|
|
|
|
now = int(time.time())
|
|
c.execute("PRAGMA foreign_keys=OFF")
|
|
c.execute("DROP TABLE IF EXISTS users_new")
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE users_new (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
password_salt TEXT NOT NULL,
|
|
reset_code_hash TEXT NOT NULL,
|
|
reset_code_salt TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
deleted_at INTEGER
|
|
)
|
|
"""
|
|
)
|
|
|
|
if {"username", "password_hash", "password_salt"}.issubset(cols):
|
|
rows = c.execute("SELECT * FROM users").fetchall()
|
|
for r in rows:
|
|
username = (r["username"] or "").strip()
|
|
if not username:
|
|
continue
|
|
reset_salt = r["reset_code_salt"] if "reset_code_salt" in cols else secrets.token_hex(8)
|
|
reset_hash = r["reset_code_hash"] if "reset_code_hash" in cols else ""
|
|
if not reset_hash:
|
|
legacy_code = self._generate_reset_code()
|
|
reset_hash = self._hash_reset_code(legacy_code, reset_salt)
|
|
created_at = int(r["created_at"]) if "created_at" in cols and r["created_at"] else now
|
|
deleted_at = int(r["deleted_at"]) if "deleted_at" in cols and r["deleted_at"] else None
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO users_new(
|
|
id, username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
int(r["id"]),
|
|
username,
|
|
r["password_hash"],
|
|
r["password_salt"],
|
|
reset_hash,
|
|
reset_salt,
|
|
created_at,
|
|
deleted_at,
|
|
),
|
|
)
|
|
elif {"username", "password"}.issubset(cols):
|
|
if "created_at" in cols:
|
|
rows = c.execute("SELECT id, username, password, created_at FROM users").fetchall()
|
|
else:
|
|
rows = c.execute("SELECT id, username, password FROM users").fetchall()
|
|
for r in rows:
|
|
username = (r["username"] or "").strip()
|
|
raw_pwd = str(r["password"] or "")
|
|
if not username or not raw_pwd:
|
|
continue
|
|
salt = secrets.token_hex(16)
|
|
pwd_hash = self._hash_password(raw_pwd, salt)
|
|
reset_code = self._generate_reset_code()
|
|
reset_salt = secrets.token_hex(8)
|
|
reset_hash = self._hash_reset_code(reset_code, reset_salt)
|
|
created_at = int(r["created_at"]) if "created_at" in r.keys() and r["created_at"] else now
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO users_new(
|
|
id, username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, NULL)
|
|
""",
|
|
(int(r["id"]), username, pwd_hash, salt, reset_hash, reset_salt, created_at),
|
|
)
|
|
|
|
c.execute("DROP TABLE users")
|
|
c.execute("ALTER TABLE users_new RENAME TO users")
|
|
c.execute("PRAGMA foreign_keys=ON")
|
|
|
|
def _ensure_sessions_table(self, c: sqlite3.Connection) -> None:
|
|
required = {"token_hash", "user_id", "expires_at", "created_at"}
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
token_hash TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
cols = self._table_columns(c, "sessions")
|
|
if required.issubset(cols):
|
|
return
|
|
c.execute("DROP TABLE IF EXISTS sessions")
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE sessions (
|
|
token_hash TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
|
|
def _hash_password(self, password: str, salt: str) -> str:
|
|
data = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt.encode("utf-8"), 120_000)
|
|
return data.hex()
|
|
|
|
def _hash_token(self, token: str) -> str:
|
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
|
|
def _generate_reset_code(self) -> str:
|
|
return secrets.token_urlsafe(9).replace("-", "").replace("_", "")[:12]
|
|
|
|
def _hash_reset_code(self, reset_code: str, salt: str) -> str:
|
|
return hashlib.sha256(f"{salt}:{reset_code}".encode("utf-8")).hexdigest()
|
|
|
|
def create_user(self, username: str, password: str) -> dict | None:
|
|
now = int(time.time())
|
|
salt = secrets.token_hex(16)
|
|
pwd_hash = self._hash_password(password, salt)
|
|
reset_code = self._generate_reset_code()
|
|
reset_salt = secrets.token_hex(8)
|
|
reset_hash = self._hash_reset_code(reset_code, reset_salt)
|
|
try:
|
|
with self._conn() as c:
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO users(
|
|
username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, NULL)
|
|
""",
|
|
(username, pwd_hash, salt, reset_hash, reset_salt, now),
|
|
)
|
|
uid = int(cur.lastrowid)
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO user_wallets(user_id, vip_enabled, token_balance, total_consumed_tokens, updated_at)
|
|
VALUES (?, 0, 0, 0, ?)
|
|
""",
|
|
(uid, now),
|
|
)
|
|
return {"id": uid, "username": username, "reset_code": reset_code}
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
except sqlite3.Error as exc:
|
|
raise RuntimeError(f"create_user_db_error: {exc}") from exc
|
|
|
|
def verify_user(self, username: str, password: str) -> dict | None:
|
|
try:
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, username, password_hash, password_salt
|
|
FROM users
|
|
WHERE username=? AND deleted_at IS NULL
|
|
""",
|
|
(username,),
|
|
).fetchone()
|
|
except sqlite3.Error as exc:
|
|
raise RuntimeError(f"verify_user_db_error: {exc}") from exc
|
|
if not row:
|
|
return None
|
|
calc = self._hash_password(password, row["password_salt"])
|
|
if not hmac.compare_digest(calc, row["password_hash"]):
|
|
return None
|
|
return {"id": int(row["id"]), "username": row["username"]}
|
|
|
|
def change_password(self, user_id: int, old_password: str, new_password: str) -> bool:
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT password_hash, password_salt FROM users WHERE id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
calc_old = self._hash_password(old_password, row["password_salt"])
|
|
if not hmac.compare_digest(calc_old, row["password_hash"]):
|
|
return False
|
|
new_salt = secrets.token_hex(16)
|
|
new_hash = self._hash_password(new_password, new_salt)
|
|
c.execute(
|
|
"UPDATE users SET password_hash=?, password_salt=? WHERE id=?",
|
|
(new_hash, new_salt, user_id),
|
|
)
|
|
return True
|
|
|
|
def reset_password_by_username(self, username: str, reset_code: str, new_password: str) -> bool:
|
|
uname = (username or "").strip()
|
|
rcode = (reset_code or "").strip()
|
|
if not uname:
|
|
return False
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, reset_code_hash, reset_code_salt
|
|
FROM users
|
|
WHERE username=? AND deleted_at IS NULL
|
|
""",
|
|
(uname,),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
if not rcode:
|
|
return False
|
|
calc = self._hash_reset_code(rcode, row["reset_code_salt"] or "")
|
|
if not hmac.compare_digest(calc, row["reset_code_hash"] or ""):
|
|
return False
|
|
new_salt = secrets.token_hex(16)
|
|
new_hash = self._hash_password(new_password, new_salt)
|
|
c.execute(
|
|
"UPDATE users SET password_hash=?, password_salt=? WHERE id=?",
|
|
(new_hash, new_salt, int(row["id"])),
|
|
)
|
|
return True
|
|
|
|
def create_session(self, user_id: int, ttl_seconds: int = 7 * 24 * 3600) -> str:
|
|
token = secrets.token_urlsafe(32)
|
|
token_hash = self._hash_token(token)
|
|
now = int(time.time())
|
|
exp = now + max(600, int(ttl_seconds))
|
|
with self._conn() as c:
|
|
c.execute(
|
|
"INSERT OR REPLACE INTO sessions(token_hash, user_id, expires_at, created_at) VALUES (?, ?, ?, ?)",
|
|
(token_hash, user_id, exp, now),
|
|
)
|
|
return token
|
|
|
|
def delete_session(self, token: str) -> None:
|
|
if not token:
|
|
return
|
|
with self._conn() as c:
|
|
c.execute("DELETE FROM sessions WHERE token_hash=?", (self._hash_token(token),))
|
|
|
|
def delete_sessions_by_user(self, user_id: int) -> None:
|
|
with self._conn() as c:
|
|
c.execute("DELETE FROM sessions WHERE user_id=?", (user_id,))
|
|
|
|
def get_user_by_session(self, token: str) -> dict | None:
|
|
if not token:
|
|
return None
|
|
now = int(time.time())
|
|
th = self._hash_token(token)
|
|
with self._conn() as c:
|
|
c.execute("DELETE FROM sessions WHERE expires_at < ?", (now,))
|
|
row = c.execute(
|
|
"""
|
|
SELECT u.id, u.username
|
|
FROM sessions s
|
|
JOIN users u ON u.id=s.user_id
|
|
WHERE s.token_hash=? AND s.expires_at>=? AND u.deleted_at IS NULL
|
|
""",
|
|
(th, now),
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
return {"id": int(row["id"]), "username": row["username"]}
|
|
|
|
def get_user_profile(self, user_id: int) -> dict:
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"""
|
|
SELECT subscriber_name, subscriber_phone, shipping_address
|
|
FROM user_prefs
|
|
WHERE user_id=?
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
if not row:
|
|
return {"subscriber_name": "", "subscriber_phone": "", "shipping_address": ""}
|
|
return {
|
|
"subscriber_name": (row["subscriber_name"] or "").strip(),
|
|
"subscriber_phone": (row["subscriber_phone"] or "").strip(),
|
|
"shipping_address": (row["shipping_address"] or "").strip(),
|
|
}
|
|
|
|
def save_user_profile(
|
|
self,
|
|
user_id: int,
|
|
*,
|
|
subscriber_name: str,
|
|
subscriber_phone: str,
|
|
shipping_address: str,
|
|
) -> dict:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(
|
|
user_id, active_wechat_account_id, active_ai_model_id,
|
|
subscriber_name, subscriber_phone, shipping_address, updated_at
|
|
) VALUES (?, NULL, NULL, ?, ?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
subscriber_name=excluded.subscriber_name,
|
|
subscriber_phone=excluded.subscriber_phone,
|
|
shipping_address=excluded.shipping_address,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(
|
|
user_id,
|
|
(subscriber_name or "").strip(),
|
|
(subscriber_phone or "").strip(),
|
|
(shipping_address or "").strip(),
|
|
now,
|
|
),
|
|
)
|
|
return self.get_user_profile(user_id)
|
|
|
|
def delete_user_logically(self, user_id: int, password: str, reset_code: str) -> bool:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, password_hash, password_salt, reset_code_hash, reset_code_salt
|
|
FROM users
|
|
WHERE id=? AND deleted_at IS NULL
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
calc_pwd = self._hash_password(password or "", row["password_salt"] or "")
|
|
if not hmac.compare_digest(calc_pwd, row["password_hash"] or ""):
|
|
return False
|
|
calc_reset = self._hash_reset_code(reset_code or "", row["reset_code_salt"] or "")
|
|
if not hmac.compare_digest(calc_reset, row["reset_code_hash"] or ""):
|
|
return False
|
|
|
|
c.execute("DELETE FROM sessions WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM wechat_accounts WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM ai_models WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM user_prefs WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM wechat_bindings WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM user_wallets WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM recharge_orders WHERE user_id=?", (user_id,))
|
|
c.execute("DELETE FROM token_ledger WHERE user_id=?", (user_id,))
|
|
c.execute(
|
|
"UPDATE users SET deleted_at=?, username=username || '#deleted' || ? WHERE id=?",
|
|
(now, str(now), user_id),
|
|
)
|
|
return True
|
|
|
|
def _ensure_wallet_row(self, c: sqlite3.Connection, user_id: int) -> None:
|
|
now = int(time.time())
|
|
cycle = time.strftime("%Y-%m", time.localtime(now))
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO user_wallets(
|
|
user_id, vip_enabled, token_balance, total_consumed_tokens,
|
|
seat_quota_credits, seat_used_credits, seat_cycle, cycle_started_at, cycle_expires_at, updated_at
|
|
)
|
|
VALUES (?, 0, 0, 0, 1500, 0, ?, 0, 0, ?)
|
|
""",
|
|
(user_id, cycle, now),
|
|
)
|
|
|
|
def _refresh_billing_cycle(self, c: sqlite3.Connection, user_id: int) -> None:
|
|
now = int(time.time())
|
|
row = c.execute(
|
|
"SELECT seat_cycle, cycle_expires_at, token_balance, seat_used_credits FROM user_wallets WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
current_cycle = (row["seat_cycle"] or "") if row else ""
|
|
expires_at = int(row["cycle_expires_at"] or 0) if row else 0
|
|
if expires_at > 0 and now >= expires_at:
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET seat_used_credits=0, token_balance=0, seat_cycle='', cycle_started_at=0, cycle_expires_at=0, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(now, user_id),
|
|
)
|
|
return
|
|
# 兼容历史按自然月 seat_cycle 的老数据:若没有新周期字段,保留原行为
|
|
if expires_at <= 0:
|
|
paid = c.execute(
|
|
"""
|
|
SELECT paid_at
|
|
FROM recharge_orders
|
|
WHERE user_id=? AND status='paid' AND paid_at IS NOT NULL
|
|
ORDER BY paid_at DESC, id DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
if paid and int(paid["paid_at"] or 0) > 0:
|
|
start_at = int(paid["paid_at"])
|
|
new_expires = start_at + 30 * 24 * 3600
|
|
if now >= new_expires:
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET seat_used_credits=0, token_balance=0, seat_cycle='', cycle_started_at=0, cycle_expires_at=0, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(now, user_id),
|
|
)
|
|
else:
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET seat_cycle=?, cycle_started_at=?, cycle_expires_at=?, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(time.strftime("%Y-%m", time.localtime(start_at)), start_at, new_expires, now, user_id),
|
|
)
|
|
return
|
|
cycle = time.strftime("%Y-%m", time.localtime(now))
|
|
if current_cycle != cycle:
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET seat_used_credits=0, seat_cycle=?, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(cycle, now, user_id),
|
|
)
|
|
|
|
def ensure_trial_tokens(self, user_id: int, trial_tokens: int) -> dict:
|
|
amount = max(0, int(trial_tokens))
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
self._ensure_wallet_row(c, user_id)
|
|
row = c.execute(
|
|
"SELECT token_balance, total_consumed_tokens FROM user_wallets WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
current = int(row["token_balance"] or 0) if row else 0
|
|
if current <= 0 and amount > 0:
|
|
new_balance = amount
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET vip_enabled=1, token_balance=?, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(new_balance, now, user_id),
|
|
)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO token_ledger(
|
|
user_id, direction, token_change, balance_after, kind, ref_type, ref_id, detail_json, created_at
|
|
) VALUES (?, 'in', ?, ?, 'trial_grant', 'system', '', '{}', ?)
|
|
""",
|
|
(user_id, amount, new_balance, now),
|
|
)
|
|
return self.get_vip_status(user_id)
|
|
|
|
def get_vip_status(self, user_id: int) -> dict:
|
|
with self._conn() as c:
|
|
self._ensure_wallet_row(c, user_id)
|
|
self._refresh_billing_cycle(c, user_id)
|
|
row = c.execute(
|
|
"""
|
|
SELECT
|
|
vip_enabled, token_balance, total_consumed_tokens,
|
|
seat_quota_credits, seat_used_credits, seat_cycle, cycle_started_at, cycle_expires_at, updated_at
|
|
FROM user_wallets
|
|
WHERE user_id=?
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
seat_quota = int(row["seat_quota_credits"] or 0) if row else 0
|
|
seat_used = int(row["seat_used_credits"] or 0) if row else 0
|
|
seat_remaining = max(0, seat_quota - seat_used)
|
|
shared_credits = int(row["token_balance"] or 0) if row else 0
|
|
cycle_started_at = int(row["cycle_started_at"] or 0) if row else 0
|
|
cycle_expires_at = int(row["cycle_expires_at"] or 0) if row else 0
|
|
now = int(time.time())
|
|
cycle_active = cycle_expires_at > now if cycle_expires_at > 0 else True
|
|
if not cycle_active:
|
|
seat_remaining = 0
|
|
shared_credits = 0
|
|
return {
|
|
"vip_enabled": bool(int(row["vip_enabled"] or 0)) if row else False,
|
|
"token_balance": shared_credits,
|
|
"total_consumed_tokens": int(row["total_consumed_tokens"] or 0) if row else 0,
|
|
"seat_quota_credits": seat_quota,
|
|
"seat_used_credits": seat_used,
|
|
"seat_remaining_credits": seat_remaining,
|
|
"shared_credits": shared_credits,
|
|
"total_available_credits": seat_remaining + shared_credits,
|
|
"seat_cycle": (row["seat_cycle"] or "") if row else "",
|
|
"cycle_started_at": cycle_started_at,
|
|
"cycle_expires_at": cycle_expires_at,
|
|
"cycle_active": cycle_active,
|
|
"updated_at": int(row["updated_at"] or 0) if row else 0,
|
|
}
|
|
|
|
def set_vip_enabled(self, user_id: int, enabled: bool) -> dict:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
self._ensure_wallet_row(c, user_id)
|
|
c.execute(
|
|
"UPDATE user_wallets SET vip_enabled=?, updated_at=? WHERE user_id=?",
|
|
(1 if enabled else 0, now, user_id),
|
|
)
|
|
return self.get_vip_status(user_id)
|
|
|
|
def recharge_tokens(
|
|
self,
|
|
user_id: int,
|
|
tokens: int,
|
|
*,
|
|
kind: str = "manual_recharge",
|
|
ref_type: str = "",
|
|
ref_id: str = "",
|
|
detail: dict | None = None,
|
|
cycle_start_at: int | None = None,
|
|
cycle_days: int = 30,
|
|
) -> dict:
|
|
add = max(0, int(tokens))
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
self._ensure_wallet_row(c, user_id)
|
|
self._refresh_billing_cycle(c, user_id)
|
|
row = c.execute("SELECT token_balance FROM user_wallets WHERE user_id=?", (user_id,)).fetchone()
|
|
prev = int(row["token_balance"] or 0) if row else 0
|
|
new_balance = prev + add
|
|
start_at = int(cycle_start_at or 0)
|
|
if start_at > 0:
|
|
expires_at = start_at + max(1, int(cycle_days)) * 24 * 3600
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET token_balance=token_balance + ?, vip_enabled=1, seat_used_credits=0, seat_cycle=?, cycle_started_at=?, cycle_expires_at=?, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(add, time.strftime("%Y-%m", time.localtime(start_at)), start_at, expires_at, now, user_id),
|
|
)
|
|
else:
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET token_balance=token_balance + ?, vip_enabled=1, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(add, now, user_id),
|
|
)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO token_ledger(
|
|
user_id, direction, token_change, balance_after, kind, ref_type, ref_id, detail_json, created_at
|
|
) VALUES (?, 'in', ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
user_id,
|
|
add,
|
|
new_balance,
|
|
kind,
|
|
ref_type,
|
|
ref_id or "",
|
|
json.dumps(detail or {}, ensure_ascii=True),
|
|
now,
|
|
),
|
|
)
|
|
return self.get_vip_status(user_id)
|
|
|
|
def consume_tokens(
|
|
self,
|
|
user_id: int,
|
|
tokens: int,
|
|
*,
|
|
kind: str = "usage",
|
|
ref_type: str = "",
|
|
ref_id: str = "",
|
|
detail: dict | None = None,
|
|
) -> tuple[bool, int]:
|
|
cost = max(0, int(tokens))
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
self._ensure_wallet_row(c, user_id)
|
|
self._refresh_billing_cycle(c, user_id)
|
|
row = c.execute(
|
|
"SELECT token_balance, seat_quota_credits, seat_used_credits FROM user_wallets WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
shared_balance = int(row["token_balance"] or 0) if row else 0
|
|
seat_quota = int(row["seat_quota_credits"] or 0) if row else 0
|
|
seat_used = int(row["seat_used_credits"] or 0) if row else 0
|
|
seat_remaining = max(0, seat_quota - seat_used)
|
|
if cost <= 0:
|
|
return True, seat_remaining + shared_balance
|
|
use_from_seat = min(seat_remaining, cost)
|
|
need_shared = cost - use_from_seat
|
|
if shared_balance < need_shared:
|
|
return False, seat_remaining + shared_balance
|
|
new_shared = shared_balance - need_shared
|
|
new_seat_used = seat_used + use_from_seat
|
|
c.execute(
|
|
"""
|
|
UPDATE user_wallets
|
|
SET token_balance=?, seat_used_credits=?, total_consumed_tokens=total_consumed_tokens + ?, updated_at=?
|
|
WHERE user_id=?
|
|
""",
|
|
(new_shared, new_seat_used, cost, now, user_id),
|
|
)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO token_ledger(
|
|
user_id, direction, token_change, balance_after, kind, ref_type, ref_id, detail_json, created_at
|
|
) VALUES (?, 'out', ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
user_id,
|
|
cost,
|
|
max(0, seat_quota - new_seat_used) + new_shared,
|
|
kind,
|
|
ref_type,
|
|
ref_id or "",
|
|
json.dumps(
|
|
{
|
|
**(detail or {}),
|
|
"credit_source": {"seat": use_from_seat, "shared": need_shared},
|
|
},
|
|
ensure_ascii=True,
|
|
),
|
|
now,
|
|
),
|
|
)
|
|
return True, max(0, seat_quota - new_seat_used) + new_shared
|
|
|
|
def create_recharge_order(
|
|
self,
|
|
user_id: int,
|
|
order_no: str,
|
|
channel: str,
|
|
token_amount: int,
|
|
amount_cny: float,
|
|
meta: dict | None = None,
|
|
) -> dict:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO recharge_orders(
|
|
user_id, order_no, channel, token_amount, amount_cny, status, external_txn_id, meta_json, created_at, paid_at
|
|
) VALUES (?, ?, ?, ?, ?, 'pending', '', ?, ?, NULL)
|
|
""",
|
|
(
|
|
user_id,
|
|
order_no,
|
|
channel or "",
|
|
int(token_amount),
|
|
float(amount_cny),
|
|
json.dumps(meta or {}, ensure_ascii=True),
|
|
now,
|
|
),
|
|
)
|
|
return {
|
|
"order_no": order_no,
|
|
"channel": channel,
|
|
"token_amount": int(token_amount),
|
|
"amount_cny": float(amount_cny),
|
|
"status": "pending",
|
|
"created_at": now,
|
|
}
|
|
|
|
def mark_recharge_order_paid(
|
|
self,
|
|
user_id: int,
|
|
order_no: str,
|
|
paid_amount_cny: float,
|
|
external_txn_id: str = "",
|
|
meta: dict | None = None,
|
|
) -> tuple[bool, str]:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"""
|
|
SELECT user_id, token_amount, amount_cny, status
|
|
FROM recharge_orders
|
|
WHERE order_no=?
|
|
""",
|
|
(order_no,),
|
|
).fetchone()
|
|
if not row:
|
|
return False, "订单不存在"
|
|
if int(row["user_id"]) != int(user_id):
|
|
return False, "订单无权限"
|
|
if (row["status"] or "") == "paid":
|
|
return True, "already_paid"
|
|
if float(paid_amount_cny or 0.0) + 1e-9 < float(row["amount_cny"] or 0.0):
|
|
return False, "支付金额不足"
|
|
c.execute(
|
|
"""
|
|
UPDATE recharge_orders
|
|
SET status='paid', external_txn_id=?, paid_at=?, meta_json=?
|
|
WHERE order_no=?
|
|
""",
|
|
(
|
|
external_txn_id or "",
|
|
now,
|
|
json.dumps(meta or {}, ensure_ascii=True),
|
|
order_no,
|
|
),
|
|
)
|
|
self.recharge_tokens(
|
|
user_id,
|
|
int(row["token_amount"] or 0),
|
|
kind="paid_recharge",
|
|
ref_type="order",
|
|
ref_id=order_no,
|
|
detail={"paid_amount_cny": float(paid_amount_cny or 0.0), "external_txn_id": external_txn_id or ""},
|
|
cycle_start_at=now,
|
|
cycle_days=30,
|
|
)
|
|
return True, "ok"
|
|
|
|
def list_recharge_orders(self, user_id: int, limit: int = 50) -> list[dict]:
|
|
with self._conn() as c:
|
|
now = int(time.time())
|
|
expire_before = now - 15 * 60
|
|
c.execute(
|
|
"""
|
|
UPDATE recharge_orders
|
|
SET status='cancelled'
|
|
WHERE user_id=? AND status='pending' AND created_at<=?
|
|
""",
|
|
(user_id, expire_before),
|
|
)
|
|
rows = c.execute(
|
|
"""
|
|
SELECT order_no, channel, token_amount, amount_cny, status, external_txn_id, created_at, paid_at, meta_json
|
|
FROM recharge_orders
|
|
WHERE user_id=?
|
|
ORDER BY id DESC
|
|
LIMIT ?
|
|
""",
|
|
(user_id, max(1, min(int(limit), 200))),
|
|
).fetchall()
|
|
return [
|
|
{
|
|
"order_no": r["order_no"] or "",
|
|
"channel": r["channel"] or "",
|
|
"token_amount": int(r["token_amount"] or 0),
|
|
"amount_cny": float(r["amount_cny"] or 0.0),
|
|
"status": r["status"] or "",
|
|
"external_txn_id": r["external_txn_id"] or "",
|
|
"created_at": int(r["created_at"] or 0),
|
|
"paid_at": int(r["paid_at"] or 0) if r["paid_at"] else None,
|
|
"meta": json.loads(r["meta_json"] or "{}"),
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
def get_recharge_order(self, user_id: int, order_no: str) -> dict | None:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
expire_before = now - 15 * 60
|
|
c.execute(
|
|
"""
|
|
UPDATE recharge_orders
|
|
SET status='cancelled'
|
|
WHERE user_id=? AND status='pending' AND created_at<=?
|
|
""",
|
|
(user_id, expire_before),
|
|
)
|
|
row = c.execute(
|
|
"""
|
|
SELECT order_no, channel, token_amount, amount_cny, status, external_txn_id, created_at, paid_at, meta_json
|
|
FROM recharge_orders
|
|
WHERE user_id=? AND order_no=?
|
|
LIMIT 1
|
|
""",
|
|
(user_id, order_no),
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
try:
|
|
meta = json.loads(row["meta_json"] or "{}")
|
|
except Exception:
|
|
meta = {}
|
|
return {
|
|
"order_no": row["order_no"] or "",
|
|
"channel": row["channel"] or "",
|
|
"token_amount": int(row["token_amount"] or 0),
|
|
"amount_cny": float(row["amount_cny"] or 0.0),
|
|
"status": row["status"] or "",
|
|
"external_txn_id": row["external_txn_id"] or "",
|
|
"created_at": int(row["created_at"] or 0),
|
|
"paid_at": int(row["paid_at"] or 0) if row["paid_at"] else None,
|
|
"meta": meta,
|
|
}
|
|
|
|
def get_recharge_order_user_id(self, order_no: str) -> int | None:
|
|
with self._conn() as c:
|
|
row = c.execute("SELECT user_id FROM recharge_orders WHERE order_no=?", (order_no,)).fetchone()
|
|
return int(row["user_id"]) if row and row["user_id"] else None
|
|
|
|
def list_token_ledger(self, user_id: int, limit: int = 100) -> list[dict]:
|
|
with self._conn() as c:
|
|
rows = c.execute(
|
|
"""
|
|
SELECT direction, token_change, balance_after, kind, ref_type, ref_id, detail_json, created_at
|
|
FROM token_ledger
|
|
WHERE user_id=?
|
|
ORDER BY id DESC
|
|
LIMIT ?
|
|
""",
|
|
(user_id, max(1, min(int(limit), 500))),
|
|
).fetchall()
|
|
out: list[dict] = []
|
|
for r in rows:
|
|
try:
|
|
detail = json.loads(r["detail_json"] or "{}")
|
|
except Exception:
|
|
detail = {}
|
|
out.append(
|
|
{
|
|
"direction": r["direction"] or "",
|
|
"token_change": int(r["token_change"] or 0),
|
|
"balance_after": int(r["balance_after"] or 0),
|
|
"kind": r["kind"] or "",
|
|
"ref_type": r["ref_type"] or "",
|
|
"ref_id": r["ref_id"] or "",
|
|
"detail": detail,
|
|
"created_at": int(r["created_at"] or 0),
|
|
}
|
|
)
|
|
return out
|
|
|
|
def save_wechat_binding(
|
|
self,
|
|
user_id: int,
|
|
appid: str,
|
|
secret: str,
|
|
author: str = "",
|
|
thumb_media_id: str = "",
|
|
thumb_image_path: str = "",
|
|
) -> None:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO wechat_bindings(user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
appid=excluded.appid,
|
|
secret=excluded.secret,
|
|
author=excluded.author,
|
|
thumb_media_id=excluded.thumb_media_id,
|
|
thumb_image_path=excluded.thumb_image_path,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, appid, secret, author, thumb_media_id, thumb_image_path, now),
|
|
)
|
|
|
|
def get_wechat_binding(self, user_id: int) -> dict | None:
|
|
return self.get_active_wechat_binding(user_id)
|
|
|
|
def list_wechat_bindings(self, user_id: int) -> list[dict]:
|
|
with self._conn() as c:
|
|
rows = c.execute(
|
|
"""
|
|
SELECT id, account_name, appid, author, thumb_media_id, thumb_image_path, updated_at
|
|
FROM wechat_accounts
|
|
WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
""",
|
|
(user_id,),
|
|
).fetchall()
|
|
pref = c.execute(
|
|
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
active_id = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
|
|
out: list[dict] = []
|
|
for r in rows:
|
|
out.append(
|
|
{
|
|
"id": int(r["id"]),
|
|
"account_name": r["account_name"] or "",
|
|
"appid": r["appid"] or "",
|
|
"author": r["author"] or "",
|
|
"thumb_media_id": r["thumb_media_id"] or "",
|
|
"thumb_image_path": r["thumb_image_path"] or "",
|
|
"updated_at": int(r["updated_at"] or 0),
|
|
"active": int(r["id"]) == active_id,
|
|
}
|
|
)
|
|
return out
|
|
|
|
def add_wechat_binding(
|
|
self,
|
|
user_id: int,
|
|
account_name: str,
|
|
appid: str,
|
|
secret: str,
|
|
author: str = "",
|
|
thumb_media_id: str = "",
|
|
thumb_image_path: str = "",
|
|
) -> dict:
|
|
now = int(time.time())
|
|
name = account_name.strip() or f"公众号{now % 10000}"
|
|
with self._conn() as c:
|
|
try:
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO wechat_accounts
|
|
(user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, name, appid, secret, author, thumb_media_id, thumb_image_path, now),
|
|
)
|
|
except sqlite3.IntegrityError:
|
|
name = f"{name}-{now % 1000}"
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO wechat_accounts
|
|
(user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, name, appid, secret, author, thumb_media_id, thumb_image_path, now),
|
|
)
|
|
aid = int(cur.lastrowid)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, aid, now),
|
|
)
|
|
return {"id": aid, "account_name": name}
|
|
|
|
def switch_active_wechat_binding(self, user_id: int, account_id: int) -> bool:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT id FROM wechat_accounts WHERE id=? AND user_id=?",
|
|
(account_id, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, account_id, now),
|
|
)
|
|
return True
|
|
|
|
def get_active_wechat_binding(self, user_id: int) -> dict | None:
|
|
with self._conn() as c:
|
|
pref = c.execute(
|
|
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
aid = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
|
|
row = None
|
|
if aid:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at
|
|
FROM wechat_accounts
|
|
WHERE id=? AND user_id=?
|
|
""",
|
|
(aid, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at
|
|
FROM wechat_accounts
|
|
WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
if row:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, int(row["id"]), int(time.time())),
|
|
)
|
|
if not row:
|
|
return None
|
|
return {
|
|
"id": int(row["id"]),
|
|
"account_name": row["account_name"] or "",
|
|
"appid": row["appid"] or "",
|
|
"secret": row["secret"] or "",
|
|
"author": row["author"] or "",
|
|
"thumb_media_id": row["thumb_media_id"] or "",
|
|
"thumb_image_path": row["thumb_image_path"] or "",
|
|
"updated_at": int(row["updated_at"] or 0),
|
|
}
|
|
|
|
def list_ai_models(self, user_id: int) -> list[dict]:
|
|
with self._conn() as c:
|
|
rows = c.execute(
|
|
"""
|
|
SELECT id, model_name, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at
|
|
FROM ai_models
|
|
WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
""",
|
|
(user_id,),
|
|
).fetchall()
|
|
pref = c.execute(
|
|
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
active_id = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
|
out: list[dict] = []
|
|
for r in rows:
|
|
out.append(
|
|
{
|
|
"id": int(r["id"]),
|
|
"model_name": r["model_name"] or "",
|
|
"base_url": r["base_url"] or "",
|
|
"model": r["model"] or "",
|
|
"image_model": r["image_model"] or "",
|
|
"timeout_sec": float(r["timeout_sec"] or 120.0),
|
|
"max_output_tokens": int(r["max_output_tokens"] or 8192),
|
|
"max_retries": int(r["max_retries"] or 0),
|
|
"updated_at": int(r["updated_at"] or 0),
|
|
"active": int(r["id"]) == active_id,
|
|
}
|
|
)
|
|
return out
|
|
|
|
def add_ai_model(
|
|
self,
|
|
user_id: int,
|
|
model_name: str,
|
|
api_key: str,
|
|
base_url: str,
|
|
model: str,
|
|
image_model: str = "",
|
|
timeout_sec: float = 120.0,
|
|
max_output_tokens: int = 8192,
|
|
max_retries: int = 0,
|
|
) -> dict:
|
|
now = int(time.time())
|
|
name = model_name.strip() or f"模型{now % 10000}"
|
|
with self._conn() as c:
|
|
try:
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO ai_models
|
|
(user_id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, now),
|
|
)
|
|
except sqlite3.IntegrityError:
|
|
name = f"{name}-{now % 1000}"
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO ai_models
|
|
(user_id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, now),
|
|
)
|
|
aid = int(cur.lastrowid)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_ai_model_id=excluded.active_ai_model_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, aid, now),
|
|
)
|
|
return {"id": aid, "model_name": name}
|
|
|
|
def switch_active_ai_model(self, user_id: int, model_id: int) -> bool:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT id FROM ai_models WHERE id=? AND user_id=?",
|
|
(model_id, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_ai_model_id=excluded.active_ai_model_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, model_id, now),
|
|
)
|
|
return True
|
|
|
|
def delete_ai_model(self, user_id: int, model_id: int) -> bool:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT id FROM ai_models WHERE id=? AND user_id=?",
|
|
(model_id, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
c.execute("DELETE FROM ai_models WHERE id=? AND user_id=?", (model_id, user_id))
|
|
pref = c.execute(
|
|
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
active_id = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
|
if active_id == model_id:
|
|
replacement = c.execute(
|
|
"""
|
|
SELECT id FROM ai_models WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
replacement_id = int(replacement["id"]) if replacement else None
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_ai_model_id=excluded.active_ai_model_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, replacement_id, now),
|
|
)
|
|
return True
|
|
|
|
def update_active_ai_image_model(self, user_id: int, image_model: str) -> bool:
|
|
now = int(time.time())
|
|
name = (image_model or "").strip()
|
|
with self._conn() as c:
|
|
pref = c.execute(
|
|
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
aid = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
|
if not aid:
|
|
row = c.execute(
|
|
"SELECT id FROM ai_models WHERE user_id=? ORDER BY updated_at DESC, id DESC LIMIT 1",
|
|
(user_id,),
|
|
).fetchone()
|
|
aid = int(row["id"]) if row else None
|
|
if not aid:
|
|
return False
|
|
c.execute(
|
|
"UPDATE ai_models SET image_model=?, updated_at=? WHERE id=? AND user_id=?",
|
|
(name, now, aid, user_id),
|
|
)
|
|
if c.total_changes <= 0:
|
|
return False
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_ai_model_id=excluded.active_ai_model_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, aid, now),
|
|
)
|
|
return True
|
|
|
|
def get_active_ai_model(self, user_id: int) -> dict | None:
|
|
with self._conn() as c:
|
|
pref = c.execute(
|
|
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
aid = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
|
row = None
|
|
if aid:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at
|
|
FROM ai_models
|
|
WHERE id=? AND user_id=?
|
|
""",
|
|
(aid, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, model_name, api_key, base_url, model, image_model, timeout_sec, max_output_tokens, max_retries, updated_at
|
|
FROM ai_models
|
|
WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
if row:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_ai_model_id=excluded.active_ai_model_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, int(row["id"]), int(time.time())),
|
|
)
|
|
if not row:
|
|
return None
|
|
return {
|
|
"id": int(row["id"]),
|
|
"model_name": row["model_name"] or "",
|
|
"api_key": row["api_key"] or "",
|
|
"base_url": row["base_url"] or "",
|
|
"model": row["model"] or "",
|
|
"image_model": row["image_model"] or "",
|
|
"timeout_sec": float(row["timeout_sec"] or 120.0),
|
|
"max_output_tokens": int(row["max_output_tokens"] or 8192),
|
|
"max_retries": int(row["max_retries"] or 0),
|
|
"updated_at": int(row["updated_at"] or 0),
|
|
}
|
|
|
|
def delete_wechat_binding(self, user_id: int, account_id: int) -> bool:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT id FROM wechat_accounts WHERE id=? AND user_id=?",
|
|
(account_id, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
c.execute("DELETE FROM wechat_accounts WHERE id=? AND user_id=?", (account_id, user_id))
|
|
pref = c.execute(
|
|
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
active_id = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
|
|
if active_id == account_id:
|
|
replacement = c.execute(
|
|
"""
|
|
SELECT id FROM wechat_accounts WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
replacement_id = int(replacement["id"]) if replacement else None
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, replacement_id, now),
|
|
)
|
|
return True
|