Files
AiTool/backend/app/services/ai_service.py
2026-03-15 16:38:59 +08:00

244 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import base64
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, Tuple
from openai import AsyncOpenAI
from openai import NotFoundError as OpenAINotFoundError
AI_CONFIG_PATH = Path("data/ai_config.json")
AI_CONFIGS_PATH = Path("data/ai_configs.json")
def get_active_ai_config() -> Dict[str, Any]:
"""
从 data/ai_configs.json 读取当前选用的配置;若无则从旧版 ai_config.json 迁移并返回。
供 router 与内部调用。
"""
defaults = {
"id": "",
"name": "",
"provider": "OpenAI",
"api_key": "",
"base_url": "",
"model_name": "gpt-4o-mini",
"temperature": 0.2,
"system_prompt_override": "",
}
if AI_CONFIGS_PATH.exists():
try:
data = json.loads(AI_CONFIGS_PATH.read_text(encoding="utf-8"))
configs = data.get("configs") or []
active_id = data.get("active_id") or ""
for c in configs:
if c.get("id") == active_id:
return {**defaults, **c}
if configs:
return {**defaults, **configs[0]}
except Exception:
pass
# 兼容旧版单文件
if AI_CONFIG_PATH.exists():
try:
data = json.loads(AI_CONFIG_PATH.read_text(encoding="utf-8"))
return {**defaults, **data}
except Exception:
pass
if not defaults.get("api_key"):
defaults["api_key"] = os.getenv("AI_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
if not defaults.get("base_url") and os.getenv("AI_BASE_URL"):
defaults["base_url"] = os.getenv("AI_BASE_URL")
if defaults.get("model_name") == "gpt-4o-mini" and os.getenv("AI_MODEL"):
defaults["model_name"] = os.getenv("AI_MODEL")
return defaults
def _load_ai_config() -> Dict[str, Any]:
"""当前生效的 AI 配置(供需求解析、发票识别等使用)。"""
return get_active_ai_config()
def _client_from_config(config: Dict[str, Any]) -> AsyncOpenAI:
api_key = (config.get("api_key") or "").strip()
if not api_key:
raise RuntimeError("AI API Key 未配置,请在 设置 → AI 模型配置 中填写。")
base_url = (config.get("base_url") or "").strip() or None
return AsyncOpenAI(api_key=api_key, base_url=base_url)
def _build_requirement_prompt(raw_text: str) -> str:
"""
Build a clear system/user prompt for requirement analysis.
The model must output valid JSON only.
"""
return (
"你是一名资深的系统架构师,请阅读以下来自客户的原始需求文本,"
"提炼出清晰的交付方案,并严格按照指定 JSON 结构输出。\n\n"
"【要求】\n"
"1. 按功能模块拆分需求。\n"
"2. 每个模块给出简要说明和技术实现思路。\n"
"3. 估算建议工时(以人天或人小时为单位,使用数字)。\n"
"4. 可以根据你的经验给出每个模块的单价与小计金额,并给出总金额,"
"方便后续生成报价单。\n\n"
"【返回格式】请只返回 JSON不要包含任何额外说明文字\n"
"{\n"
' "modules": [\n'
" {\n"
' "name": "模块名称",\n'
' "description": "模块说明(可以为 Markdown 格式)",\n'
' "technical_approach": "技术实现思路Markdown 格式)",\n'
' "estimated_hours": 16,\n'
' "unit_price": 800,\n'
' "subtotal": 12800\n'
" }\n"
" ],\n"
' "total_estimated_hours": 40,\n'
' "total_amount": 32000,\n'
' "notes": "整体方案备注可选Markdown 格式)"\n'
"}\n\n"
f"【客户原始需求】\n{raw_text}"
)
async def analyze_requirement(raw_text: str) -> Dict[str, Any]:
"""
Call the AI model to analyze customer requirements.
Reads config from data/ai_config.json (and env fallback) on every request.
"""
import logging
logger = logging.getLogger(__name__)
config = _load_ai_config()
client = _client_from_config(config)
model = config.get("model_name") or "gpt-4o-mini"
temperature = float(config.get("temperature", 0.2))
system_override = (config.get("system_prompt_override") or "").strip()
logger.info("AI 需求解析: 调用模型 %s,输入长度 %d 字符", model, len(raw_text))
prompt = _build_requirement_prompt(raw_text)
system_content = (
system_override
if system_override
else "你是一名严谨的系统架构师,只能输出有效的 JSON不要输出任何解释文字。"
)
try:
completion = await client.chat.completions.create(
model=model,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": system_content},
{"role": "user", "content": prompt},
],
temperature=temperature,
)
except OpenAINotFoundError as e:
raise RuntimeError(
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」与当前提供商一致(如阿里云使用 qwen 系列、OpenAI 使用 gpt-4o-mini 等)。"
) from e
content = completion.choices[0].message.content or "{}"
try:
data: Any = json.loads(content)
except json.JSONDecodeError as exc:
logger.error("AI 返回非 JSON片段: %s", (content or "")[:200])
raise RuntimeError(f"AI 返回的内容不是合法 JSON{content}") from exc
# Some models return a list (e.g. modules only); normalize to expected dict shape
if isinstance(data, list):
data = {
"modules": data,
"total_estimated_hours": None,
"total_amount": None,
"notes": None,
}
if not isinstance(data, dict):
data = {}
mods = data.get("modules") or []
logger.info("AI 需求解析完成: 模块数 %d", len(mods) if isinstance(mods, list) else 0)
return data
async def test_connection() -> str:
"""使用当前选用配置测试连接。"""
return await test_connection_with_config(get_active_ai_config())
async def test_connection_with_config(config: Dict[str, Any]) -> str:
"""
使用指定配置发送简单补全以验证 API Key 与 Base URL。
供测试当前配置或指定 config_id 时使用。
"""
client = _client_from_config(config)
model = config.get("model_name") or "gpt-4o-mini"
try:
completion = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=50,
)
except OpenAINotFoundError as e:
raise RuntimeError(
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」(如阿里云使用 qwen 系列)。"
) from e
return (completion.choices[0].message.content or "").strip() or "OK"
async def extract_invoice_metadata(image_bytes: bytes, mime: str = "image/jpeg") -> Tuple[float | None, str | None]:
"""
Use AI vision to extract total amount and invoice date from an image.
Returns (amount, date_yyyy_mm_dd). On any error or unsupported model, returns (None, None).
"""
config = _load_ai_config()
api_key = (config.get("api_key") or "").strip()
if not api_key:
return (None, None)
try:
client = _client_from_config(config)
model = config.get("model_name") or "gpt-4o-mini"
b64 = base64.b64encode(image_bytes).decode("ascii")
data_url = f"data:{mime};base64,{b64}"
prompt = (
"从这张发票/收据图片中识别并提取1) 价税合计/总金额数字不含货币符号2) 开票日期(格式 YYYY-MM-DD"
"只返回 JSON不要其他文字格式{\"amount\": 数字或null, \"date\": \"YYYY-MM-DD\" 或 null}。"
)
completion = await client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": data_url}},
],
}
],
max_tokens=150,
)
content = (completion.choices[0].message.content or "").strip()
if not content:
return (None, None)
# Handle markdown code block
if "```" in content:
content = re.sub(r"^.*?```(?:json)?\s*", "", content).strip()
content = re.sub(r"\s*```.*$", "", content).strip()
data = json.loads(content)
amount_raw = data.get("amount")
date_raw = data.get("date")
amount = None
if amount_raw is not None:
try:
amount = float(amount_raw)
except (TypeError, ValueError):
pass
date_str = None
if isinstance(date_raw, str) and re.match(r"\d{4}-\d{2}-\d{2}", date_raw):
date_str = date_raw[:10]
return (amount, date_str)
except Exception:
return (None, None)