513 lines
20 KiB
Python
513 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import hmac
|
|
import secrets
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
|
|
class UserStore:
|
|
def __init__(self, db_path: str) -> None:
|
|
self._db_path = db_path
|
|
p = Path(db_path)
|
|
if p.parent and not p.parent.exists():
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_db()
|
|
|
|
def _conn(self) -> sqlite3.Connection:
|
|
c = sqlite3.connect(self._db_path)
|
|
c.row_factory = sqlite3.Row
|
|
return c
|
|
|
|
def _init_db(self) -> None:
|
|
with self._conn() as c:
|
|
self._ensure_users_table(c)
|
|
self._ensure_sessions_table(c)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS wechat_bindings (
|
|
user_id INTEGER PRIMARY KEY,
|
|
appid TEXT NOT NULL,
|
|
secret TEXT NOT NULL,
|
|
author TEXT NOT NULL DEFAULT '',
|
|
thumb_media_id TEXT NOT NULL DEFAULT '',
|
|
thumb_image_path TEXT NOT NULL DEFAULT '',
|
|
updated_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS wechat_accounts (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
account_name TEXT NOT NULL,
|
|
appid TEXT NOT NULL,
|
|
secret TEXT NOT NULL,
|
|
author TEXT NOT NULL DEFAULT '',
|
|
thumb_media_id TEXT NOT NULL DEFAULT '',
|
|
thumb_image_path TEXT NOT NULL DEFAULT '',
|
|
updated_at INTEGER NOT NULL,
|
|
UNIQUE(user_id, account_name),
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS user_prefs (
|
|
user_id INTEGER PRIMARY KEY,
|
|
active_wechat_account_id INTEGER,
|
|
updated_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
# 兼容历史单绑定结构,自动迁移为默认账号
|
|
rows = c.execute(
|
|
"SELECT user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_bindings"
|
|
).fetchall()
|
|
for r in rows:
|
|
now = int(time.time())
|
|
cur = c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO wechat_accounts
|
|
(user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, '默认公众号', ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
int(r["user_id"]),
|
|
r["appid"] or "",
|
|
r["secret"] or "",
|
|
r["author"] or "",
|
|
r["thumb_media_id"] or "",
|
|
r["thumb_image_path"] or "",
|
|
int(r["updated_at"] or now),
|
|
),
|
|
)
|
|
if cur.rowcount:
|
|
aid = int(
|
|
c.execute(
|
|
"SELECT id FROM wechat_accounts WHERE user_id=? AND account_name='默认公众号'",
|
|
(int(r["user_id"]),),
|
|
).fetchone()["id"]
|
|
)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=COALESCE(user_prefs.active_wechat_account_id, excluded.active_wechat_account_id),
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(int(r["user_id"]), aid, now),
|
|
)
|
|
|
|
def _table_columns(self, c: sqlite3.Connection, table_name: str) -> set[str]:
|
|
rows = c.execute(f"PRAGMA table_info({table_name})").fetchall()
|
|
return {str(r["name"]) for r in rows}
|
|
|
|
def _ensure_users_table(self, c: sqlite3.Connection) -> None:
|
|
required = {"id", "username", "password_hash", "password_salt", "created_at"}
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
password_salt TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
cols = self._table_columns(c, "users")
|
|
if required.issubset(cols):
|
|
return
|
|
|
|
now = int(time.time())
|
|
c.execute("PRAGMA foreign_keys=OFF")
|
|
c.execute("DROP TABLE IF EXISTS users_new")
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE users_new (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
password_salt TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
|
|
if {"username", "password_hash", "password_salt"}.issubset(cols):
|
|
if "created_at" in cols:
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
|
|
SELECT id, username, password_hash, password_salt, COALESCE(created_at, ?)
|
|
FROM users
|
|
""",
|
|
(now,),
|
|
)
|
|
else:
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
|
|
SELECT id, username, password_hash, password_salt, ?
|
|
FROM users
|
|
""",
|
|
(now,),
|
|
)
|
|
elif {"username", "password"}.issubset(cols):
|
|
if "created_at" in cols:
|
|
rows = c.execute("SELECT id, username, password, created_at FROM users").fetchall()
|
|
else:
|
|
rows = c.execute("SELECT id, username, password FROM users").fetchall()
|
|
for r in rows:
|
|
username = (r["username"] or "").strip()
|
|
raw_pwd = str(r["password"] or "")
|
|
if not username or not raw_pwd:
|
|
continue
|
|
salt = secrets.token_hex(16)
|
|
pwd_hash = self._hash_password(raw_pwd, salt)
|
|
created_at = int(r["created_at"]) if "created_at" in r.keys() and r["created_at"] else now
|
|
c.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(int(r["id"]), username, pwd_hash, salt, created_at),
|
|
)
|
|
|
|
c.execute("DROP TABLE users")
|
|
c.execute("ALTER TABLE users_new RENAME TO users")
|
|
c.execute("PRAGMA foreign_keys=ON")
|
|
|
|
def _ensure_sessions_table(self, c: sqlite3.Connection) -> None:
|
|
required = {"token_hash", "user_id", "expires_at", "created_at"}
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
token_hash TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
cols = self._table_columns(c, "sessions")
|
|
if required.issubset(cols):
|
|
return
|
|
c.execute("DROP TABLE IF EXISTS sessions")
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE sessions (
|
|
token_hash TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
)
|
|
"""
|
|
)
|
|
|
|
def _hash_password(self, password: str, salt: str) -> str:
|
|
data = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt.encode("utf-8"), 120_000)
|
|
return data.hex()
|
|
|
|
def _hash_token(self, token: str) -> str:
|
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
|
|
def create_user(self, username: str, password: str) -> dict | None:
|
|
now = int(time.time())
|
|
salt = secrets.token_hex(16)
|
|
pwd_hash = self._hash_password(password, salt)
|
|
try:
|
|
with self._conn() as c:
|
|
cur = c.execute(
|
|
"INSERT INTO users(username, password_hash, password_salt, created_at) VALUES (?, ?, ?, ?)",
|
|
(username, pwd_hash, salt, now),
|
|
)
|
|
uid = int(cur.lastrowid)
|
|
return {"id": uid, "username": username}
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
except sqlite3.Error as exc:
|
|
raise RuntimeError(f"create_user_db_error: {exc}") from exc
|
|
|
|
def verify_user(self, username: str, password: str) -> dict | None:
|
|
try:
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT id, username, password_hash, password_salt FROM users WHERE username=?",
|
|
(username,),
|
|
).fetchone()
|
|
except sqlite3.Error as exc:
|
|
raise RuntimeError(f"verify_user_db_error: {exc}") from exc
|
|
if not row:
|
|
return None
|
|
calc = self._hash_password(password, row["password_salt"])
|
|
if not hmac.compare_digest(calc, row["password_hash"]):
|
|
return None
|
|
return {"id": int(row["id"]), "username": row["username"]}
|
|
|
|
def change_password(self, user_id: int, old_password: str, new_password: str) -> bool:
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT password_hash, password_salt FROM users WHERE id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
calc_old = self._hash_password(old_password, row["password_salt"])
|
|
if not hmac.compare_digest(calc_old, row["password_hash"]):
|
|
return False
|
|
new_salt = secrets.token_hex(16)
|
|
new_hash = self._hash_password(new_password, new_salt)
|
|
c.execute(
|
|
"UPDATE users SET password_hash=?, password_salt=? WHERE id=?",
|
|
(new_hash, new_salt, user_id),
|
|
)
|
|
return True
|
|
|
|
def reset_password_by_username(self, username: str, new_password: str) -> bool:
|
|
uname = (username or "").strip()
|
|
if not uname:
|
|
return False
|
|
with self._conn() as c:
|
|
row = c.execute("SELECT id FROM users WHERE username=?", (uname,)).fetchone()
|
|
if not row:
|
|
return False
|
|
new_salt = secrets.token_hex(16)
|
|
new_hash = self._hash_password(new_password, new_salt)
|
|
c.execute(
|
|
"UPDATE users SET password_hash=?, password_salt=? WHERE id=?",
|
|
(new_hash, new_salt, int(row["id"])),
|
|
)
|
|
return True
|
|
|
|
def create_session(self, user_id: int, ttl_seconds: int = 7 * 24 * 3600) -> str:
|
|
token = secrets.token_urlsafe(32)
|
|
token_hash = self._hash_token(token)
|
|
now = int(time.time())
|
|
exp = now + max(600, int(ttl_seconds))
|
|
with self._conn() as c:
|
|
c.execute(
|
|
"INSERT OR REPLACE INTO sessions(token_hash, user_id, expires_at, created_at) VALUES (?, ?, ?, ?)",
|
|
(token_hash, user_id, exp, now),
|
|
)
|
|
return token
|
|
|
|
def delete_session(self, token: str) -> None:
|
|
if not token:
|
|
return
|
|
with self._conn() as c:
|
|
c.execute("DELETE FROM sessions WHERE token_hash=?", (self._hash_token(token),))
|
|
|
|
def delete_sessions_by_user(self, user_id: int) -> None:
|
|
with self._conn() as c:
|
|
c.execute("DELETE FROM sessions WHERE user_id=?", (user_id,))
|
|
|
|
def get_user_by_session(self, token: str) -> dict | None:
|
|
if not token:
|
|
return None
|
|
now = int(time.time())
|
|
th = self._hash_token(token)
|
|
with self._conn() as c:
|
|
c.execute("DELETE FROM sessions WHERE expires_at < ?", (now,))
|
|
row = c.execute(
|
|
"""
|
|
SELECT u.id, u.username
|
|
FROM sessions s
|
|
JOIN users u ON u.id=s.user_id
|
|
WHERE s.token_hash=? AND s.expires_at>=?
|
|
""",
|
|
(th, now),
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
return {"id": int(row["id"]), "username": row["username"]}
|
|
|
|
def save_wechat_binding(
|
|
self,
|
|
user_id: int,
|
|
appid: str,
|
|
secret: str,
|
|
author: str = "",
|
|
thumb_media_id: str = "",
|
|
thumb_image_path: str = "",
|
|
) -> None:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO wechat_bindings(user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
appid=excluded.appid,
|
|
secret=excluded.secret,
|
|
author=excluded.author,
|
|
thumb_media_id=excluded.thumb_media_id,
|
|
thumb_image_path=excluded.thumb_image_path,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, appid, secret, author, thumb_media_id, thumb_image_path, now),
|
|
)
|
|
|
|
def get_wechat_binding(self, user_id: int) -> dict | None:
|
|
return self.get_active_wechat_binding(user_id)
|
|
|
|
def list_wechat_bindings(self, user_id: int) -> list[dict]:
|
|
with self._conn() as c:
|
|
rows = c.execute(
|
|
"""
|
|
SELECT id, account_name, appid, author, thumb_media_id, thumb_image_path, updated_at
|
|
FROM wechat_accounts
|
|
WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
""",
|
|
(user_id,),
|
|
).fetchall()
|
|
pref = c.execute(
|
|
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
active_id = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
|
|
out: list[dict] = []
|
|
for r in rows:
|
|
out.append(
|
|
{
|
|
"id": int(r["id"]),
|
|
"account_name": r["account_name"] or "",
|
|
"appid": r["appid"] or "",
|
|
"author": r["author"] or "",
|
|
"thumb_media_id": r["thumb_media_id"] or "",
|
|
"thumb_image_path": r["thumb_image_path"] or "",
|
|
"updated_at": int(r["updated_at"] or 0),
|
|
"active": int(r["id"]) == active_id,
|
|
}
|
|
)
|
|
return out
|
|
|
|
def add_wechat_binding(
|
|
self,
|
|
user_id: int,
|
|
account_name: str,
|
|
appid: str,
|
|
secret: str,
|
|
author: str = "",
|
|
thumb_media_id: str = "",
|
|
thumb_image_path: str = "",
|
|
) -> dict:
|
|
now = int(time.time())
|
|
name = account_name.strip() or f"公众号{now % 10000}"
|
|
with self._conn() as c:
|
|
try:
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO wechat_accounts
|
|
(user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, name, appid, secret, author, thumb_media_id, thumb_image_path, now),
|
|
)
|
|
except sqlite3.IntegrityError:
|
|
name = f"{name}-{now % 1000}"
|
|
cur = c.execute(
|
|
"""
|
|
INSERT INTO wechat_accounts
|
|
(user_id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, name, appid, secret, author, thumb_media_id, thumb_image_path, now),
|
|
)
|
|
aid = int(cur.lastrowid)
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, aid, now),
|
|
)
|
|
return {"id": aid, "account_name": name}
|
|
|
|
def switch_active_wechat_binding(self, user_id: int, account_id: int) -> bool:
|
|
now = int(time.time())
|
|
with self._conn() as c:
|
|
row = c.execute(
|
|
"SELECT id FROM wechat_accounts WHERE id=? AND user_id=?",
|
|
(account_id, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, account_id, now),
|
|
)
|
|
return True
|
|
|
|
def get_active_wechat_binding(self, user_id: int) -> dict | None:
|
|
with self._conn() as c:
|
|
pref = c.execute(
|
|
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
|
|
(user_id,),
|
|
).fetchone()
|
|
aid = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
|
|
row = None
|
|
if aid:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at
|
|
FROM wechat_accounts
|
|
WHERE id=? AND user_id=?
|
|
""",
|
|
(aid, user_id),
|
|
).fetchone()
|
|
if not row:
|
|
row = c.execute(
|
|
"""
|
|
SELECT id, account_name, appid, secret, author, thumb_media_id, thumb_image_path, updated_at
|
|
FROM wechat_accounts
|
|
WHERE user_id=?
|
|
ORDER BY updated_at DESC, id DESC
|
|
LIMIT 1
|
|
""",
|
|
(user_id,),
|
|
).fetchone()
|
|
if row:
|
|
c.execute(
|
|
"""
|
|
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
active_wechat_account_id=excluded.active_wechat_account_id,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(user_id, int(row["id"]), int(time.time())),
|
|
)
|
|
if not row:
|
|
return None
|
|
return {
|
|
"id": int(row["id"]),
|
|
"account_name": row["account_name"] or "",
|
|
"appid": row["appid"] or "",
|
|
"secret": row["secret"] or "",
|
|
"author": row["author"] or "",
|
|
"thumb_media_id": row["thumb_media_id"] or "",
|
|
"thumb_image_path": row["thumb_image_path"] or "",
|
|
"updated_at": int(row["updated_at"] or 0),
|
|
}
|