from __future__ import annotations import hashlib import hmac import secrets import sqlite3 import time from pathlib import Path class UserStore: def __init__(self, db_path: str) -> None: self._db_path = db_path p = Path(db_path) if p.parent and not p.parent.exists(): p.parent.mkdir(parents=True, exist_ok=True) self._init_db() def _conn(self) -> sqlite3.Connection: c = sqlite3.connect(self._db_path) c.row_factory = sqlite3.Row return c def _init_db(self) -> None: with self._conn() as c: self._ensure_users_table(c) self._ensure_sessions_table(c) c.execute( """ CREATE TABLE IF NOT EXISTS wechat_bindings ( user_id INTEGER PRIMARY KEY, appid TEXT NOT NULL, secret TEXT NOT NULL, author TEXT NOT NULL DEFAULT '', thumb_media_id TEXT NOT NULL DEFAULT '', thumb_image_path TEXT NOT NULL DEFAULT '', updated_at INTEGER NOT NULL, FOREIGN KEY(user_id) REFERENCES users(id) ) """ ) c.execute( """ CREATE TABLE IF NOT EXISTS wechat_accounts ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, account_name TEXT NOT NULL, appid TEXT NOT NULL, secret TEXT NOT NULL, author TEXT NOT NULL DEFAULT '', thumb_media_id TEXT NOT NULL DEFAULT '', thumb_image_path TEXT NOT NULL DEFAULT '', updated_at INTEGER NOT NULL, UNIQUE(user_id, account_name), FOREIGN KEY(user_id) REFERENCES users(id) ) """ ) c.execute( """ CREATE TABLE IF NOT EXISTS user_prefs ( user_id INTEGER PRIMARY KEY, active_wechat_account_id INTEGER, updated_at INTEGER NOT NULL, FOREIGN KEY(user_id) REFERENCES users(id) ) """ ) # 兼容历史单绑定结构,自动迁移为默认账号 rows = c.execute( "SELECT user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_bindings" ).fetchall() for r in rows: now = int(time.time()) cur = c.execute( """ INSERT OR IGNORE INTO wechat_accounts (user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at) VALUES (?, '默认公众号', ?, ?, ?, ?, ?, ?) """, ( int(r["user_id"]), r["appid"] or "", r["secret"] or "", r["author"] or "", r["thumb_media_id"] or "", r["thumb_image_path"] or "", int(r["updated_at"] or now), ), ) if cur.rowcount: aid = int( c.execute( "SELECT id FROM wechat_accounts WHERE user_id=? AND account_name='默认公众号'", (int(r["user_id"]),), ).fetchone()["id"] ) c.execute( """ INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET active_wechat_account_id=COALESCE(user_prefs.active_wechat_account_id, excluded.active_wechat_account_id), updated_at=excluded.updated_at """, (int(r["user_id"]), aid, now), ) def _table_columns(self, c: sqlite3.Connection, table_name: str) -> set[str]: rows = c.execute(f"PRAGMA table_info({table_name})").fetchall() return {str(r["name"]) for r in rows} def _ensure_users_table(self, c: sqlite3.Connection) -> None: required = {"id", "username", "password_hash", "password_salt", "created_at"} c.execute( """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, password_salt TEXT NOT NULL, created_at INTEGER NOT NULL ) """ ) cols = self._table_columns(c, "users") if required.issubset(cols): return now = int(time.time()) c.execute("PRAGMA foreign_keys=OFF") c.execute("DROP TABLE IF EXISTS users_new") c.execute( """ CREATE TABLE users_new ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, password_salt TEXT NOT NULL, created_at INTEGER NOT NULL ) """ ) if {"username", "password_hash", "password_salt"}.issubset(cols): if "created_at" in cols: c.execute( """ INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at) SELECT id, username, password_hash, password_salt, COALESCE(created_at, ?) FROM users """, (now,), ) else: c.execute( """ INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at) SELECT id, username, password_hash, password_salt, ? FROM users """, (now,), ) elif {"username", "password"}.issubset(cols): if "created_at" in cols: rows = c.execute("SELECT id, username, password, created_at FROM users").fetchall() else: rows = c.execute("SELECT id, username, password FROM users").fetchall() for r in rows: username = (r["username"] or "").strip() raw_pwd = str(r["password"] or "") if not username or not raw_pwd: continue salt = secrets.token_hex(16) pwd_hash = self._hash_password(raw_pwd, salt) created_at = int(r["created_at"]) if "created_at" in r.keys() and r["created_at"] else now c.execute( """ INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at) VALUES (?, ?, ?, ?, ?) """, (int(r["id"]), username, pwd_hash, salt, created_at), ) c.execute("DROP TABLE users") c.execute("ALTER TABLE users_new RENAME TO users") c.execute("PRAGMA foreign_keys=ON") def _ensure_sessions_table(self, c: sqlite3.Connection) -> None: required = {"token_hash", "user_id", "expires_at", "created_at"} c.execute( """ CREATE TABLE IF NOT EXISTS sessions ( token_hash TEXT PRIMARY KEY, user_id INTEGER NOT NULL, expires_at INTEGER NOT NULL, created_at INTEGER NOT NULL, FOREIGN KEY(user_id) REFERENCES users(id) ) """ ) cols = self._table_columns(c, "sessions") if required.issubset(cols): return c.execute("DROP TABLE IF EXISTS sessions") c.execute( """ CREATE TABLE sessions ( token_hash TEXT PRIMARY KEY, user_id INTEGER NOT NULL, expires_at INTEGER NOT NULL, created_at INTEGER NOT NULL, FOREIGN KEY(user_id) REFERENCES users(id) ) """ ) def _hash_password(self, password: str, salt: str) -> str: data = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt.encode("utf-8"), 120_000) return data.hex() def _hash_token(self, token: str) -> str: return hashlib.sha256(token.encode("utf-8")).hexdigest() def create_user(self, username: str, password: str) -> dict | None: now = int(time.time()) salt = secrets.token_hex(16) pwd_hash = self._hash_password(password, salt) try: with self._conn() as c: cur = c.execute( "INSERT INTO users(username, password_hash, password_salt, created_at) VALUES (?, ?, ?, ?)", (username, pwd_hash, salt, now), ) uid = int(cur.lastrowid) return {"id": uid, "username": username} except sqlite3.IntegrityError: return None except sqlite3.Error as exc: raise RuntimeError(f"create_user_db_error: {exc}") from exc def verify_user(self, username: str, password: str) -> dict | None: try: with self._conn() as c: row = c.execute( "SELECT id, username, password_hash, password_salt FROM users WHERE username=?", (username,), ).fetchone() except sqlite3.Error as exc: raise RuntimeError(f"verify_user_db_error: {exc}") from exc if not row: return None calc = self._hash_password(password, row["password_salt"]) if not hmac.compare_digest(calc, row["password_hash"]): return None return {"id": int(row["id"]), "username": row["username"]} def change_password(self, user_id: int, old_password: str, new_password: str) -> bool: with self._conn() as c: row = c.execute( "SELECT password_hash, password_salt FROM users WHERE id=?", (user_id,), ).fetchone() if not row: return False calc_old = self._hash_password(old_password, row["password_salt"]) if not hmac.compare_digest(calc_old, row["password_hash"]): return False new_salt = secrets.token_hex(16) new_hash = self._hash_password(new_password, new_salt) c.execute( "UPDATE users SET password_hash=?, password_salt=? WHERE id=?", (new_hash, new_salt, user_id), ) return True def reset_password_by_username(self, username: str, new_password: str) -> bool: uname = (username or "").strip() if not uname: return False with self._conn() as c: row = c.execute("SELECT id FROM users WHERE username=?", (uname,)).fetchone() if not row: return False new_salt = secrets.token_hex(16) new_hash = self._hash_password(new_password, new_salt) c.execute( "UPDATE users SET password_hash=?, password_salt=? WHERE id=?", (new_hash, new_salt, int(row["id"])), ) return True def create_session(self, user_id: int, ttl_seconds: int = 7 * 24 * 3600) -> str: token = secrets.token_urlsafe(32) token_hash = self._hash_token(token) now = int(time.time()) exp = now + max(600, int(ttl_seconds)) with self._conn() as c: c.execute( "INSERT OR REPLACE INTO sessions(token_hash, user_id, expires_at, created_at) VALUES (?, ?, ?, ?)", (token_hash, user_id, exp, now), ) return token def delete_session(self, token: str) -> None: if not token: return with self._conn() as c: c.execute("DELETE FROM sessions WHERE token_hash=?", (self._hash_token(token),)) def delete_sessions_by_user(self, user_id: int) -> None: with self._conn() as c: c.execute("DELETE FROM sessions WHERE user_id=?", (user_id,)) def get_user_by_session(self, token: str) -> dict | None: if not token: return None now = int(time.time()) th = self._hash_token(token) with self._conn() as c: c.execute("DELETE FROM sessions WHERE expires_at < ?", (now,)) row = c.execute( """ SELECT u.id, u.username FROM sessions s JOIN users u ON u.id=s.user_id WHERE s.token_hash=? AND s.expires_at>=? """, (th, now), ).fetchone() if not row: return None return {"id": int(row["id"]), "username": row["username"]} def save_wechat_binding( self, user_id: int, appid: str, secret: str, author: str = "", thumb_media_id: str = "", thumb_image_path: str = "", ) -> None: now = int(time.time()) with self._conn() as c: c.execute( """ INSERT INTO wechat_bindings(user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET appid=excluded.appid, secret=excluded.secret, author=excluded.author, thumb_media_id=excluded.thumb_media_id, thumb_image_path=excluded.thumb_image_path, updated_at=excluded.updated_at """, (user_id, appid, secret, author, thumb_media_id, thumb_image_path, now), ) def get_wechat_binding(self, user_id: int) -> dict | None: return self.get_active_wechat_binding(user_id) def list_wechat_bindings(self, user_id: int) -> list[dict]: with self._conn() as c: rows = c.execute( """ SELECT id, account_name, appid, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_accounts WHERE user_id=? ORDER BY updated_at DESC, id DESC """, (user_id,), ).fetchall() pref = c.execute( "SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?", (user_id,), ).fetchone() active_id = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None out: list[dict] = [] for r in rows: out.append( { "id": int(r["id"]), "account_name": r["account_name"] or "", "appid": r["appid"] or "", "author": r["author"] or "", "thumb_media_id": r["thumb_media_id"] or "", "thumb_image_path": r["thumb_image_path"] or "", "updated_at": int(r["updated_at"] or 0), "active": int(r["id"]) == active_id, } ) return out def add_wechat_binding( self, user_id: int, account_name: str, appid: str, secret: str, author: str = "", thumb_media_id: str = "", thumb_image_path: str = "", ) -> dict: now = int(time.time()) name = account_name.strip() or f"公众号{now % 10000}" with self._conn() as c: try: cur = c.execute( """ INSERT INTO wechat_accounts (user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, (user_id, name, appid, secret, author, thumb_media_id, thumb_image_path, now), ) except sqlite3.IntegrityError: name = f"{name}-{now % 1000}" cur = c.execute( """ INSERT INTO wechat_accounts (user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, (user_id, name, appid, secret, author, thumb_media_id, thumb_image_path, now), ) aid = int(cur.lastrowid) c.execute( """ INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET active_wechat_account_id=excluded.active_wechat_account_id, updated_at=excluded.updated_at """, (user_id, aid, now), ) return {"id": aid, "account_name": name} def switch_active_wechat_binding(self, user_id: int, account_id: int) -> bool: now = int(time.time()) with self._conn() as c: row = c.execute( "SELECT id FROM wechat_accounts WHERE id=? AND user_id=?", (account_id, user_id), ).fetchone() if not row: return False c.execute( """ INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET active_wechat_account_id=excluded.active_wechat_account_id, updated_at=excluded.updated_at """, (user_id, account_id, now), ) return True def get_active_wechat_binding(self, user_id: int) -> dict | None: with self._conn() as c: pref = c.execute( "SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?", (user_id,), ).fetchone() aid = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None row = None if aid: row = c.execute( """ SELECT id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_accounts WHERE id=? AND user_id=? """, (aid, user_id), ).fetchone() if not row: row = c.execute( """ SELECT id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_accounts WHERE user_id=? ORDER BY updated_at DESC, id DESC LIMIT 1 """, (user_id,), ).fetchone() if row: c.execute( """ INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET active_wechat_account_id=excluded.active_wechat_account_id, updated_at=excluded.updated_at """, (user_id, int(row["id"]), int(time.time())), ) if not row: return None return { "id": int(row["id"]), "account_name": row["account_name"] or "", "appid": row["appid"] or "", "secret": row["secret"] or "", "author": row["author"] or "", "thumb_media_id": row["thumb_media_id"] or "", "thumb_image_path": row["thumb_image_path"] or "", "updated_at": int(row["updated_at"] or 0), }