fix:优化数据
This commit is contained in:
@@ -1,37 +1,71 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai import NotFoundError as OpenAINotFoundError
|
||||
|
||||
AI_CONFIG_PATH = Path("data/ai_config.json")
|
||||
AI_CONFIGS_PATH = Path("data/ai_configs.json")
|
||||
|
||||
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_ai_client() -> AsyncOpenAI:
|
||||
def get_active_ai_config() -> Dict[str, Any]:
|
||||
"""
|
||||
Create (or reuse) a singleton AsyncOpenAI client.
|
||||
|
||||
The client is configured via:
|
||||
- AI_API_KEY / OPENAI_API_KEY
|
||||
- AI_BASE_URL (optional, defaults to official OpenAI endpoint)
|
||||
- AI_MODEL (optional, defaults to gpt-4.1-mini or a similar capable model)
|
||||
从 data/ai_configs.json 读取当前选用的配置;若无则从旧版 ai_config.json 迁移并返回。
|
||||
供 router 与内部调用。
|
||||
"""
|
||||
global _client
|
||||
if _client is not None:
|
||||
return _client
|
||||
defaults = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"provider": "OpenAI",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"model_name": "gpt-4o-mini",
|
||||
"temperature": 0.2,
|
||||
"system_prompt_override": "",
|
||||
}
|
||||
if AI_CONFIGS_PATH.exists():
|
||||
try:
|
||||
data = json.loads(AI_CONFIGS_PATH.read_text(encoding="utf-8"))
|
||||
configs = data.get("configs") or []
|
||||
active_id = data.get("active_id") or ""
|
||||
for c in configs:
|
||||
if c.get("id") == active_id:
|
||||
return {**defaults, **c}
|
||||
if configs:
|
||||
return {**defaults, **configs[0]}
|
||||
except Exception:
|
||||
pass
|
||||
# 兼容旧版单文件
|
||||
if AI_CONFIG_PATH.exists():
|
||||
try:
|
||||
data = json.loads(AI_CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return {**defaults, **data}
|
||||
except Exception:
|
||||
pass
|
||||
if not defaults.get("api_key"):
|
||||
defaults["api_key"] = os.getenv("AI_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
if not defaults.get("base_url") and os.getenv("AI_BASE_URL"):
|
||||
defaults["base_url"] = os.getenv("AI_BASE_URL")
|
||||
if defaults.get("model_name") == "gpt-4o-mini" and os.getenv("AI_MODEL"):
|
||||
defaults["model_name"] = os.getenv("AI_MODEL")
|
||||
return defaults
|
||||
|
||||
api_key = os.getenv("AI_API_KEY") or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
def _load_ai_config() -> Dict[str, Any]:
|
||||
"""当前生效的 AI 配置(供需求解析、发票识别等使用)。"""
|
||||
return get_active_ai_config()
|
||||
|
||||
|
||||
def _client_from_config(config: Dict[str, Any]) -> AsyncOpenAI:
|
||||
api_key = (config.get("api_key") or "").strip()
|
||||
if not api_key:
|
||||
raise RuntimeError("AI_API_KEY or OPENAI_API_KEY must be set in environment.")
|
||||
|
||||
base_url = os.getenv("AI_BASE_URL") # can point to OpenAI, DeepSeek, Qwen, etc.
|
||||
|
||||
_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
)
|
||||
return _client
|
||||
raise RuntimeError("AI API Key 未配置,请在 设置 → AI 模型配置 中填写。")
|
||||
base_url = (config.get("base_url") or "").strip() or None
|
||||
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def _build_requirement_prompt(raw_text: str) -> str:
|
||||
@@ -71,38 +105,139 @@ def _build_requirement_prompt(raw_text: str) -> str:
|
||||
async def analyze_requirement(raw_text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Call the AI model to analyze customer requirements.
|
||||
|
||||
Returns a Python dict matching the JSON structure described
|
||||
in `_build_requirement_prompt`.
|
||||
Reads config from data/ai_config.json (and env fallback) on every request.
|
||||
"""
|
||||
client = get_ai_client()
|
||||
model = os.getenv("AI_MODEL", "gpt-4.1-mini")
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = _load_ai_config()
|
||||
client = _client_from_config(config)
|
||||
model = config.get("model_name") or "gpt-4o-mini"
|
||||
temperature = float(config.get("temperature", 0.2))
|
||||
system_override = (config.get("system_prompt_override") or "").strip()
|
||||
|
||||
logger.info("AI 需求解析: 调用模型 %s,输入长度 %d 字符", model, len(raw_text))
|
||||
|
||||
prompt = _build_requirement_prompt(raw_text)
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是一名严谨的系统架构师,只能输出有效的 JSON,不要输出任何解释文字。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
],
|
||||
temperature=0.2,
|
||||
system_content = (
|
||||
system_override
|
||||
if system_override
|
||||
else "你是一名严谨的系统架构师,只能输出有效的 JSON,不要输出任何解释文字。"
|
||||
)
|
||||
|
||||
try:
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=temperature,
|
||||
)
|
||||
except OpenAINotFoundError as e:
|
||||
raise RuntimeError(
|
||||
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」与当前提供商一致(如阿里云使用 qwen 系列、OpenAI 使用 gpt-4o-mini 等)。"
|
||||
) from e
|
||||
|
||||
content = completion.choices[0].message.content or "{}"
|
||||
try:
|
||||
data: Dict[str, Any] = json.loads(content)
|
||||
data: Any = json.loads(content)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error("AI 返回非 JSON,片段: %s", (content or "")[:200])
|
||||
raise RuntimeError(f"AI 返回的内容不是合法 JSON:{content}") from exc
|
||||
|
||||
# Some models return a list (e.g. modules only); normalize to expected dict shape
|
||||
if isinstance(data, list):
|
||||
data = {
|
||||
"modules": data,
|
||||
"total_estimated_hours": None,
|
||||
"total_amount": None,
|
||||
"notes": None,
|
||||
}
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
|
||||
mods = data.get("modules") or []
|
||||
logger.info("AI 需求解析完成: 模块数 %d", len(mods) if isinstance(mods, list) else 0)
|
||||
return data
|
||||
|
||||
|
||||
async def test_connection() -> str:
|
||||
"""使用当前选用配置测试连接。"""
|
||||
return await test_connection_with_config(get_active_ai_config())
|
||||
|
||||
|
||||
async def test_connection_with_config(config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
使用指定配置发送简单补全以验证 API Key 与 Base URL。
|
||||
供测试当前配置或指定 config_id 时使用。
|
||||
"""
|
||||
client = _client_from_config(config)
|
||||
model = config.get("model_name") or "gpt-4o-mini"
|
||||
try:
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=50,
|
||||
)
|
||||
except OpenAINotFoundError as e:
|
||||
raise RuntimeError(
|
||||
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」(如阿里云使用 qwen 系列)。"
|
||||
) from e
|
||||
return (completion.choices[0].message.content or "").strip() or "OK"
|
||||
|
||||
|
||||
async def extract_invoice_metadata(image_bytes: bytes, mime: str = "image/jpeg") -> Tuple[float | None, str | None]:
|
||||
"""
|
||||
Use AI vision to extract total amount and invoice date from an image.
|
||||
Returns (amount, date_yyyy_mm_dd). On any error or unsupported model, returns (None, None).
|
||||
"""
|
||||
config = _load_ai_config()
|
||||
api_key = (config.get("api_key") or "").strip()
|
||||
if not api_key:
|
||||
return (None, None)
|
||||
try:
|
||||
client = _client_from_config(config)
|
||||
model = config.get("model_name") or "gpt-4o-mini"
|
||||
b64 = base64.b64encode(image_bytes).decode("ascii")
|
||||
data_url = f"data:{mime};base64,{b64}"
|
||||
prompt = (
|
||||
"从这张发票/收据图片中识别并提取:1) 价税合计/总金额(数字,不含货币符号);2) 开票日期(格式 YYYY-MM-DD)。"
|
||||
"只返回 JSON,不要其他文字,格式:{\"amount\": 数字或null, \"date\": \"YYYY-MM-DD\" 或 null}。"
|
||||
)
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=150,
|
||||
)
|
||||
content = (completion.choices[0].message.content or "").strip()
|
||||
if not content:
|
||||
return (None, None)
|
||||
# Handle markdown code block
|
||||
if "```" in content:
|
||||
content = re.sub(r"^.*?```(?:json)?\s*", "", content).strip()
|
||||
content = re.sub(r"\s*```.*$", "", content).strip()
|
||||
data = json.loads(content)
|
||||
amount_raw = data.get("amount")
|
||||
date_raw = data.get("date")
|
||||
amount = None
|
||||
if amount_raw is not None:
|
||||
try:
|
||||
amount = float(amount_raw)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
date_str = None
|
||||
if isinstance(date_raw, str) and re.match(r"\d{4}-\d{2}-\d{2}", date_raw):
|
||||
date_str = date_raw[:10]
|
||||
return (amount, date_str)
|
||||
except Exception:
|
||||
return (None, None)
|
||||
|
||||
Reference in New Issue
Block a user