191 lines
6.8 KiB
Python
191 lines
6.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""SQLite 数据库:表结构初始化与连接,数据目录由 DATA_DIR 决定(可挂载到宿主机)。"""
|
||
import json
|
||
import os
|
||
import sqlite3
|
||
|
||
_DATA_DIR = os.getenv("DATA_DIR") or os.path.join(os.path.dirname(__file__), "data")
|
||
_DB_PATH = os.path.join(_DATA_DIR, "wechat.db")
|
||
|
||
def get_db_path() -> str:
|
||
return _DB_PATH
|
||
|
||
def get_conn() -> sqlite3.Connection:
|
||
os.makedirs(_DATA_DIR, exist_ok=True)
|
||
conn = sqlite3.connect(_DB_PATH, check_same_thread=False)
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
def init_schema(conn: sqlite3.Connection) -> None:
|
||
cur = conn.cursor()
|
||
# 客户档案
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS customers (
|
||
id TEXT PRIMARY KEY,
|
||
key TEXT NOT NULL,
|
||
wxid TEXT NOT NULL,
|
||
remark_name TEXT,
|
||
region TEXT,
|
||
age TEXT,
|
||
gender TEXT,
|
||
level TEXT,
|
||
tags TEXT
|
||
)
|
||
""")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_customers_key ON customers(key)")
|
||
# 定时问候任务
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS greeting_tasks (
|
||
id TEXT PRIMARY KEY,
|
||
key TEXT NOT NULL,
|
||
name TEXT,
|
||
send_time TEXT,
|
||
customer_tags TEXT,
|
||
template TEXT,
|
||
use_qwen INTEGER DEFAULT 0,
|
||
enabled INTEGER DEFAULT 1,
|
||
executed_at TEXT
|
||
)
|
||
""")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_greeting_tasks_key ON greeting_tasks(key)")
|
||
# 商品标签
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS product_tags (
|
||
id TEXT PRIMARY KEY,
|
||
key TEXT NOT NULL,
|
||
name TEXT
|
||
)
|
||
""")
|
||
# 推送群组
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS push_groups (
|
||
id TEXT PRIMARY KEY,
|
||
key TEXT NOT NULL,
|
||
name TEXT,
|
||
customer_ids TEXT,
|
||
tag_ids TEXT
|
||
)
|
||
""")
|
||
# 推送任务
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS push_tasks (
|
||
id TEXT PRIMARY KEY,
|
||
key TEXT NOT NULL,
|
||
product_tag_id TEXT,
|
||
group_id TEXT,
|
||
content TEXT,
|
||
send_at TEXT,
|
||
status TEXT,
|
||
created_at TEXT
|
||
)
|
||
""")
|
||
# 同步消息(WS 拉取 + 发出记录)
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS sync_messages (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
key TEXT NOT NULL,
|
||
create_time INTEGER DEFAULT 0,
|
||
payload TEXT
|
||
)
|
||
""")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_sync_messages_key ON sync_messages(key)")
|
||
# 回调原始 body 落库,便于回溯与统计
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS callback_log (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
key TEXT NOT NULL,
|
||
received_at TEXT NOT NULL,
|
||
raw_body TEXT
|
||
)
|
||
""")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_callback_log_key ON callback_log(key)")
|
||
cur.execute("CREATE INDEX IF NOT EXISTS idx_callback_log_received ON callback_log(received_at)")
|
||
# 模型配置
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS models (
|
||
id TEXT PRIMARY KEY,
|
||
name TEXT,
|
||
provider TEXT,
|
||
api_key TEXT,
|
||
base_url TEXT,
|
||
model_name TEXT,
|
||
is_current INTEGER DEFAULT 0
|
||
)
|
||
""")
|
||
# AI 回复配置(白名单 / 超级管理员)
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS ai_reply_config (
|
||
key TEXT PRIMARY KEY,
|
||
super_admin_wxids TEXT,
|
||
whitelist_wxids TEXT
|
||
)
|
||
""")
|
||
conn.commit()
|
||
_migrate_json_if_needed(conn)
|
||
|
||
|
||
def _migrate_json_if_needed(conn: sqlite3.Connection) -> None:
|
||
"""若表为空且存在同名 JSON 文件,则从 JSON 迁移一次。"""
|
||
cur = conn.cursor()
|
||
# (table, filename, columns, json_columns)
|
||
tables_files = [
|
||
("customers", "customers.json", ["id", "key", "wxid", "remark_name", "region", "age", "gender", "level", "tags"], ["tags"]),
|
||
("greeting_tasks", "greeting_tasks.json", ["id", "key", "name", "send_time", "customer_tags", "template", "use_qwen", "enabled", "executed_at"], ["customer_tags"]),
|
||
("product_tags", "product_tags.json", ["id", "key", "name"], []),
|
||
("push_groups", "push_groups.json", ["id", "key", "name", "customer_ids", "tag_ids"], ["customer_ids", "tag_ids"]),
|
||
("push_tasks", "push_tasks.json", ["id", "key", "product_tag_id", "group_id", "content", "send_at", "status", "created_at"], []),
|
||
("models", "models.json", ["id", "name", "provider", "api_key", "base_url", "model_name", "is_current"], []),
|
||
("ai_reply_config", "ai_reply_config.json", ["key", "super_admin_wxids", "whitelist_wxids"], ["super_admin_wxids", "whitelist_wxids"]),
|
||
]
|
||
for table, filename, columns, json_cols in tables_files:
|
||
json_cols_set = set(json_cols)
|
||
cur.execute(f"SELECT COUNT(*) FROM {table}")
|
||
if cur.fetchone()[0] > 0:
|
||
continue
|
||
path = os.path.join(_DATA_DIR, filename)
|
||
if not os.path.isfile(path):
|
||
continue
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
rows = json.load(f)
|
||
except Exception:
|
||
continue
|
||
if not rows:
|
||
continue
|
||
for r in rows:
|
||
if not isinstance(r, dict):
|
||
continue
|
||
vals = []
|
||
for c in columns:
|
||
v = r.get(c)
|
||
if c in json_cols_set and isinstance(v, (list, dict)):
|
||
v = json.dumps(v, ensure_ascii=False)
|
||
elif isinstance(v, bool):
|
||
v = 1 if v else 0
|
||
vals.append(v)
|
||
placeholders = ",".join("?" * len(columns))
|
||
cur.execute(f"INSERT OR IGNORE INTO {table} ({','.join(columns)}) VALUES ({placeholders})", vals)
|
||
# sync_messages: 按 key + payload 迁移
|
||
cur.execute("SELECT COUNT(*) FROM sync_messages")
|
||
if cur.fetchone()[0] == 0:
|
||
path = os.path.join(_DATA_DIR, "sync_messages.json")
|
||
if os.path.isfile(path):
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
rows = json.load(f)
|
||
for r in rows:
|
||
if not isinstance(r, dict):
|
||
continue
|
||
key = r.get("key", "")
|
||
ct = int(r.get("CreateTime") or 0) if isinstance(r.get("CreateTime"), (int, float)) else 0
|
||
cur.execute("INSERT INTO sync_messages (key, create_time, payload) VALUES (?,?,?)", (key, ct, json.dumps(r, ensure_ascii=False)))
|
||
except Exception:
|
||
pass
|
||
conn.commit()
|
||
|
||
|
||
def _conn():
|
||
c = get_conn()
|
||
init_schema(c)
|
||
return c
|