Files
wechatAiclaw/backend/store.py
2026-03-11 00:22:41 +08:00

391 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""JSON 文件存储:客户档案、定时问候任务、商品标签、推送群组、推送任务、同步消息。"""
import json
import os
import threading
import uuid
from typing import Any, Dict, List, Optional
_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
_LOCK = threading.Lock()
def _path(name: str) -> str:
os.makedirs(_DATA_DIR, exist_ok=True)
return os.path.join(_DATA_DIR, f"{name}.json")
def _load(name: str) -> list:
with _LOCK:
p = _path(name)
if not os.path.exists(p):
return []
try:
with open(p, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return []
def _save(name: str, data: list) -> None:
with _LOCK:
p = _path(name)
with open(p, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
# ---------- 客户档案 R1-2 ----------
def list_customers(key: Optional[str] = None) -> List[Dict]:
"""key: 微信 key若传则只返回该 key 下的客户。"""
rows = _load("customers")
if key:
rows = [r for r in rows if r.get("key") == key]
return sorted(rows, key=lambda x: (x.get("remark_name") or x.get("wxid") or ""))
def get_customer(customer_id: str) -> Optional[Dict]:
rows = _load("customers")
for r in rows:
if r.get("id") == customer_id:
return r
return None
def upsert_customer(key: str, wxid: str, remark_name: str = "", region: str = "",
age: str = "", gender: str = "", level: str = "", tags: Optional[List[str]] = None,
customer_id: Optional[str] = None) -> Dict:
"""拿货等级 leveltags 为标签列表,用于分群与问候。"""
rows = _load("customers")
if tags is None:
tags = []
rid = customer_id or str(uuid.uuid4())
for r in rows:
if r.get("id") == rid or (r.get("key") == key and r.get("wxid") == wxid and not customer_id):
r.update({
"key": key, "wxid": wxid, "remark_name": remark_name, "region": region,
"age": age, "gender": gender, "level": level, "tags": tags,
})
_save("customers", rows)
return r
new_row = {
"id": rid, "key": key, "wxid": wxid, "remark_name": remark_name,
"region": region, "age": age, "gender": gender, "level": level, "tags": tags,
}
rows.append(new_row)
_save("customers", rows)
return new_row
def delete_customer(customer_id: str) -> bool:
rows = _load("customers")
for i, r in enumerate(rows):
if r.get("id") == customer_id:
rows.pop(i)
_save("customers", rows)
return True
return False
def list_customer_tags(key: str) -> List[str]:
"""返回该 key 下客户档案中出现过的所有标签(去重、排序)。"""
rows = [r for r in _load("customers") if r.get("key") == key]
tags_set = set()
for r in rows:
for t in r.get("tags") or []:
if t and str(t).strip():
tags_set.add(str(t).strip())
return sorted(tags_set)
# ---------- 定时问候任务 R1-3 ----------
def list_greeting_tasks(key: Optional[str] = None) -> List[Dict]:
rows = _load("greeting_tasks")
if key:
rows = [r for r in rows if r.get("key") == key]
return sorted(rows, key=lambda x: x.get("send_time", "") or x.get("cron", ""))
def get_greeting_task(task_id: str) -> Optional[Dict]:
for r in _load("greeting_tasks"):
if r.get("id") == task_id:
return r
return None
def create_greeting_task(key: str, name: str, send_time: str, customer_tags: List[str],
template: str, use_qwen: bool = False) -> Dict:
rid = str(uuid.uuid4())
row = {
"id": rid, "key": key, "name": name, "send_time": send_time,
"customer_tags": customer_tags or [], "template": template, "use_qwen": use_qwen,
"enabled": True, "executed_at": None,
}
rows = _load("greeting_tasks")
rows.append(row)
_save("greeting_tasks", rows)
return row
def update_greeting_task(task_id: str, **kwargs) -> Optional[Dict]:
rows = _load("greeting_tasks")
for r in rows:
if r.get("id") == task_id:
for k, v in kwargs.items():
if k in ("name", "send_time", "cron", "customer_tags", "template", "use_qwen", "enabled", "executed_at"):
r[k] = v
_save("greeting_tasks", rows)
return r
return None
def delete_greeting_task(task_id: str) -> bool:
rows = _load("greeting_tasks")
for i, r in enumerate(rows):
if r.get("id") == task_id:
rows.pop(i)
_save("greeting_tasks", rows)
return True
return False
# ---------- 商品标签 R1-4 ----------
def list_product_tags(key: Optional[str] = None) -> List[Dict]:
rows = _load("product_tags")
if key:
rows = [r for r in rows if r.get("key") == key]
return rows
def create_product_tag(key: str, name: str) -> Dict:
rid = str(uuid.uuid4())
row = {"id": rid, "key": key, "name": name}
rows = _load("product_tags")
rows.append(row)
_save("product_tags", rows)
return row
def delete_product_tag(tag_id: str) -> bool:
rows = _load("product_tags")
for i, r in enumerate(rows):
if r.get("id") == tag_id:
rows.pop(i)
_save("product_tags", rows)
return True
return False
# ---------- 推送群组(客户群组) ----------
def list_push_groups(key: Optional[str] = None) -> List[Dict]:
rows = _load("push_groups")
if key:
rows = [r for r in rows if r.get("key") == key]
return rows
def create_push_group(key: str, name: str, customer_ids: List[str], tag_ids: List[str]) -> Dict:
rid = str(uuid.uuid4())
row = {"id": rid, "key": key, "name": name, "customer_ids": customer_ids or [], "tag_ids": tag_ids or []}
rows = _load("push_groups")
rows.append(row)
_save("push_groups", rows)
return row
def update_push_group(group_id: str, name: Optional[str] = None, customer_ids: Optional[List[str]] = None,
tag_ids: Optional[List[str]] = None) -> Optional[Dict]:
rows = _load("push_groups")
for r in rows:
if r.get("id") == group_id:
if name is not None:
r["name"] = name
if customer_ids is not None:
r["customer_ids"] = customer_ids
if tag_ids is not None:
r["tag_ids"] = tag_ids
_save("push_groups", rows)
return r
return None
def delete_push_group(group_id: str) -> bool:
rows = _load("push_groups")
for i, r in enumerate(rows):
if r.get("id") == group_id:
rows.pop(i)
_save("push_groups", rows)
return True
return False
# ---------- 推送任务(一键/定时发送) ----------
def list_push_tasks(key: Optional[str] = None, limit: int = 200) -> List[Dict]:
rows = _load("push_tasks")
if key:
rows = [r for r in rows if r.get("key") == key]
rows = sorted(rows, key=lambda x: x.get("created_at", ""), reverse=True)
return rows[:limit]
def create_push_task(key: str, product_tag_id: str, group_id: str, content: str,
send_at: Optional[str] = None) -> Dict:
rid = str(uuid.uuid4())
import time
row = {
"id": rid, "key": key, "product_tag_id": product_tag_id, "group_id": group_id,
"content": content, "send_at": send_at, "status": "pending",
"created_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
rows = _load("push_tasks")
rows.append(row)
_save("push_tasks", rows)
return row
def update_push_task_status(task_id: str, status: str) -> Optional[Dict]:
rows = _load("push_tasks")
for r in rows:
if r.get("id") == task_id:
r["status"] = status
_save("push_tasks", rows)
return r
return None
# ---------- WS 同步消息GetSyncMsg 结果) ----------
def append_sync_messages(key: str, messages: List[Dict], max_per_key: int = 500) -> None:
rows = _load("sync_messages")
for m in messages:
m["key"] = key
rows.append(m)
by_key: Dict[str, List[Dict]] = {}
for m in rows:
k = m.get("key", "")
by_key.setdefault(k, []).append(m)
new_rows = []
for lst in by_key.values():
new_rows.extend(lst[-max_per_key:])
_save("sync_messages", new_rows)
def list_sync_messages(key: str, limit: int = 100) -> List[Dict]:
rows = _load("sync_messages")
rows = [r for r in rows if r.get("key") == key]
# 统一按 CreateTime 排序(支持 int 时间戳与其它格式),新消息在前
rows = sorted(rows, key=lambda x: int(x.get("CreateTime") or 0) if isinstance(x.get("CreateTime"), (int, float)) else 0, reverse=True)
return rows[:limit]
def append_sent_message(key: str, to_user_name: str, content: str) -> None:
"""发送消息成功后写入一条「发出」记录,便于在实时消息页展示完整对话。"""
import time
append_sync_messages(key, [{"direction": "out", "ToUserName": to_user_name, "Content": content, "CreateTime": int(time.time())}])
# ---------- 模型管理多模型切换API Key 按模型配置) ----------
def list_models() -> List[Dict]:
rows = _load("models")
return sorted(rows, key=lambda x: (not x.get("is_current"), x.get("name") or ""))
def get_model(model_id: str) -> Optional[Dict]:
for r in _load("models"):
if r.get("id") == model_id:
return r
return None
def get_current_model() -> Optional[Dict]:
for r in _load("models"):
if r.get("is_current"):
return r
return None
def create_model(
name: str,
provider: str,
api_key: str,
base_url: str = "",
model_name: str = "",
is_current: bool = False,
) -> Dict:
rows = _load("models")
if is_current:
for r in rows:
r["is_current"] = False
rid = str(uuid.uuid4())
if provider == "qwen":
default_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
default_model = "qwen-turbo"
elif provider == "doubao":
default_base = "https://ark.cn-beijing.volces.com/api/v3"
default_model = "doubao-seed-2-0-pro-260215"
else:
default_base = "https://api.openai.com/v1"
default_model = "gpt-3.5-turbo"
row = {
"id": rid,
"name": name,
"provider": provider,
"api_key": api_key,
"base_url": (base_url or default_base).strip(),
"model_name": (model_name or default_model).strip(),
"is_current": is_current or len(rows) == 0,
}
if row["is_current"]:
for r in rows:
r["is_current"] = False
rows.append(row)
_save("models", rows)
return row
def update_model(
model_id: str,
name: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model_name: Optional[str] = None,
) -> Optional[Dict]:
rows = _load("models")
for r in rows:
if r.get("id") == model_id:
if name is not None:
r["name"] = name
if api_key is not None:
r["api_key"] = api_key
if base_url is not None:
r["base_url"] = base_url
if model_name is not None:
r["model_name"] = model_name
_save("models", rows)
return r
return None
def set_current_model(model_id: str) -> Optional[Dict]:
rows = _load("models")
found = None
for r in rows:
if r.get("id") == model_id:
r["is_current"] = True
found = r
else:
r["is_current"] = False
if found:
_save("models", rows)
return found
def delete_model(model_id: str) -> bool:
rows = _load("models")
for i, r in enumerate(rows):
if r.get("id") == model_id:
was_current = r.get("is_current")
rows.pop(i)
if was_current and rows:
rows[0]["is_current"] = True
_save("models", rows)
return True
return False