fix: bug
This commit is contained in:
@@ -66,6 +66,27 @@ class UserStore:
|
||||
)
|
||||
"""
|
||||
)
|
||||
pref_cols = self._table_columns(c, "user_prefs")
|
||||
if "active_ai_model_id" not in pref_cols:
|
||||
c.execute("ALTER TABLE user_prefs ADD COLUMN active_ai_model_id INTEGER")
|
||||
c.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS ai_models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL DEFAULT '',
|
||||
model TEXT NOT NULL,
|
||||
timeout_sec REAL NOT NULL DEFAULT 120.0,
|
||||
max_output_tokens INTEGER NOT NULL DEFAULT 8192,
|
||||
max_retries INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at INTEGER NOT NULL,
|
||||
UNIQUE(user_id, model_name),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
# 兼容历史单绑定结构,自动迁移为默认账号
|
||||
rows = c.execute(
|
||||
"SELECT user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_bindings"
|
||||
@@ -111,7 +132,16 @@ class UserStore:
|
||||
return {str(r["name"]) for r in rows}
|
||||
|
||||
def _ensure_users_table(self, c: sqlite3.Connection) -> None:
|
||||
required = {"id", "username", "password_hash", "password_salt", "created_at"}
|
||||
required = {
|
||||
"id",
|
||||
"username",
|
||||
"password_hash",
|
||||
"password_salt",
|
||||
"reset_code_hash",
|
||||
"reset_code_salt",
|
||||
"created_at",
|
||||
"deleted_at",
|
||||
}
|
||||
c.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
@@ -119,7 +149,10 @@ class UserStore:
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
password_salt TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
reset_code_hash TEXT NOT NULL DEFAULT '',
|
||||
reset_code_salt TEXT NOT NULL DEFAULT '',
|
||||
created_at INTEGER NOT NULL,
|
||||
deleted_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -137,29 +170,44 @@ class UserStore:
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
password_salt TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
reset_code_hash TEXT NOT NULL,
|
||||
reset_code_salt TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
deleted_at INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
if {"username", "password_hash", "password_salt"}.issubset(cols):
|
||||
if "created_at" in cols:
|
||||
rows = c.execute("SELECT * FROM users").fetchall()
|
||||
for r in rows:
|
||||
username = (r["username"] or "").strip()
|
||||
if not username:
|
||||
continue
|
||||
reset_salt = r["reset_code_salt"] if "reset_code_salt" in cols else secrets.token_hex(8)
|
||||
reset_hash = r["reset_code_hash"] if "reset_code_hash" in cols else ""
|
||||
if not reset_hash:
|
||||
legacy_code = self._generate_reset_code()
|
||||
reset_hash = self._hash_reset_code(legacy_code, reset_salt)
|
||||
created_at = int(r["created_at"]) if "created_at" in cols and r["created_at"] else now
|
||||
deleted_at = int(r["deleted_at"]) if "deleted_at" in cols and r["deleted_at"] else None
|
||||
c.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
|
||||
SELECT id, username, password_hash, password_salt, COALESCE(created_at, ?)
|
||||
FROM users
|
||||
INSERT OR IGNORE INTO users_new(
|
||||
id, username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(now,),
|
||||
)
|
||||
else:
|
||||
c.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
|
||||
SELECT id, username, password_hash, password_salt, ?
|
||||
FROM users
|
||||
""",
|
||||
(now,),
|
||||
(
|
||||
int(r["id"]),
|
||||
username,
|
||||
r["password_hash"],
|
||||
r["password_salt"],
|
||||
reset_hash,
|
||||
reset_salt,
|
||||
created_at,
|
||||
deleted_at,
|
||||
),
|
||||
)
|
||||
elif {"username", "password"}.issubset(cols):
|
||||
if "created_at" in cols:
|
||||
@@ -173,13 +221,18 @@ class UserStore:
|
||||
continue
|
||||
salt = secrets.token_hex(16)
|
||||
pwd_hash = self._hash_password(raw_pwd, salt)
|
||||
reset_code = self._generate_reset_code()
|
||||
reset_salt = secrets.token_hex(8)
|
||||
reset_hash = self._hash_reset_code(reset_code, reset_salt)
|
||||
created_at = int(r["created_at"]) if "created_at" in r.keys() and r["created_at"] else now
|
||||
c.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
INSERT OR IGNORE INTO users_new(
|
||||
id, username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, NULL)
|
||||
""",
|
||||
(int(r["id"]), username, pwd_hash, salt, created_at),
|
||||
(int(r["id"]), username, pwd_hash, salt, reset_hash, reset_salt, created_at),
|
||||
)
|
||||
|
||||
c.execute("DROP TABLE users")
|
||||
@@ -222,18 +275,31 @@ class UserStore:
|
||||
def _hash_token(self, token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def _generate_reset_code(self) -> str:
|
||||
return secrets.token_urlsafe(9).replace("-", "").replace("_", "")[:12]
|
||||
|
||||
def _hash_reset_code(self, reset_code: str, salt: str) -> str:
|
||||
return hashlib.sha256(f"{salt}:{reset_code}".encode("utf-8")).hexdigest()
|
||||
|
||||
def create_user(self, username: str, password: str) -> dict | None:
|
||||
now = int(time.time())
|
||||
salt = secrets.token_hex(16)
|
||||
pwd_hash = self._hash_password(password, salt)
|
||||
reset_code = self._generate_reset_code()
|
||||
reset_salt = secrets.token_hex(8)
|
||||
reset_hash = self._hash_reset_code(reset_code, reset_salt)
|
||||
try:
|
||||
with self._conn() as c:
|
||||
cur = c.execute(
|
||||
"INSERT INTO users(username, password_hash, password_salt, created_at) VALUES (?, ?, ?, ?)",
|
||||
(username, pwd_hash, salt, now),
|
||||
"""
|
||||
INSERT INTO users(
|
||||
username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, NULL)
|
||||
""",
|
||||
(username, pwd_hash, salt, reset_hash, reset_salt, now),
|
||||
)
|
||||
uid = int(cur.lastrowid)
|
||||
return {"id": uid, "username": username}
|
||||
return {"id": uid, "username": username, "reset_code": reset_code}
|
||||
except sqlite3.IntegrityError:
|
||||
return None
|
||||
except sqlite3.Error as exc:
|
||||
@@ -243,7 +309,11 @@ class UserStore:
|
||||
try:
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"SELECT id, username, password_hash, password_salt FROM users WHERE username=?",
|
||||
"""
|
||||
SELECT id, username, password_hash, password_salt
|
||||
FROM users
|
||||
WHERE username=? AND deleted_at IS NULL
|
||||
""",
|
||||
(username,),
|
||||
).fetchone()
|
||||
except sqlite3.Error as exc:
|
||||
@@ -274,14 +344,27 @@ class UserStore:
|
||||
)
|
||||
return True
|
||||
|
||||
def reset_password_by_username(self, username: str, new_password: str) -> bool:
|
||||
def reset_password_by_username(self, username: str, reset_code: str, new_password: str) -> bool:
|
||||
uname = (username or "").strip()
|
||||
rcode = (reset_code or "").strip()
|
||||
if not uname:
|
||||
return False
|
||||
with self._conn() as c:
|
||||
row = c.execute("SELECT id FROM users WHERE username=?", (uname,)).fetchone()
|
||||
row = c.execute(
|
||||
"""
|
||||
SELECT id, reset_code_hash, reset_code_salt
|
||||
FROM users
|
||||
WHERE username=? AND deleted_at IS NULL
|
||||
""",
|
||||
(uname,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return False
|
||||
if not rcode:
|
||||
return False
|
||||
calc = self._hash_reset_code(rcode, row["reset_code_salt"] or "")
|
||||
if not hmac.compare_digest(calc, row["reset_code_hash"] or ""):
|
||||
return False
|
||||
new_salt = secrets.token_hex(16)
|
||||
new_hash = self._hash_password(new_password, new_salt)
|
||||
c.execute(
|
||||
@@ -324,7 +407,7 @@ class UserStore:
|
||||
SELECT u.id, u.username
|
||||
FROM sessions s
|
||||
JOIN users u ON u.id=s.user_id
|
||||
WHERE s.token_hash=? AND s.expires_at>=?
|
||||
WHERE s.token_hash=? AND s.expires_at>=? AND u.deleted_at IS NULL
|
||||
""",
|
||||
(th, now),
|
||||
).fetchone()
|
||||
@@ -332,6 +415,37 @@ class UserStore:
|
||||
return None
|
||||
return {"id": int(row["id"]), "username": row["username"]}
|
||||
|
||||
def delete_user_logically(self, user_id: int, password: str, reset_code: str) -> bool:
|
||||
now = int(time.time())
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"""
|
||||
SELECT id, password_hash, password_salt, reset_code_hash, reset_code_salt
|
||||
FROM users
|
||||
WHERE id=? AND deleted_at IS NULL
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return False
|
||||
calc_pwd = self._hash_password(password or "", row["password_salt"] or "")
|
||||
if not hmac.compare_digest(calc_pwd, row["password_hash"] or ""):
|
||||
return False
|
||||
calc_reset = self._hash_reset_code(reset_code or "", row["reset_code_salt"] or "")
|
||||
if not hmac.compare_digest(calc_reset, row["reset_code_hash"] or ""):
|
||||
return False
|
||||
|
||||
c.execute("DELETE FROM sessions WHERE user_id=?", (user_id,))
|
||||
c.execute("DELETE FROM wechat_accounts WHERE user_id=?", (user_id,))
|
||||
c.execute("DELETE FROM ai_models WHERE user_id=?", (user_id,))
|
||||
c.execute("DELETE FROM user_prefs WHERE user_id=?", (user_id,))
|
||||
c.execute("DELETE FROM wechat_bindings WHERE user_id=?", (user_id,))
|
||||
c.execute(
|
||||
"UPDATE users SET deleted_at=?, username=username || '#deleted' || ? WHERE id=?",
|
||||
(now, str(now), user_id),
|
||||
)
|
||||
return True
|
||||
|
||||
def save_wechat_binding(
|
||||
self,
|
||||
user_id: int,
|
||||
@@ -510,3 +624,230 @@ class UserStore:
|
||||
"thumb_image_path": row["thumb_image_path"] or "",
|
||||
"updated_at": int(row["updated_at"] or 0),
|
||||
}
|
||||
|
||||
def list_ai_models(self, user_id: int) -> list[dict]:
|
||||
with self._conn() as c:
|
||||
rows = c.execute(
|
||||
"""
|
||||
SELECT id, model_name, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at
|
||||
FROM ai_models
|
||||
WHERE user_id=?
|
||||
ORDER BY updated_at DESC, id DESC
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchall()
|
||||
pref = c.execute(
|
||||
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
active_id = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
||||
out: list[dict] = []
|
||||
for r in rows:
|
||||
out.append(
|
||||
{
|
||||
"id": int(r["id"]),
|
||||
"model_name": r["model_name"] or "",
|
||||
"base_url": r["base_url"] or "",
|
||||
"model": r["model"] or "",
|
||||
"timeout_sec": float(r["timeout_sec"] or 120.0),
|
||||
"max_output_tokens": int(r["max_output_tokens"] or 8192),
|
||||
"max_retries": int(r["max_retries"] or 0),
|
||||
"updated_at": int(r["updated_at"] or 0),
|
||||
"active": int(r["id"]) == active_id,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
def add_ai_model(
|
||||
self,
|
||||
user_id: int,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
timeout_sec: float = 120.0,
|
||||
max_output_tokens: int = 8192,
|
||||
max_retries: int = 0,
|
||||
) -> dict:
|
||||
now = int(time.time())
|
||||
name = model_name.strip() or f"模型{now % 10000}"
|
||||
with self._conn() as c:
|
||||
try:
|
||||
cur = c.execute(
|
||||
"""
|
||||
INSERT INTO ai_models
|
||||
(user_id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, now),
|
||||
)
|
||||
except sqlite3.IntegrityError:
|
||||
name = f"{name}-{now % 1000}"
|
||||
cur = c.execute(
|
||||
"""
|
||||
INSERT INTO ai_models
|
||||
(user_id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, now),
|
||||
)
|
||||
aid = int(cur.lastrowid)
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
active_ai_model_id=excluded.active_ai_model_id,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(user_id, aid, now),
|
||||
)
|
||||
return {"id": aid, "model_name": name}
|
||||
|
||||
def switch_active_ai_model(self, user_id: int, model_id: int) -> bool:
|
||||
now = int(time.time())
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"SELECT id FROM ai_models WHERE id=? AND user_id=?",
|
||||
(model_id, user_id),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return False
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
active_ai_model_id=excluded.active_ai_model_id,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(user_id, model_id, now),
|
||||
)
|
||||
return True
|
||||
|
||||
def delete_ai_model(self, user_id: int, model_id: int) -> bool:
|
||||
now = int(time.time())
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"SELECT id FROM ai_models WHERE id=? AND user_id=?",
|
||||
(model_id, user_id),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return False
|
||||
c.execute("DELETE FROM ai_models WHERE id=? AND user_id=?", (model_id, user_id))
|
||||
pref = c.execute(
|
||||
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
active_id = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
||||
if active_id == model_id:
|
||||
replacement = c.execute(
|
||||
"""
|
||||
SELECT id FROM ai_models WHERE user_id=?
|
||||
ORDER BY updated_at DESC, id DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
replacement_id = int(replacement["id"]) if replacement else None
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
active_ai_model_id=excluded.active_ai_model_id,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(user_id, replacement_id, now),
|
||||
)
|
||||
return True
|
||||
|
||||
def get_active_ai_model(self, user_id: int) -> dict | None:
|
||||
with self._conn() as c:
|
||||
pref = c.execute(
|
||||
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
aid = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
|
||||
row = None
|
||||
if aid:
|
||||
row = c.execute(
|
||||
"""
|
||||
SELECT id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at
|
||||
FROM ai_models
|
||||
WHERE id=? AND user_id=?
|
||||
""",
|
||||
(aid, user_id),
|
||||
).fetchone()
|
||||
if not row:
|
||||
row = c.execute(
|
||||
"""
|
||||
SELECT id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at
|
||||
FROM ai_models
|
||||
WHERE user_id=?
|
||||
ORDER BY updated_at DESC, id DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
if row:
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
active_ai_model_id=excluded.active_ai_model_id,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(user_id, int(row["id"]), int(time.time())),
|
||||
)
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"id": int(row["id"]),
|
||||
"model_name": row["model_name"] or "",
|
||||
"api_key": row["api_key"] or "",
|
||||
"base_url": row["base_url"] or "",
|
||||
"model": row["model"] or "",
|
||||
"timeout_sec": float(row["timeout_sec"] or 120.0),
|
||||
"max_output_tokens": int(row["max_output_tokens"] or 8192),
|
||||
"max_retries": int(row["max_retries"] or 0),
|
||||
"updated_at": int(row["updated_at"] or 0),
|
||||
}
|
||||
|
||||
def delete_wechat_binding(self, user_id: int, account_id: int) -> bool:
|
||||
now = int(time.time())
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"SELECT id FROM wechat_accounts WHERE id=? AND user_id=?",
|
||||
(account_id, user_id),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return False
|
||||
c.execute("DELETE FROM wechat_accounts WHERE id=? AND user_id=?", (account_id, user_id))
|
||||
pref = c.execute(
|
||||
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
active_id = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
|
||||
if active_id == account_id:
|
||||
replacement = c.execute(
|
||||
"""
|
||||
SELECT id FROM wechat_accounts WHERE user_id=?
|
||||
ORDER BY updated_at DESC, id DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
replacement_id = int(replacement["id"]) if replacement else None
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
active_wechat_account_id=excluded.active_wechat_account_id,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(user_id, replacement_id, now),
|
||||
)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user