Files
wechatAiclaw/backend/db.py
2026-03-11 09:44:17 +08:00

180 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""SQLite 数据库:表结构初始化与连接,数据目录由 DATA_DIR 决定(可挂载到宿主机)。"""
import json
import os
import sqlite3
_DATA_DIR = os.getenv("DATA_DIR") or os.path.join(os.path.dirname(__file__), "data")
_DB_PATH = os.path.join(_DATA_DIR, "wechat.db")
def get_db_path() -> str:
return _DB_PATH
def get_conn() -> sqlite3.Connection:
os.makedirs(_DATA_DIR, exist_ok=True)
conn = sqlite3.connect(_DB_PATH, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def init_schema(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
# 客户档案
cur.execute("""
CREATE TABLE IF NOT EXISTS customers (
id TEXT PRIMARY KEY,
key TEXT NOT NULL,
wxid TEXT NOT NULL,
remark_name TEXT,
region TEXT,
age TEXT,
gender TEXT,
level TEXT,
tags TEXT
)
""")
cur.execute("CREATE INDEX IF NOT EXISTS idx_customers_key ON customers(key)")
# 定时问候任务
cur.execute("""
CREATE TABLE IF NOT EXISTS greeting_tasks (
id TEXT PRIMARY KEY,
key TEXT NOT NULL,
name TEXT,
send_time TEXT,
customer_tags TEXT,
template TEXT,
use_qwen INTEGER DEFAULT 0,
enabled INTEGER DEFAULT 1,
executed_at TEXT
)
""")
cur.execute("CREATE INDEX IF NOT EXISTS idx_greeting_tasks_key ON greeting_tasks(key)")
# 商品标签
cur.execute("""
CREATE TABLE IF NOT EXISTS product_tags (
id TEXT PRIMARY KEY,
key TEXT NOT NULL,
name TEXT
)
""")
# 推送群组
cur.execute("""
CREATE TABLE IF NOT EXISTS push_groups (
id TEXT PRIMARY KEY,
key TEXT NOT NULL,
name TEXT,
customer_ids TEXT,
tag_ids TEXT
)
""")
# 推送任务
cur.execute("""
CREATE TABLE IF NOT EXISTS push_tasks (
id TEXT PRIMARY KEY,
key TEXT NOT NULL,
product_tag_id TEXT,
group_id TEXT,
content TEXT,
send_at TEXT,
status TEXT,
created_at TEXT
)
""")
# 同步消息WS 拉取 + 发出记录)
cur.execute("""
CREATE TABLE IF NOT EXISTS sync_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT NOT NULL,
create_time INTEGER DEFAULT 0,
payload TEXT
)
""")
cur.execute("CREATE INDEX IF NOT EXISTS idx_sync_messages_key ON sync_messages(key)")
# 模型配置
cur.execute("""
CREATE TABLE IF NOT EXISTS models (
id TEXT PRIMARY KEY,
name TEXT,
provider TEXT,
api_key TEXT,
base_url TEXT,
model_name TEXT,
is_current INTEGER DEFAULT 0
)
""")
# AI 回复配置(白名单 / 超级管理员)
cur.execute("""
CREATE TABLE IF NOT EXISTS ai_reply_config (
key TEXT PRIMARY KEY,
super_admin_wxids TEXT,
whitelist_wxids TEXT
)
""")
conn.commit()
_migrate_json_if_needed(conn)
def _migrate_json_if_needed(conn: sqlite3.Connection) -> None:
"""若表为空且存在同名 JSON 文件,则从 JSON 迁移一次。"""
cur = conn.cursor()
# (table, filename, columns, json_columns)
tables_files = [
("customers", "customers.json", ["id", "key", "wxid", "remark_name", "region", "age", "gender", "level", "tags"], ["tags"]),
("greeting_tasks", "greeting_tasks.json", ["id", "key", "name", "send_time", "customer_tags", "template", "use_qwen", "enabled", "executed_at"], ["customer_tags"]),
("product_tags", "product_tags.json", ["id", "key", "name"], []),
("push_groups", "push_groups.json", ["id", "key", "name", "customer_ids", "tag_ids"], ["customer_ids", "tag_ids"]),
("push_tasks", "push_tasks.json", ["id", "key", "product_tag_id", "group_id", "content", "send_at", "status", "created_at"], []),
("models", "models.json", ["id", "name", "provider", "api_key", "base_url", "model_name", "is_current"], []),
("ai_reply_config", "ai_reply_config.json", ["key", "super_admin_wxids", "whitelist_wxids"], ["super_admin_wxids", "whitelist_wxids"]),
]
for table, filename, columns, json_cols in tables_files:
json_cols_set = set(json_cols)
cur.execute(f"SELECT COUNT(*) FROM {table}")
if cur.fetchone()[0] > 0:
continue
path = os.path.join(_DATA_DIR, filename)
if not os.path.isfile(path):
continue
try:
with open(path, "r", encoding="utf-8") as f:
rows = json.load(f)
except Exception:
continue
if not rows:
continue
for r in rows:
if not isinstance(r, dict):
continue
vals = []
for c in columns:
v = r.get(c)
if c in json_cols_set and isinstance(v, (list, dict)):
v = json.dumps(v, ensure_ascii=False)
elif isinstance(v, bool):
v = 1 if v else 0
vals.append(v)
placeholders = ",".join("?" * len(columns))
cur.execute(f"INSERT OR IGNORE INTO {table} ({','.join(columns)}) VALUES ({placeholders})", vals)
# sync_messages: 按 key + payload 迁移
cur.execute("SELECT COUNT(*) FROM sync_messages")
if cur.fetchone()[0] == 0:
path = os.path.join(_DATA_DIR, "sync_messages.json")
if os.path.isfile(path):
try:
with open(path, "r", encoding="utf-8") as f:
rows = json.load(f)
for r in rows:
if not isinstance(r, dict):
continue
key = r.get("key", "")
ct = int(r.get("CreateTime") or 0) if isinstance(r.get("CreateTime"), (int, float)) else 0
cur.execute("INSERT INTO sync_messages (key, create_time, payload) VALUES (?,?,?)", (key, ct, json.dumps(r, ensure_ascii=False)))
except Exception:
pass
conn.commit()
def _conn():
c = get_conn()
init_schema(c)
return c