# -*- coding: utf-8 -*- """SQLite 数据库:表结构初始化与连接,数据目录由 DATA_DIR 决定(可挂载到宿主机)。""" import json import os import sqlite3 _DATA_DIR = os.getenv("DATA_DIR") or os.path.join(os.path.dirname(__file__), "data") _DB_PATH = os.path.join(_DATA_DIR, "wechat.db") def get_db_path() -> str: return _DB_PATH def get_conn() -> sqlite3.Connection: os.makedirs(_DATA_DIR, exist_ok=True) conn = sqlite3.connect(_DB_PATH, check_same_thread=False) conn.row_factory = sqlite3.Row return conn def init_schema(conn: sqlite3.Connection) -> None: cur = conn.cursor() # 客户档案 cur.execute(""" CREATE TABLE IF NOT EXISTS customers ( id TEXT PRIMARY KEY, key TEXT NOT NULL, wxid TEXT NOT NULL, remark_name TEXT, region TEXT, age TEXT, gender TEXT, level TEXT, tags TEXT ) """) cur.execute("CREATE INDEX IF NOT EXISTS idx_customers_key ON customers(key)") # 定时问候任务 cur.execute(""" CREATE TABLE IF NOT EXISTS greeting_tasks ( id TEXT PRIMARY KEY, key TEXT NOT NULL, name TEXT, send_time TEXT, customer_tags TEXT, template TEXT, use_qwen INTEGER DEFAULT 0, enabled INTEGER DEFAULT 1, executed_at TEXT ) """) cur.execute("CREATE INDEX IF NOT EXISTS idx_greeting_tasks_key ON greeting_tasks(key)") # 商品标签 cur.execute(""" CREATE TABLE IF NOT EXISTS product_tags ( id TEXT PRIMARY KEY, key TEXT NOT NULL, name TEXT ) """) # 推送群组 cur.execute(""" CREATE TABLE IF NOT EXISTS push_groups ( id TEXT PRIMARY KEY, key TEXT NOT NULL, name TEXT, customer_ids TEXT, tag_ids TEXT ) """) # 推送任务 cur.execute(""" CREATE TABLE IF NOT EXISTS push_tasks ( id TEXT PRIMARY KEY, key TEXT NOT NULL, product_tag_id TEXT, group_id TEXT, content TEXT, send_at TEXT, status TEXT, created_at TEXT ) """) # 同步消息(WS 拉取 + 发出记录) cur.execute(""" CREATE TABLE IF NOT EXISTS sync_messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, key TEXT NOT NULL, create_time INTEGER DEFAULT 0, payload TEXT ) """) cur.execute("CREATE INDEX IF NOT EXISTS idx_sync_messages_key ON sync_messages(key)") # 模型配置 cur.execute(""" CREATE TABLE IF NOT EXISTS models ( id TEXT PRIMARY KEY, name TEXT, provider TEXT, api_key TEXT, base_url TEXT, model_name TEXT, is_current INTEGER DEFAULT 0 ) """) # AI 回复配置(白名单 / 超级管理员) cur.execute(""" CREATE TABLE IF NOT EXISTS ai_reply_config ( key TEXT PRIMARY KEY, super_admin_wxids TEXT, whitelist_wxids TEXT ) """) conn.commit() _migrate_json_if_needed(conn) def _migrate_json_if_needed(conn: sqlite3.Connection) -> None: """若表为空且存在同名 JSON 文件,则从 JSON 迁移一次。""" cur = conn.cursor() # (table, filename, columns, json_columns) tables_files = [ ("customers", "customers.json", ["id", "key", "wxid", "remark_name", "region", "age", "gender", "level", "tags"], ["tags"]), ("greeting_tasks", "greeting_tasks.json", ["id", "key", "name", "send_time", "customer_tags", "template", "use_qwen", "enabled", "executed_at"], ["customer_tags"]), ("product_tags", "product_tags.json", ["id", "key", "name"], []), ("push_groups", "push_groups.json", ["id", "key", "name", "customer_ids", "tag_ids"], ["customer_ids", "tag_ids"]), ("push_tasks", "push_tasks.json", ["id", "key", "product_tag_id", "group_id", "content", "send_at", "status", "created_at"], []), ("models", "models.json", ["id", "name", "provider", "api_key", "base_url", "model_name", "is_current"], []), ("ai_reply_config", "ai_reply_config.json", ["key", "super_admin_wxids", "whitelist_wxids"], ["super_admin_wxids", "whitelist_wxids"]), ] for table, filename, columns, json_cols in tables_files: json_cols_set = set(json_cols) cur.execute(f"SELECT COUNT(*) FROM {table}") if cur.fetchone()[0] > 0: continue path = os.path.join(_DATA_DIR, filename) if not os.path.isfile(path): continue try: with open(path, "r", encoding="utf-8") as f: rows = json.load(f) except Exception: continue if not rows: continue for r in rows: if not isinstance(r, dict): continue vals = [] for c in columns: v = r.get(c) if c in json_cols_set and isinstance(v, (list, dict)): v = json.dumps(v, ensure_ascii=False) elif isinstance(v, bool): v = 1 if v else 0 vals.append(v) placeholders = ",".join("?" * len(columns)) cur.execute(f"INSERT OR IGNORE INTO {table} ({','.join(columns)}) VALUES ({placeholders})", vals) # sync_messages: 按 key + payload 迁移 cur.execute("SELECT COUNT(*) FROM sync_messages") if cur.fetchone()[0] == 0: path = os.path.join(_DATA_DIR, "sync_messages.json") if os.path.isfile(path): try: with open(path, "r", encoding="utf-8") as f: rows = json.load(f) for r in rows: if not isinstance(r, dict): continue key = r.get("key", "") ct = int(r.get("CreateTime") or 0) if isinstance(r.get("CreateTime"), (int, float)) else 0 cur.execute("INSERT INTO sync_messages (key, create_time, payload) VALUES (?,?,?)", (key, ct, json.dumps(r, ensure_ascii=False))) except Exception: pass conn.commit() def _conn(): c = get_conn() init_schema(c) return c