Files
wechatAiclaw/backend/main.py
2026-03-11 09:44:17 +08:00

1072 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import html
import logging
import os
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, List, Optional
from urllib.parse import urlencode
import httpx
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
try:
from backend import store
from backend.llm_client import chat as llm_chat
from backend.ws_sync import is_ws_connected, set_message_callback, start_ws_sync
except ImportError:
import store
from llm_client import chat as llm_chat
from ws_sync import is_ws_connected, set_message_callback, start_ws_sync
WECHAT_UPSTREAM_BASE_URL = os.getenv("WECHAT_UPSTREAM_BASE_URL", "http://localhost:8080").rstrip("/")
CHECK_STATUS_BASE_URL = os.getenv("CHECK_STATUS_BASE_URL", "http://113.44.162.180:7006").rstrip("/")
SLIDER_VERIFY_BASE_URL = os.getenv("SLIDER_VERIFY_BASE_URL", "http://113.44.162.180:7765").rstrip("/")
SLIDER_VERIFY_KEY = os.getenv("SLIDER_VERIFY_KEY", os.getenv("KEY", "408449830"))
# 发送文本消息swagger 中为 POST /message/SendTextMessagebody 为 SendMessageModelMsgItem 数组)
SEND_MSG_PATH = (os.getenv("SEND_MSG_PATH") or "/message/SendTextMessage").strip()
# 发送图片消息:部分上游为独立接口,或与文本同 path 仅 MsgType 不同(如 3=图片)
SEND_IMAGE_PATH = (os.getenv("SEND_IMAGE_PATH") or "").strip() or SEND_MSG_PATH
# 联系人列表7006 为 POST /friend/GetContactListbody 传 CurrentChatRoomContactSeq/CurrentWxcontactSeq=0
CONTACT_LIST_PATH = (os.getenv("CONTACT_LIST_PATH") or os.getenv("FRIEND_LIST_PATH") or "/friend/GetContactList").strip()
FRIEND_LIST_PATH = (os.getenv("FRIEND_LIST_PATH") or CONTACT_LIST_PATH).strip()
# 图片消息 MsgType部分上游为 0常见为 3
IMAGE_MSG_TYPE = int(os.getenv("IMAGE_MSG_TYPE", "3"))
# 按 key 缓存取码结果与 Data62供后续步骤使用
qrcode_store: dict = {}
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("wechat-backend")
def _is_self_sent(msg: dict) -> bool:
"""判断是否为当前账号自己发出的消息(则不由 AI 回复)。"""
if msg.get("direction") == "out":
return True
if msg.get("IsSelf") in (1, True, "1"):
return True
return False
def _allowed_ai_reply(key: str, from_user: str) -> bool:
"""分级处理:仅超级管理员或白名单内的联系人可获得 AI 回复,其他一律不回复。"""
if not from_user or not from_user.strip():
return False
cfg = store.get_ai_reply_config(key)
if not cfg:
return False
super_admins = set(cfg.get("super_admin_wxids") or [])
whitelist = set(cfg.get("whitelist_wxids") or [])
return from_user.strip() in super_admins or from_user.strip() in whitelist
async def _ai_takeover_reply(key: str, from_user: str, content: str) -> None:
"""收到他人消息时由 AI 接管:生成回复并发送。"""
if not from_user or not content or not content.strip():
return
try:
recent = store.list_sync_messages(key, limit=10)
# 仅取与该用户的最近几条作为上下文(简化:只取最后几条)
context = []
for m in reversed(recent):
c = (m.get("Content") or m.get("content") or "").strip()
if not c:
continue
if m.get("direction") == "out" and (m.get("ToUserName") or "").strip() == from_user:
context.append({"role": "assistant", "content": c})
elif (m.get("FromUserName") or m.get("from") or "").strip() == from_user and not _is_self_sent(m):
context.append({"role": "user", "content": c})
if len(context) >= 6:
break
if not context or context[-1].get("role") != "user":
context.append({"role": "user", "content": content})
text = await llm_chat(context)
if text and text.strip():
await _send_message_upstream(key, from_user, text.strip())
logger.info("AI takeover replied to %s: %s", from_user[:20], text.strip()[:50])
except Exception as e:
logger.exception("AI takeover reply error (from=%s): %s", from_user, e)
def _on_ws_message(key: str, data: dict) -> None:
"""GetSyncMsg 收到数据时:写入 store若为他人消息则 AI 接管对话。"""
msg_list = data.get("MsgList") or data.get("List") or data.get("msgList")
if isinstance(msg_list, list) and msg_list:
store.append_sync_messages(key, msg_list)
for m in msg_list:
if _is_self_sent(m):
continue
from_user = (m.get("FromUserName") or m.get("from") or "").strip()
content = (m.get("Content") or m.get("content") or "").strip()
msg_type = m.get("MsgType") or m.get("msgType")
if from_user and content and (msg_type in (1, None) or str(msg_type) == "1"): # 仅文本触发 AI
if not _allowed_ai_reply(key, from_user):
continue
try:
asyncio.get_running_loop().create_task(_ai_takeover_reply(key, from_user, content))
except RuntimeError:
pass
elif isinstance(data, list):
store.append_sync_messages(key, data)
for m in data:
if not isinstance(m, dict) or _is_self_sent(m):
continue
from_user = (m.get("FromUserName") or m.get("from") or "").strip()
content = (m.get("Content") or m.get("content") or "").strip()
msg_type = m.get("MsgType") or m.get("msgType")
if from_user and content and (msg_type in (1, None) or str(msg_type) == "1"):
if not _allowed_ai_reply(key, from_user):
continue
try:
asyncio.get_running_loop().create_task(_ai_takeover_reply(key, from_user, content))
except RuntimeError:
pass
else:
store.append_sync_messages(key, [data])
m = data if isinstance(data, dict) else {}
if not _is_self_sent(m):
from_user = (m.get("FromUserName") or m.get("from") or "").strip()
content = (m.get("Content") or m.get("content") or "").strip()
msg_type = m.get("MsgType") or m.get("msgType")
if from_user and content and (msg_type in (1, None) or str(msg_type) == "1"):
if not _allowed_ai_reply(key, from_user):
pass
else:
try:
asyncio.get_running_loop().create_task(_ai_takeover_reply(key, from_user, content))
except RuntimeError:
pass
async def _run_greeting_scheduler() -> None:
"""定时检查到期问候任务,通过发送消息接口向匹配客户发送,并标记已执行。"""
check_interval = 30
while True:
try:
await asyncio.sleep(check_interval)
now = datetime.now()
all_tasks = store.list_greeting_tasks(key=None)
for task in all_tasks:
if not task.get("enabled"):
continue
if task.get("executed_at"):
continue
send_time = task.get("send_time") or task.get("cron")
if not send_time:
continue
dt = _parse_send_time(send_time)
if not dt or dt > now:
continue
task_id = task.get("id")
key = task.get("key")
customer_tags = set(task.get("customer_tags") or [])
template = (task.get("template") or "").strip() or "{{name}},您好!"
use_qwen = bool(task.get("use_qwen"))
customers = store.list_customers(key)
if customer_tags:
customers = [c for c in customers if set(c.get("tags") or []) & customer_tags]
for c in customers:
wxid = c.get("wxid")
if not wxid:
continue
remark_name = (c.get("remark_name") or "").strip() or wxid
if use_qwen:
user = f"请生成一句简短的微信问候语1-2句话客户备注名{remark_name}"
region = (c.get("region") or "").strip()
if region:
user += f",地区:{region}"
tags = c.get("tags") or []
if tags:
user += f",标签:{','.join(tags)}"
user += "。不要解释,只输出问候语本身。"
try:
content = await llm_chat([{"role": "user", "content": user}])
except Exception as e:
logger.warning("Greeting task %s llm_chat error: %s", task_id, e)
content = template.replace("{{name}}", remark_name)
if not content or not content.strip():
content = template.replace("{{name}}", remark_name)
else:
content = template.replace("{{name}}", remark_name)
try:
await _send_message_upstream(key, wxid, content)
logger.info("Greeting task %s sent to %s", task_id, wxid)
except Exception as e:
logger.warning("Greeting task %s send to %s failed: %s", task_id, wxid, e)
store.update_greeting_task(task_id, executed_at=now.isoformat(), enabled=False)
logger.info("Greeting task %s executed_at set", task_id)
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Greeting scheduler error: %s", e)
@asynccontextmanager
async def lifespan(app: FastAPI):
set_message_callback(_on_ws_message)
asyncio.create_task(start_ws_sync())
scheduler = asyncio.create_task(_run_greeting_scheduler())
yield
scheduler.cancel()
try:
await scheduler
except asyncio.CancelledError:
pass
app = FastAPI(title="WeChat Admin Backend (FastAPI)", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QrCodeRequest(BaseModel):
key: str
Proxy: Optional[str] = ""
IpadOrmac: Optional[str] = ""
Check: Optional[bool] = False
@app.middleware("http")
async def log_requests(request: Request, call_next):
logger.info("HTTP %s %s from %s", request.method, request.url.path, request.client.host if request.client else "-")
response = await call_next(request)
logger.info("HTTP %s %s -> %s", request.method, request.url.path, response.status_code)
return response
@app.get("/health")
async def health() -> dict:
logger.info("Health check")
return {"status": "ok", "backend": "fastapi", "upstream": WECHAT_UPSTREAM_BASE_URL}
@app.get("/api/ws-status")
async def api_ws_status() -> dict:
"""WSGetSyncMsg连接状态供前端在掉线时跳转登录页。"""
return {"connected": is_ws_connected()}
@app.post("/auth/qrcode")
async def get_login_qrcode(body: QrCodeRequest):
key = body.key
if not key:
raise HTTPException(status_code=400, detail="key is required")
payload = body.dict(exclude={"key"})
url = f"{WECHAT_UPSTREAM_BASE_URL}/login/GetLoginQrCodeNewDirect"
logger.info("GetLoginQrCodeNewDirect: key=%s, payload=%s, url=%s", key, payload, url)
try:
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(url, params={"key": key}, json=payload)
except Exception as exc:
logger.exception("Error calling upstream GetLoginQrCodeNewDirect: %s", exc)
raise HTTPException(
status_code=502,
detail={"error": "upstream_connect_error", "detail": str(exc)},
) from exc
body_text = resp.text[:500]
if resp.status_code >= 400:
logger.warning(
"Upstream GetLoginQrCodeNewDirect bad response: status=%s, body=%s",
resp.status_code,
body_text,
)
raise HTTPException(
status_code=502,
detail={
"error": "upstream_bad_response",
"status_code": resp.status_code,
"body": body_text,
},
)
logger.info(
"Upstream GetLoginQrCodeNewDirect success: status=%s, body=%s",
resp.status_code,
body_text,
)
data = resp.json()
# 第一步:记录完整返回并保存 Data62供第二步滑块自动填充参数
try:
data62 = data.get("Data62") or (data.get("Data") or {}).get("data62") or ""
qrcode_store[key] = {"data62": data62, "response": data}
# 在返回中拼接已存储标记,便于后续步骤使用同一 key 取 data62
data["_data62_stored"] = True
data["_data62_length"] = len(data62)
logger.info("Stored Data62 for key=%s (len=%s)", key, len(data62))
except Exception as e:
logger.warning("Store qrcode data for key=%s failed: %s", key, e)
return data
@app.get("/auth/status")
async def get_online_status(
key: str = Query(..., description="账号唯一标识"),
):
if not key:
raise HTTPException(status_code=400, detail="key is required")
url = f"{WECHAT_UPSTREAM_BASE_URL}/login/GetLoginStatus"
logger.info("GetLoginStatus: key=%s, url=%s", key, url)
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(url, params={"key": key})
except Exception as exc:
logger.exception("Error calling upstream GetLoginStatus: %s", exc)
raise HTTPException(status_code=502, detail=f"upstream_error: {exc}") from exc
body_text = resp.text[:500]
logger.info(
"Upstream GetLoginStatus response: status=%s, body=%s",
resp.status_code,
body_text,
)
return resp.json()
def _extract_clean_ticket(obj: dict) -> Optional[str]:
"""从扫码状态返回中提取 ticket去掉乱码只保留可见 ASCII 到第一个非法字符前)。"""
if not obj or not isinstance(obj, dict):
return None
d = obj.get("Data") if isinstance(obj.get("Data"), dict) else obj
raw = (
(d.get("ticket") if d else None)
or obj.get("ticket")
or obj.get("Ticket")
)
if not raw:
wvu = obj.get("wechat_verify_url") or ""
if isinstance(wvu, str) and "ticket=" in wvu:
raw = wvu.split("ticket=", 1)[1].split("&")[0]
if not raw or not isinstance(raw, str):
return None
clean = []
for ch in raw:
code = ord(ch)
if code == 0xFFFD or code < 32 or code > 126:
break
clean.append(ch)
return "".join(clean) if clean else None
@app.get("/auth/scan-status")
async def check_scan_status(
key: str = Query(..., description="账号唯一标识"),
):
if not key:
raise HTTPException(status_code=400, detail="key is required")
url = f"{CHECK_STATUS_BASE_URL}/login/CheckLoginStatus"
logger.info("CheckLoginStatus: key=%s, url=%s", key, url)
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(url, params={"key": key})
except Exception as exc:
logger.exception("Error calling upstream CheckLoginStatus: %s", exc)
raise HTTPException(status_code=502, detail=f"upstream_error: {exc}") from exc
body_full = resp.text
logger.info(
"Upstream CheckLoginStatus response: status=%s, body=%s",
resp.status_code,
body_full[:2000] if len(body_full) > 2000 else body_full,
)
data = resp.json()
ticket = _extract_clean_ticket(data)
if ticket:
# 不调用滑块服务;返回自带预填表单的页面 pathiframe 加载后自动填充 Key/Data62/Original Ticket用户点「开始验证」提交到第三方 7765
stored = qrcode_store.get(key) or {}
data62 = stored.get("data62") or ""
params = {"key": SLIDER_VERIFY_KEY, "ticket": ticket}
if data62:
params["data62"] = data62
data["slider_url"] = f"/auth/slider-form?{urlencode(params)}"
logger.info(
"Attached slider_url (slider-form) for key=%s (ticket len=%s, data62 len=%s)",
key,
len(ticket),
len(data62),
)
return data
def _slider_form_html(key_val: str, data62_val: str, ticket_val: str) -> str:
"""生成滑块表单页Key、Data62、Original Ticket 已预填,提交到第三方 7765。"""
k = html.escape(key_val, quote=True)
d = html.escape(data62_val, quote=True)
t = html.escape(ticket_val, quote=True)
action = html.escape(SLIDER_VERIFY_BASE_URL, quote=True)
return f"""<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>滑块验证</title>
<style>
body {{ font-family: sans-serif; background: #f0f0f0; margin: 20px; }}
.card {{ background: #fff; border-radius: 8px; padding: 20px; max-width: 480px; margin: 0 auto; box-shadow: 0 1px 3px rgba(0,0,0,.1); }}
h2 {{ margin-top: 0; }}
label {{ display: block; margin: 10px 0 4px; color: #333; }}
input {{ width: 100%; box-sizing: border-box; padding: 8px; border: 1px solid #ccc; border-radius: 4px; }}
button {{ margin-top: 16px; padding: 10px 20px; background: #07c160; color: #fff; border: none; border-radius: 4px; cursor: pointer; font-size: 14px; }}
button:hover {{ background: #06ad56; }}
.hint {{ font-size: 12px; color: #888; margin-top: 12px; }}
</style>
</head>
<body>
<div class="card">
<h2>滑块验证</h2>
<form id="f" action="{action}" method="get" target="_top">
<label>Key:</label>
<input type="text" name="key" value="{k}" placeholder="请输入key" />
<label>Data62:</label>
<input type="text" name="data62" value="{d}" placeholder="请输入data62" />
<label>Original Ticket:</label>
<input type="text" name="original_ticket" value="{t}" placeholder="请输入original_ticket" />
<button type="submit">开始验证</button>
</form>
<p class="hint">参数已自动填充,点击「开始验证」将提交到第三方滑块服务。</p>
</div>
</body>
</html>"""
@app.get("/auth/slider-form", response_class=HTMLResponse)
async def slider_form(
key: str = Query(..., description="Key提交到第三方滑块"),
data62: str = Query("", description="Data62"),
ticket: str = Query(..., description="Original Ticket"),
):
"""返回带 Key/Data62/Original Ticket 预填的表单页,提交到第三方 7765供 iframe 加载并自动填充。"""
return HTMLResponse(content=_slider_form_html(key, data62, ticket))
# ---------- R1-2 客户画像 / R1-3 定时问候 / R1-4 分群推送 / 消息与发送 ----------
class CustomerCreate(BaseModel):
key: str
wxid: str
remark_name: Optional[str] = ""
region: Optional[str] = ""
age: Optional[str] = ""
gender: Optional[str] = ""
level: Optional[str] = "" # 拿货等级
tags: Optional[List[str]] = None
class GreetingTaskCreate(BaseModel):
key: str
name: str
send_time: str # ISO 格式触发时间,如 2026-03-11T14:30:00必须为未来时间
customer_tags: Optional[List[str]] = None
template: str
use_qwen: Optional[bool] = False
class ProductTagCreate(BaseModel):
key: str
name: str
class PushGroupCreate(BaseModel):
key: str
name: str
customer_ids: Optional[List[str]] = None
tag_ids: Optional[List[str]] = None
class PushTaskCreate(BaseModel):
key: str
product_tag_id: str
group_id: str
content: str
send_at: Optional[str] = None
class SendMessageBody(BaseModel):
key: str
to_user_name: str
content: str
class BatchSendItem(BaseModel):
to_user_name: str
content: str
class BatchSendBody(BaseModel):
key: str
items: List[BatchSendItem]
class SendImageBody(BaseModel):
key: str
to_user_name: str
image_content: str # 图片 base64 或 URL依上游约定
text_content: Optional[str] = ""
at_wxid_list: Optional[List[str]] = None
class QwenGenerateBody(BaseModel):
prompt: str
system: Optional[str] = None
@app.get("/api/customers")
async def api_list_customers(key: str = Query(..., description="账号 key")):
return {"items": store.list_customers(key)}
@app.post("/api/customers")
async def api_upsert_customer(body: CustomerCreate):
row = store.upsert_customer(
body.key, body.wxid,
remark_name=body.remark_name or "",
region=body.region or "",
age=body.age or "",
gender=body.gender or "",
level=body.level or "",
tags=body.tags,
)
return row
@app.get("/api/customers/{customer_id}")
async def api_get_customer(customer_id: str):
row = store.get_customer(customer_id)
if not row:
raise HTTPException(status_code=404, detail="customer not found")
return row
@app.get("/api/customer-tags")
async def api_list_customer_tags(key: str = Query(..., description="账号 key")):
"""返回该 key 下客户档案中出现的所有标签,供定时任务等下拉选择。"""
return {"tags": store.list_customer_tags(key)}
@app.delete("/api/customers/{customer_id}")
async def api_delete_customer(customer_id: str):
if not store.delete_customer(customer_id):
raise HTTPException(status_code=404, detail="customer not found")
return {"ok": True}
@app.get("/api/greeting-tasks")
async def api_list_greeting_tasks(key: str = Query(..., description="账号 key")):
return {"items": store.list_greeting_tasks(key)}
def _parse_send_time(s: str) -> Optional[datetime]:
"""解析 ISO 时间字符串,返回 datetime无时区"""
try:
if "T" in s:
return datetime.fromisoformat(s.replace("Z", "+00:00")[:19])
return datetime.strptime(s[:19], "%Y-%m-%d %H:%M:%S")
except Exception:
return None
@app.post("/api/greeting-tasks")
async def api_create_greeting_task(body: GreetingTaskCreate):
dt = _parse_send_time(body.send_time)
if not dt:
raise HTTPException(status_code=400, detail="触发时间格式无效,请使用 日期+时分秒 选择器")
if dt <= datetime.now():
raise HTTPException(status_code=400, detail="触发时间必须是未来时间,请重新选择")
row = store.create_greeting_task(
body.key, body.name, body.send_time,
customer_tags=body.customer_tags or [],
template=body.template,
use_qwen=body.use_qwen or False,
)
return row
@app.patch("/api/greeting-tasks/{task_id}")
async def api_update_greeting_task(task_id: str, body: dict):
if "send_time" in body:
dt = _parse_send_time(body["send_time"])
if not dt:
raise HTTPException(status_code=400, detail="触发时间格式无效")
if dt <= datetime.now():
raise HTTPException(status_code=400, detail="触发时间必须是未来时间")
row = store.update_greeting_task(task_id, **{k: v for k, v in body.items() if k in ("name", "send_time", "customer_tags", "template", "use_qwen", "enabled")})
if not row:
raise HTTPException(status_code=404, detail="task not found")
return row
@app.delete("/api/greeting-tasks/{task_id}")
async def api_delete_greeting_task(task_id: str):
if not store.delete_greeting_task(task_id):
raise HTTPException(status_code=404, detail="task not found")
return {"ok": True}
@app.get("/api/product-tags")
async def api_list_product_tags(key: str = Query(..., description="账号 key")):
return {"items": store.list_product_tags(key)}
@app.post("/api/product-tags")
async def api_create_product_tag(body: ProductTagCreate):
return store.create_product_tag(body.key, body.name)
@app.delete("/api/product-tags/{tag_id}")
async def api_delete_product_tag(tag_id: str):
if not store.delete_product_tag(tag_id):
raise HTTPException(status_code=404, detail="tag not found")
return {"ok": True}
@app.get("/api/push-groups")
async def api_list_push_groups(key: str = Query(..., description="账号 key")):
return {"items": store.list_push_groups(key)}
@app.post("/api/push-groups")
async def api_create_push_group(body: PushGroupCreate):
return store.create_push_group(body.key, body.name, body.customer_ids or [], body.tag_ids or [])
@app.patch("/api/push-groups/{group_id}")
async def api_update_push_group(group_id: str, body: dict):
row = store.update_push_group(
group_id,
name=body.get("name"),
customer_ids=body.get("customer_ids"),
tag_ids=body.get("tag_ids"),
)
if not row:
raise HTTPException(status_code=404, detail="group not found")
return row
@app.delete("/api/push-groups/{group_id}")
async def api_delete_push_group(group_id: str):
if not store.delete_push_group(group_id):
raise HTTPException(status_code=404, detail="group not found")
return {"ok": True}
@app.get("/api/push-tasks")
async def api_list_push_tasks(key: str = Query(..., description="账号 key"), limit: int = Query(100, le=500)):
return {"items": store.list_push_tasks(key, limit=limit)}
@app.post("/api/push-tasks")
async def api_create_push_task(body: PushTaskCreate):
return store.create_push_task(body.key, body.product_tag_id, body.group_id, body.content, body.send_at)
@app.get("/api/messages")
async def api_list_messages(key: str = Query(..., description="账号 key"), limit: int = Query(100, le=500)):
return {"items": store.list_sync_messages(key, limit=limit)}
async def _send_message_upstream(key: str, to_user_name: str, content: str) -> dict:
"""调用上游发送文本消息;成功时写入发出记录并返回响应,失败抛 HTTPException。"""
url = f"{WECHAT_UPSTREAM_BASE_URL.rstrip('/')}{SEND_MSG_PATH}"
payload = {"MsgItem": [{"ToUserName": to_user_name, "MsgType": 1, "TextContent": content}]}
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(url, params={"key": key}, json=payload)
if resp.status_code >= 400:
body_preview = resp.text[:400] if resp.text else ""
logger.warning("Send message upstream %s: %s", resp.status_code, body_preview)
raise HTTPException(
status_code=502,
detail=f"upstream_returned_{resp.status_code}: {body_preview}",
)
store.append_sent_message(key, to_user_name, content)
try:
return resp.json()
except Exception:
return {"ok": True, "raw": resp.text[:500]}
async def _send_batch_upstream(key: str, items: List[dict]) -> dict:
"""批量发送:一次请求多个 MsgItem快速分发。"""
url = f"{WECHAT_UPSTREAM_BASE_URL.rstrip('/')}{SEND_MSG_PATH}"
msg_items = []
for it in items:
to_user = (it.get("to_user_name") or it.get("ToUserName") or "").strip()
content = (it.get("content") or it.get("TextContent") or "").strip()
if not to_user:
continue
msg_items.append({"ToUserName": to_user, "MsgType": 1, "TextContent": content})
if not msg_items:
raise HTTPException(status_code=400, detail="items 中至少需要一条有效 to_user_name 与 content")
payload = {"MsgItem": msg_items}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(url, params={"key": key}, json=payload)
if resp.status_code >= 400:
body_preview = resp.text[:400] if resp.text else ""
logger.warning("Batch send upstream %s: %s", resp.status_code, body_preview)
raise HTTPException(
status_code=502,
detail=f"upstream_returned_{resp.status_code}: {body_preview}",
)
for it in msg_items:
store.append_sent_message(key, it["ToUserName"], it.get("TextContent", ""))
try:
return resp.json()
except Exception:
return {"ok": True, "sent": len(msg_items), "raw": resp.text[:500]}
async def _send_image_upstream(key: str, to_user_name: str, image_content: str,
text_content: Optional[str] = "",
at_wxid_list: Optional[List[str]] = None) -> dict:
"""发送图片消息MsgItem 含 ImageContent、MsgType=3或 0依上游可选 TextContent、AtWxIDList。"""
url = f"{WECHAT_UPSTREAM_BASE_URL.rstrip('/')}{SEND_IMAGE_PATH}"
item = {
"ToUserName": to_user_name,
"MsgType": IMAGE_MSG_TYPE,
"ImageContent": image_content or "",
"TextContent": text_content or "",
"AtWxIDList": at_wxid_list or [],
}
payload = {"MsgItem": [item]}
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(url, params={"key": key}, json=payload)
if resp.status_code >= 400:
body_preview = resp.text[:400] if resp.text else ""
logger.warning("Send image upstream %s: %s", resp.status_code, body_preview)
raise HTTPException(
status_code=502,
detail=f"upstream_returned_{resp.status_code}: {body_preview}",
)
store.append_sent_message(key, to_user_name, "[图片]" + ((" " + text_content) if text_content else ""))
try:
return resp.json()
except Exception:
return {"ok": True, "raw": resp.text[:500]}
@app.post("/api/send-message")
async def api_send_message(body: SendMessageBody):
try:
return await _send_message_upstream(body.key, body.to_user_name, body.content)
except HTTPException:
raise
except Exception as exc:
logger.exception("Send message upstream error: %s", exc)
raise HTTPException(status_code=502, detail=f"upstream_error: {exc}") from exc
@app.post("/api/send-batch")
async def api_send_batch(body: BatchSendBody):
"""快速群发:一次请求批量发送给多人,支持从好友/客户列表选择后调用。"""
items = [{"to_user_name": it.to_user_name, "content": it.content} for it in body.items]
try:
return await _send_batch_upstream(body.key, items)
except HTTPException:
raise
except Exception as exc:
logger.exception("Batch send error: %s", exc)
raise HTTPException(status_code=502, detail=f"upstream_error: {exc}") from exc
@app.post("/api/send-image")
async def api_send_image(body: SendImageBody):
"""发送图片消息快捷方式,参数对应 MsgItemImageContent、TextContent、ToUserName、AtWxIDList。"""
try:
return await _send_image_upstream(
body.key,
body.to_user_name,
body.image_content,
text_content=body.text_content or "",
at_wxid_list=body.at_wxid_list,
)
except HTTPException:
raise
except Exception as exc:
logger.exception("Send image error: %s", exc)
raise HTTPException(status_code=502, detail=f"upstream_error: {exc}") from exc
def _normalize_contact_list(raw: Any) -> List[dict]:
"""将上游 GetContactList 多种返回格式统一为 [ { wxid, remark_name, ... } ]。"""
items = []
if isinstance(raw, list):
items = raw
elif isinstance(raw, dict):
data = raw.get("Data") or raw.get("data") or raw
if isinstance(data, list):
items = data
elif isinstance(data, dict):
items = (
data.get("ContactList")
or data.get("contactList")
or data.get("WxcontactList")
or data.get("wxcontactList")
or data.get("CachedContactList")
or data.get("List")
or data.get("list")
or data.get("items")
or []
)
items = items or raw.get("items") or raw.get("list") or raw.get("List") or []
result = []
for x in items:
if not isinstance(x, dict):
continue
wxid = (
x.get("wxid")
or x.get("Wxid")
or x.get("UserName")
or x.get("userName")
or x.get("Alias")
or ""
)
remark = (
x.get("remark_name")
or x.get("RemarkName")
or x.get("NickName")
or x.get("nickName")
or x.get("DisplayName")
or wxid
)
result.append({"wxid": wxid, "remark_name": remark, **{k: v for k, v in x.items() if k not in ("wxid", "Wxid", "remark_name", "RemarkName")}})
return result
# 上游 GetContactList 请求体CurrentChatRoomContactSeq、CurrentWxcontactSeq 传 0 表示拉取全量
GET_CONTACT_LIST_BODY = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
@app.get("/api/contact-list")
async def api_contact_list(key: str = Query(..., description="账号 key")):
"""获取全部联系人POST 上游body 为 CurrentChatRoomContactSeq/CurrentWxcontactSeq=0key 走 query。"""
base = WECHAT_UPSTREAM_BASE_URL.rstrip("/")
path = CONTACT_LIST_PATH if CONTACT_LIST_PATH.startswith("/") else f"/{CONTACT_LIST_PATH}"
url = f"{base}{path}"
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(
url,
params={"key": key},
json=GET_CONTACT_LIST_BODY,
)
if resp.status_code >= 400:
logger.warning("GetContactList %s: %s", resp.status_code, resp.text[:200])
return {"items": [], "error": resp.text[:200]}
raw = resp.json()
# 日志便于确认 7006 返回结构(不打印完整列表)
if isinstance(raw, dict):
data = raw.get("Data") or raw.get("data")
data_keys = list(data.keys()) if isinstance(data, dict) else getattr(data, "__name__", type(data).__name__)
logger.info("GetContactList response keys: raw=%s, Data=%s", list(raw.keys()), data_keys)
items = _normalize_contact_list(raw)
if not items and isinstance(raw, dict):
items = _normalize_contact_list(raw.get("Data") or raw.get("data") or raw)
logger.info("GetContactList normalized items count: %s", len(items))
return {"items": items}
except Exception as e:
logger.warning("GetContactList error: %s", e)
return {"items": [], "error": str(e)}
@app.get("/api/friends")
async def api_list_friends(key: str = Query(..., description="账号 key")):
"""好友列表:代理上游联系人接口,与 /api/contact-list 同源;否则返回客户档案。"""
return await api_contact_list(key)
def _friends_fallback(key: str) -> List[dict]:
"""用客户档案作为可选联系人,便于在管理页选择群发对象。"""
customers = store.list_customers(key)
return [
{"wxid": c.get("wxid"), "remark_name": c.get("remark_name") or c.get("wxid"), "id": c.get("id")}
for c in customers
if c.get("wxid")
]
# ---------- AI 接管回复配置(白名单 + 超级管理员) ----------
class AIReplyConfigUpdate(BaseModel):
key: str
super_admin_wxids: Optional[List[str]] = None
whitelist_wxids: Optional[List[str]] = None
@app.get("/api/ai-reply-config")
async def api_get_ai_reply_config(key: str = Query(..., description="账号 key")):
"""获取当前账号的 AI 回复配置:超级管理员与白名单 wxid 列表。"""
cfg = store.get_ai_reply_config(key)
if not cfg:
return {"key": key, "super_admin_wxids": [], "whitelist_wxids": []}
return cfg
@app.patch("/api/ai-reply-config")
async def api_update_ai_reply_config(body: AIReplyConfigUpdate):
"""设置 AI 回复白名单与超级管理员:仅列表内联系人会收到 AI 自动回复。"""
return store.update_ai_reply_config(
body.key,
super_admin_wxids=body.super_admin_wxids,
whitelist_wxids=body.whitelist_wxids,
)
# ---------- 模型管理多模型切换API Key 按模型配置) ----------
class ModelCreate(BaseModel):
name: str
provider: str # qwen | openai
api_key: str
base_url: Optional[str] = ""
model_name: Optional[str] = ""
is_current: Optional[bool] = False
class ModelUpdate(BaseModel):
name: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
model_name: Optional[str] = None
def _mask_api_key(m: dict) -> dict:
if not m or not isinstance(m, dict):
return m
out = dict(m)
if out.get("api_key"):
out["api_key"] = "***"
return out
@app.get("/api/models")
async def api_list_models():
return {"items": [_mask_api_key(m) for m in store.list_models()]}
@app.get("/api/models/current")
async def api_get_current_model():
m = store.get_current_model()
if not m:
return {"current": None}
return {"current": _mask_api_key(m)}
@app.post("/api/models")
async def api_create_model(body: ModelCreate):
if body.provider not in ("qwen", "openai", "doubao"):
raise HTTPException(status_code=400, detail="provider must be qwen, openai or doubao")
row = store.create_model(
name=body.name,
provider=body.provider,
api_key=body.api_key,
base_url=body.base_url or "",
model_name=body.model_name or "",
is_current=body.is_current or False,
)
return _mask_api_key(row)
@app.patch("/api/models/{model_id}")
async def api_update_model(model_id: str, body: ModelUpdate):
row = store.update_model(
model_id,
name=body.name,
api_key=body.api_key,
base_url=body.base_url,
model_name=body.model_name,
)
if not row:
raise HTTPException(status_code=404, detail="model not found")
return _mask_api_key(row)
@app.post("/api/models/{model_id}/set-current")
async def api_set_current_model(model_id: str):
row = store.set_current_model(model_id)
if not row:
raise HTTPException(status_code=404, detail="model not found")
return _mask_api_key(row)
@app.delete("/api/models/{model_id}")
async def api_delete_model(model_id: str):
if not store.delete_model(model_id):
raise HTTPException(status_code=404, detail="model not found")
return {"ok": True}
@app.post("/api/qwen/generate")
async def api_qwen_generate(body: QwenGenerateBody):
"""所有对话生成由当前选中的模型接管,不再使用环境变量兜底。"""
messages = []
if body.system:
messages.append({"role": "system", "content": body.system})
messages.append({"role": "user", "content": body.prompt})
text = await llm_chat(messages)
if text is None:
raise HTTPException(status_code=503, detail="请在「模型管理」页添加并选中模型、填写 API Key")
return {"text": text}
@app.post("/api/qwen/generate-greeting")
async def api_qwen_generate_greeting(
remark_name: str = Query(...),
region: str = Query(""),
tags: Optional[str] = Query(None),
):
"""问候语生成由当前选中的模型接管。"""
tag_list = [t.strip() for t in (tags or "").split(",") if t.strip()]
user = f"请生成一句简短的微信问候语1-2句话客户备注名{remark_name}"
if region:
user += f",地区:{region}"
if tag_list:
user += f",标签:{','.join(tag_list)}"
user += "。不要解释,只输出问候语本身。"
text = await llm_chat([{"role": "user", "content": user}])
if text is None:
raise HTTPException(status_code=503, detail="请在「模型管理」页添加并选中模型、填写 API Key")
return {"text": text}
class LogoutBody(BaseModel):
key: str
@app.post("/auth/logout")
async def logout(body: LogoutBody):
key = body.key
if not key:
raise HTTPException(status_code=400, detail="key is required")
url = f"{WECHAT_UPSTREAM_BASE_URL}/login/LogOut"
logger.info("LogOut: key=%s, url=%s", key, url)
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(url, params={"key": key})
except Exception as exc:
logger.exception("Error calling upstream LogOut: %s", exc)
raise HTTPException(status_code=502, detail=f"upstream_error: {exc}") from exc
body_text = resp.text[:500]
logger.info(
"Upstream LogOut response: status=%s, body=%s",
resp.status_code,
body_text,
)
return resp.json()