303 lines
11 KiB
Python
303 lines
11 KiB
Python
import base64
|
||
import json
|
||
import os
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Tuple, List
|
||
|
||
from openai import AsyncOpenAI
|
||
from openai import NotFoundError as OpenAINotFoundError
|
||
|
||
AI_CONFIG_PATH = Path("data/ai_config.json")
|
||
AI_CONFIGS_PATH = Path("data/ai_configs.json")
|
||
|
||
|
||
def get_active_ai_config() -> Dict[str, Any]:
|
||
"""
|
||
从 data/ai_configs.json 读取当前选用的配置;若无则从旧版 ai_config.json 迁移并返回。
|
||
供 router 与内部调用。
|
||
"""
|
||
defaults = {
|
||
"id": "",
|
||
"name": "",
|
||
"provider": "OpenAI",
|
||
"api_key": "",
|
||
"base_url": "",
|
||
"model_name": "gpt-4o-mini",
|
||
"temperature": 0.2,
|
||
"system_prompt_override": "",
|
||
}
|
||
if AI_CONFIGS_PATH.exists():
|
||
try:
|
||
data = json.loads(AI_CONFIGS_PATH.read_text(encoding="utf-8"))
|
||
configs = data.get("configs") or []
|
||
active_id = data.get("active_id") or ""
|
||
for c in configs:
|
||
if c.get("id") == active_id:
|
||
return {**defaults, **c}
|
||
if configs:
|
||
return {**defaults, **configs[0]}
|
||
except Exception:
|
||
pass
|
||
# 兼容旧版单文件
|
||
if AI_CONFIG_PATH.exists():
|
||
try:
|
||
data = json.loads(AI_CONFIG_PATH.read_text(encoding="utf-8"))
|
||
return {**defaults, **data}
|
||
except Exception:
|
||
pass
|
||
if not defaults.get("api_key"):
|
||
defaults["api_key"] = os.getenv("AI_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||
if not defaults.get("base_url") and os.getenv("AI_BASE_URL"):
|
||
defaults["base_url"] = os.getenv("AI_BASE_URL")
|
||
if defaults.get("model_name") == "gpt-4o-mini" and os.getenv("AI_MODEL"):
|
||
defaults["model_name"] = os.getenv("AI_MODEL")
|
||
return defaults
|
||
|
||
|
||
def _load_ai_config() -> Dict[str, Any]:
|
||
"""当前生效的 AI 配置(供需求解析、发票识别等使用)。"""
|
||
return get_active_ai_config()
|
||
|
||
|
||
def _client_from_config(config: Dict[str, Any]) -> AsyncOpenAI:
|
||
api_key = (config.get("api_key") or "").strip()
|
||
if not api_key:
|
||
raise RuntimeError("AI API Key 未配置,请在 设置 → AI 模型配置 中填写。")
|
||
base_url = (config.get("base_url") or "").strip() or None
|
||
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||
|
||
|
||
def _build_requirement_prompt(raw_text: str) -> str:
|
||
"""
|
||
Build a clear system/user prompt for requirement analysis.
|
||
The model must output valid JSON only.
|
||
"""
|
||
return (
|
||
"你是一名资深的系统架构师,请阅读以下来自客户的原始需求文本,"
|
||
"提炼出清晰的交付方案,并严格按照指定 JSON 结构输出。\n\n"
|
||
"【要求】\n"
|
||
"1. 按功能模块拆分需求。\n"
|
||
"2. 每个模块给出简要说明和技术实现思路。\n"
|
||
"3. 估算建议工时(以人天或人小时为单位,使用数字)。\n"
|
||
"4. 可以根据你的经验给出每个模块的单价与小计金额,并给出总金额,"
|
||
"方便后续生成报价单。\n\n"
|
||
"【返回格式】请只返回 JSON,不要包含任何额外说明文字:\n"
|
||
"{\n"
|
||
' "modules": [\n'
|
||
" {\n"
|
||
' "name": "模块名称",\n'
|
||
' "description": "模块说明(可以为 Markdown 格式)",\n'
|
||
' "technical_approach": "技术实现思路(Markdown 格式)",\n'
|
||
' "estimated_hours": 16,\n'
|
||
' "unit_price": 800,\n'
|
||
' "subtotal": 12800\n'
|
||
" }\n"
|
||
" ],\n"
|
||
' "total_estimated_hours": 40,\n'
|
||
' "total_amount": 32000,\n'
|
||
' "notes": "整体方案备注(可选,Markdown 格式)"\n'
|
||
"}\n\n"
|
||
f"【客户原始需求】\n{raw_text}"
|
||
)
|
||
|
||
|
||
async def analyze_requirement(raw_text: str) -> Dict[str, Any]:
|
||
"""
|
||
Call the AI model to analyze customer requirements.
|
||
Reads config from data/ai_config.json (and env fallback) on every request.
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
config = _load_ai_config()
|
||
client = _client_from_config(config)
|
||
model = config.get("model_name") or "gpt-4o-mini"
|
||
temperature = float(config.get("temperature", 0.2))
|
||
system_override = (config.get("system_prompt_override") or "").strip()
|
||
|
||
logger.info("AI 需求解析: 调用模型 %s,输入长度 %d 字符", model, len(raw_text))
|
||
|
||
prompt = _build_requirement_prompt(raw_text)
|
||
system_content = (
|
||
system_override
|
||
if system_override
|
||
else "你是一名严谨的系统架构师,只能输出有效的 JSON,不要输出任何解释文字。"
|
||
)
|
||
|
||
try:
|
||
completion = await client.chat.completions.create(
|
||
model=model,
|
||
response_format={"type": "json_object"},
|
||
messages=[
|
||
{"role": "system", "content": system_content},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=temperature,
|
||
)
|
||
except OpenAINotFoundError as e:
|
||
raise RuntimeError(
|
||
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」与当前提供商一致(如阿里云使用 qwen 系列、OpenAI 使用 gpt-4o-mini 等)。"
|
||
) from e
|
||
|
||
content = completion.choices[0].message.content or "{}"
|
||
try:
|
||
data: Any = json.loads(content)
|
||
except json.JSONDecodeError as exc:
|
||
logger.error("AI 返回非 JSON,片段: %s", (content or "")[:200])
|
||
raise RuntimeError(f"AI 返回的内容不是合法 JSON:{content}") from exc
|
||
|
||
# Some models return a list (e.g. modules only); normalize to expected dict shape
|
||
if isinstance(data, list):
|
||
data = {
|
||
"modules": data,
|
||
"total_estimated_hours": None,
|
||
"total_amount": None,
|
||
"notes": None,
|
||
}
|
||
if not isinstance(data, dict):
|
||
data = {}
|
||
|
||
mods = data.get("modules") or []
|
||
logger.info("AI 需求解析完成: 模块数 %d", len(mods) if isinstance(mods, list) else 0)
|
||
return data
|
||
|
||
|
||
async def test_connection() -> str:
|
||
"""使用当前选用配置测试连接。"""
|
||
return await test_connection_with_config(get_active_ai_config())
|
||
|
||
|
||
async def test_connection_with_config(config: Dict[str, Any]) -> str:
|
||
"""
|
||
使用指定配置发送简单补全以验证 API Key 与 Base URL。
|
||
供测试当前配置或指定 config_id 时使用。
|
||
"""
|
||
client = _client_from_config(config)
|
||
model = config.get("model_name") or "gpt-4o-mini"
|
||
try:
|
||
completion = await client.chat.completions.create(
|
||
model=model,
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
max_tokens=50,
|
||
)
|
||
except OpenAINotFoundError as e:
|
||
raise RuntimeError(
|
||
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」(如阿里云使用 qwen 系列)。"
|
||
) from e
|
||
return (completion.choices[0].message.content or "").strip() or "OK"
|
||
|
||
|
||
async def extract_invoice_metadata(image_bytes: bytes, mime: str = "image/jpeg") -> Tuple[float | None, str | None]:
|
||
"""
|
||
Use AI vision to extract total amount and invoice date from an image.
|
||
Returns (amount, date_yyyy_mm_dd). On any error or unsupported model, returns (None, None).
|
||
"""
|
||
config = _load_ai_config()
|
||
api_key = (config.get("api_key") or "").strip()
|
||
if not api_key:
|
||
return (None, None)
|
||
|
||
|
||
async def extract_finance_tags(
|
||
content_text: str,
|
||
doc_type: str,
|
||
filename: str = "",
|
||
) -> Tuple[List[str], Dict[str, Any]]:
|
||
"""
|
||
从附件文本内容中抽取标签与结构化信息(JSON)。
|
||
返回 (tags, meta)。
|
||
"""
|
||
config = _load_ai_config()
|
||
client = _client_from_config(config)
|
||
model = config.get("model_name") or "gpt-4o-mini"
|
||
temperature = float(config.get("temperature", 0.2))
|
||
|
||
prompt = (
|
||
"你是一名财务助理。请根据附件的文本内容,为它生成可检索的标签,并抽取关键字段。\n"
|
||
"只返回 JSON,不要任何解释文字。\n"
|
||
"输入信息:\n"
|
||
f"- 类型 doc_type: {doc_type}\n"
|
||
f"- 文件名 filename: {filename}\n"
|
||
"- 附件文本 content_text: (见下)\n\n"
|
||
"返回 JSON 格式:\n"
|
||
"{\n"
|
||
' "tags": ["标签1","标签2"],\n'
|
||
' "meta": {\n'
|
||
' "counterparty": "对方单位/收款方/付款方(如能识别)或 null",\n'
|
||
' "account": "账户/卡号后四位(如能识别)或 null",\n'
|
||
' "amount": "金额数字字符串或 null",\n'
|
||
' "date": "YYYY-MM-DD 或 null",\n'
|
||
' "summary": "一句话摘要"\n'
|
||
" }\n"
|
||
"}\n\n"
|
||
"content_text:\n"
|
||
f"{content_text[:12000]}\n"
|
||
)
|
||
|
||
completion = await client.chat.completions.create(
|
||
model=model,
|
||
response_format={"type": "json_object"},
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=temperature,
|
||
max_tokens=500,
|
||
)
|
||
content = completion.choices[0].message.content or "{}"
|
||
try:
|
||
data: Any = json.loads(content)
|
||
except Exception:
|
||
return ([], {"summary": "", "raw": content})
|
||
|
||
tags = data.get("tags") if isinstance(data, dict) else None
|
||
meta = data.get("meta") if isinstance(data, dict) else None
|
||
if not isinstance(tags, list):
|
||
tags = []
|
||
tags = [str(t).strip() for t in tags if str(t).strip()][:12]
|
||
if not isinstance(meta, dict):
|
||
meta = {}
|
||
return (tags, meta)
|
||
try:
|
||
client = _client_from_config(config)
|
||
model = config.get("model_name") or "gpt-4o-mini"
|
||
b64 = base64.b64encode(image_bytes).decode("ascii")
|
||
data_url = f"data:{mime};base64,{b64}"
|
||
prompt = (
|
||
"从这张发票/收据图片中识别并提取:1) 价税合计/总金额(数字,不含货币符号);2) 开票日期(格式 YYYY-MM-DD)。"
|
||
"只返回 JSON,不要其他文字,格式:{\"amount\": 数字或null, \"date\": \"YYYY-MM-DD\" 或 null}。"
|
||
)
|
||
completion = await client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": data_url}},
|
||
],
|
||
}
|
||
],
|
||
max_tokens=150,
|
||
)
|
||
content = (completion.choices[0].message.content or "").strip()
|
||
if not content:
|
||
return (None, None)
|
||
# Handle markdown code block
|
||
if "```" in content:
|
||
content = re.sub(r"^.*?```(?:json)?\s*", "", content).strip()
|
||
content = re.sub(r"\s*```.*$", "", content).strip()
|
||
data = json.loads(content)
|
||
amount_raw = data.get("amount")
|
||
date_raw = data.get("date")
|
||
amount = None
|
||
if amount_raw is not None:
|
||
try:
|
||
amount = float(amount_raw)
|
||
except (TypeError, ValueError):
|
||
pass
|
||
date_str = None
|
||
if isinstance(date_raw, str) and re.match(r"\d{4}-\d{2}-\d{2}", date_raw):
|
||
date_str = date_raw[:10]
|
||
return (amount, date_str)
|
||
except Exception:
|
||
return (None, None)
|