124 lines
5.3 KiB
Python
124 lines
5.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""多模型统一调用:从 store 读取当前模型配置(API Key、base_url、model_name),按 OpenAI 兼容或豆包 Responses API 请求。"""
|
||
import logging
|
||
from typing import List, Optional
|
||
|
||
logger = logging.getLogger("wechat-backend.llm")
|
||
|
||
try:
|
||
from backend import store
|
||
except ImportError:
|
||
import store
|
||
|
||
|
||
def _doubao_input_from_messages(messages: List[dict]) -> list:
|
||
"""将 [{"role":"user","content":"..."}] 转为豆包 Responses API 的 input 格式。"""
|
||
result = []
|
||
for m in messages:
|
||
role = (m.get("role") or "user").lower()
|
||
if role == "system":
|
||
role = "user"
|
||
content = m.get("content")
|
||
if isinstance(content, str):
|
||
content = [{"type": "input_text", "text": content}]
|
||
elif isinstance(content, list):
|
||
# 已是多模态格式,或需转为 input_text
|
||
out = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
if item.get("type") == "input_text":
|
||
out.append(item)
|
||
elif "text" in item:
|
||
out.append({"type": "input_text", "text": item["text"]})
|
||
else:
|
||
out.append({"type": "input_text", "text": str(item)})
|
||
content = out if out else [{"type": "input_text", "text": ""}]
|
||
else:
|
||
content = [{"type": "input_text", "text": str(content or "")}]
|
||
result.append({"role": role, "content": content})
|
||
return result
|
||
|
||
|
||
async def _chat_doubao(api_key: str, base_url: str, model_name: str, messages: List[dict]) -> Optional[str]:
|
||
"""豆包(Volcengine ARK)Responses API:POST /responses,input 格式见示例。"""
|
||
import httpx
|
||
url = base_url.rstrip("/") + "/responses"
|
||
payload = {"model": model_name, "input": _doubao_input_from_messages(messages)}
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
r = await client.post(
|
||
url,
|
||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||
json=payload,
|
||
)
|
||
if r.status_code != 200:
|
||
logger.warning("doubao API status=%s body=%s", r.status_code, r.text[:500])
|
||
return None
|
||
data = r.json()
|
||
# 常见返回结构:output.output_text / output.text / choices[0].message.content
|
||
output = data.get("output") or data
|
||
if isinstance(output, dict):
|
||
text = output.get("output_text") or output.get("text")
|
||
if isinstance(text, str):
|
||
return text
|
||
choices = output.get("choices")
|
||
if choices and len(choices) > 0:
|
||
msg = choices[0].get("message") or choices[0]
|
||
if isinstance(msg, dict) and msg.get("content"):
|
||
return msg["content"]
|
||
if "choices" in data and data["choices"]:
|
||
c = data["choices"][0]
|
||
msg = c.get("message", c)
|
||
if isinstance(msg, dict) and msg.get("content"):
|
||
return msg["content"]
|
||
logger.warning("doubao unknown response shape: %s", list(data.keys()))
|
||
return None
|
||
except Exception as e:
|
||
logger.exception("doubao chat error: %s", e)
|
||
return None
|
||
|
||
|
||
async def chat(messages: List[dict], model_id: Optional[str] = None) -> Optional[str]:
|
||
"""
|
||
使用已配置的模型进行对话。
|
||
model_id: 若指定则用该模型,否则用当前选中的模型。
|
||
返回 assistant 回复文本。
|
||
"""
|
||
if model_id:
|
||
model_config = store.get_model(model_id)
|
||
else:
|
||
model_config = store.get_current_model()
|
||
if not model_config or not model_config.get("api_key"):
|
||
logger.warning("No model configured or no api_key")
|
||
return None
|
||
api_key = model_config["api_key"]
|
||
base_url = (model_config.get("base_url") or "").strip()
|
||
model_name = (model_config.get("model_name") or "qwen-turbo").strip()
|
||
provider = (model_config.get("provider") or "openai").lower()
|
||
if not base_url and provider == "qwen":
|
||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
if not base_url and provider == "doubao":
|
||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||
if not base_url:
|
||
base_url = "https://api.openai.com/v1"
|
||
# 豆包:火山方舟支持 chat/completions 兼容接口,优先走统一 OpenAI 客户端;若需 Responses API 可走 _chat_doubao
|
||
if provider == "doubao":
|
||
try:
|
||
from openai import AsyncOpenAI
|
||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||
r = await client.chat.completions.create(model=model_name, messages=messages)
|
||
if r.choices and r.choices[0].message.content:
|
||
return r.choices[0].message.content
|
||
except Exception as e:
|
||
logger.info("doubao chat/completions failed, try responses API: %s", e)
|
||
return await _chat_doubao(api_key, base_url, model_name, messages)
|
||
try:
|
||
from openai import AsyncOpenAI
|
||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||
r = await client.chat.completions.create(model=model_name, messages=messages)
|
||
text = r.choices[0].message.content if r.choices else None
|
||
return text
|
||
except Exception as e:
|
||
logger.exception("llm chat error: %s", e)
|
||
return None
|