fix:优化数据
This commit is contained in:
@@ -6,7 +6,19 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from backend.app.routers import customers, projects, finance
|
||||
from backend.app.routers import (
|
||||
customers,
|
||||
projects,
|
||||
finance,
|
||||
settings,
|
||||
ai_settings,
|
||||
email_configs,
|
||||
cloud_docs,
|
||||
cloud_doc_config,
|
||||
portal_links,
|
||||
)
|
||||
from backend.app.db import Base, engine
|
||||
from backend.app import models # noqa: F401 - ensure models are imported
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -16,6 +28,35 @@ def create_app() -> FastAPI:
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Ensure database tables exist (especially when running in Docker)
|
||||
@app.on_event("startup")
|
||||
def on_startup() -> None:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
# Add new columns to finance_records if they don't exist (Module 6)
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
with engine.connect() as conn:
|
||||
r = conn.execute(text("PRAGMA table_info(finance_records)"))
|
||||
cols = [row[1] for row in r]
|
||||
if "amount" not in cols:
|
||||
conn.execute(text("ALTER TABLE finance_records ADD COLUMN amount NUMERIC(12,2)"))
|
||||
if "billing_date" not in cols:
|
||||
conn.execute(text("ALTER TABLE finance_records ADD COLUMN billing_date DATE"))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
# Add customers.tags if missing (customer tags for project 收纳)
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
with engine.connect() as conn:
|
||||
r = conn.execute(text("PRAGMA table_info(customers)"))
|
||||
cols = [row[1] for row in r]
|
||||
if "tags" not in cols:
|
||||
conn.execute(text("ALTER TABLE customers ADD COLUMN tags VARCHAR(512)"))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# CORS
|
||||
raw_origins = os.getenv("CORS_ORIGINS")
|
||||
if raw_origins:
|
||||
@@ -35,6 +76,12 @@ def create_app() -> FastAPI:
|
||||
app.include_router(customers.router)
|
||||
app.include_router(projects.router)
|
||||
app.include_router(finance.router)
|
||||
app.include_router(settings.router)
|
||||
app.include_router(ai_settings.router)
|
||||
app.include_router(email_configs.router)
|
||||
app.include_router(cloud_docs.router)
|
||||
app.include_router(cloud_doc_config.router)
|
||||
app.include_router(portal_links.router)
|
||||
|
||||
# Static data mount (for quotes, contracts, finance archives, etc.)
|
||||
data_dir = Path("data")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
@@ -20,6 +21,7 @@ class Customer(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
contact_info: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
tags: Mapped[str | None] = mapped_column(String(512), nullable=True) # 逗号分隔,如:重点客户,已签约
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
@@ -51,6 +53,30 @@ class Project(Base):
|
||||
quotes: Mapped[list["Quote"]] = relationship(
|
||||
"Quote", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
cloud_docs: Mapped[list["ProjectCloudDoc"]] = relationship(
|
||||
"ProjectCloudDoc", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class ProjectCloudDoc(Base):
|
||||
"""项目与云文档的映射,用于增量更新(有则 PATCH,无则 POST)。"""
|
||||
__tablename__ = "project_cloud_docs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
platform: Mapped[str] = mapped_column(String(32), nullable=False, index=True) # feishu | yuque | tencent
|
||||
cloud_doc_id: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
cloud_url: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="cloud_docs")
|
||||
|
||||
|
||||
class Quote(Base):
|
||||
@@ -74,9 +100,11 @@ class FinanceRecord(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
month: Mapped[str] = mapped_column(String(7), nullable=False, index=True) # YYYY-MM
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False) # invoice / bank_receipt / ...
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False) # invoice / bank_receipt / manual / ...
|
||||
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
amount: Mapped[float | None] = mapped_column(Numeric(12, 2), nullable=True)
|
||||
billing_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
300
backend/app/routers/ai_settings.py
Normal file
300
backend/app/routers/ai_settings.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
AI 模型配置:支持多套配置,持久化在 data/ai_configs.json,可选用当前生效配置。
|
||||
GET /settings/ai 当前选用配置;GET /settings/ai/list 列表;POST 新增;PUT /:id 更新;DELETE /:id 删除;POST /:id/activate 选用。
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.app.services.ai_service import get_active_ai_config, test_connection_with_config
|
||||
|
||||
router = APIRouter(prefix="/settings/ai", tags=["ai-settings"])
|
||||
|
||||
CONFIGS_PATH = Path("data/ai_configs.json")
|
||||
LEGACY_CONFIG_PATH = Path("data/ai_config.json")
|
||||
|
||||
DEFAULT_FIELDS: Dict[str, Any] = {
|
||||
"provider": "OpenAI",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"model_name": "gpt-4o-mini",
|
||||
"temperature": 0.2,
|
||||
"system_prompt_override": "",
|
||||
}
|
||||
|
||||
|
||||
class AIConfigRead(BaseModel):
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
provider: str = "OpenAI"
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
model_name: str = "gpt-4o-mini"
|
||||
temperature: float = 0.2
|
||||
system_prompt_override: str = ""
|
||||
|
||||
|
||||
class AIConfigListItem(BaseModel):
|
||||
"""列表项:不含完整 api_key,仅标记是否已配置"""
|
||||
id: str
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
base_url: str = ""
|
||||
api_key_configured: bool = False
|
||||
is_active: bool = False
|
||||
|
||||
|
||||
class AIConfigCreate(BaseModel):
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
name: str = Field("", max_length=64)
|
||||
provider: str = "OpenAI"
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
model_name: str = "gpt-4o-mini"
|
||||
temperature: float = 0.2
|
||||
system_prompt_override: str = ""
|
||||
|
||||
|
||||
class AIConfigUpdate(BaseModel):
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
name: str | None = Field(None, max_length=64)
|
||||
provider: str | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
model_name: str | None = None
|
||||
temperature: float | None = None
|
||||
system_prompt_override: str | None = None
|
||||
|
||||
|
||||
def _load_configs_file() -> Dict[str, Any]:
|
||||
if not CONFIGS_PATH.exists():
|
||||
return {"configs": [], "active_id": ""}
|
||||
try:
|
||||
data = json.loads(CONFIGS_PATH.read_text(encoding="utf-8"))
|
||||
return {"configs": data.get("configs", []), "active_id": data.get("active_id", "") or ""}
|
||||
except Exception:
|
||||
return {"configs": [], "active_id": ""}
|
||||
|
||||
|
||||
def _migrate_from_legacy() -> None:
|
||||
if CONFIGS_PATH.exists():
|
||||
return
|
||||
if not LEGACY_CONFIG_PATH.exists():
|
||||
return
|
||||
try:
|
||||
legacy = json.loads(LEGACY_CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return
|
||||
cfg = {**DEFAULT_FIELDS, **legacy}
|
||||
new_id = str(uuid.uuid4())[:8]
|
||||
payload = {
|
||||
"configs": [
|
||||
{
|
||||
"id": new_id,
|
||||
"name": "默认配置",
|
||||
"provider": cfg.get("provider", "OpenAI"),
|
||||
"api_key": cfg.get("api_key", ""),
|
||||
"base_url": cfg.get("base_url", ""),
|
||||
"model_name": cfg.get("model_name", "gpt-4o-mini"),
|
||||
"temperature": cfg.get("temperature", 0.2),
|
||||
"system_prompt_override": cfg.get("system_prompt_override", ""),
|
||||
}
|
||||
],
|
||||
"active_id": new_id,
|
||||
}
|
||||
CONFIGS_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIGS_PATH.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def _save_configs(configs: List[Dict], active_id: str) -> None:
|
||||
CONFIGS_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIGS_PATH.write_text(
|
||||
json.dumps({"configs": configs, "active_id": active_id}, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=AIConfigRead)
|
||||
async def get_current_ai_settings():
|
||||
"""返回当前选用的 AI 配置(用于编辑表单与兼容旧接口)。"""
|
||||
_migrate_from_legacy()
|
||||
cfg = get_active_ai_config()
|
||||
return AIConfigRead(
|
||||
id=cfg.get("id", ""),
|
||||
name=cfg.get("name", ""),
|
||||
provider=cfg.get("provider", "OpenAI"),
|
||||
api_key=cfg.get("api_key", ""),
|
||||
base_url=cfg.get("base_url", ""),
|
||||
model_name=cfg.get("model_name", "gpt-4o-mini"),
|
||||
temperature=float(cfg.get("temperature", 0.2)),
|
||||
system_prompt_override=cfg.get("system_prompt_override", ""),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=List[AIConfigListItem])
|
||||
async def list_ai_configs():
|
||||
"""列出所有已配置的模型,方便查看、选用或编辑。"""
|
||||
_migrate_from_legacy()
|
||||
data = _load_configs_file()
|
||||
configs = data.get("configs") or []
|
||||
active_id = data.get("active_id") or ""
|
||||
out = []
|
||||
for c in configs:
|
||||
out.append(
|
||||
AIConfigListItem(
|
||||
id=c.get("id", ""),
|
||||
name=c.get("name", "未命名"),
|
||||
provider=c.get("provider", "OpenAI"),
|
||||
model_name=c.get("model_name", ""),
|
||||
base_url=(c.get("base_url") or "")[:64] or "",
|
||||
api_key_configured=bool((c.get("api_key") or "").strip()),
|
||||
is_active=(c.get("id") == active_id),
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@router.get("/{config_id}", response_model=AIConfigRead)
|
||||
async def get_ai_config_by_id(config_id: str):
|
||||
"""获取单条配置(用于编辑)。"""
|
||||
_migrate_from_legacy()
|
||||
data = _load_configs_file()
|
||||
for c in data.get("configs") or []:
|
||||
if c.get("id") == config_id:
|
||||
return AIConfigRead(
|
||||
id=c.get("id", ""),
|
||||
name=c.get("name", ""),
|
||||
provider=c.get("provider", "OpenAI"),
|
||||
api_key=c.get("api_key", ""),
|
||||
base_url=c.get("base_url", ""),
|
||||
model_name=c.get("model_name", "gpt-4o-mini"),
|
||||
temperature=float(c.get("temperature", 0.2)),
|
||||
system_prompt_override=c.get("system_prompt_override", ""),
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="配置不存在")
|
||||
|
||||
|
||||
@router.post("", response_model=AIConfigRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_ai_config(payload: AIConfigCreate):
|
||||
"""新增一套模型配置。"""
|
||||
_migrate_from_legacy()
|
||||
data = _load_configs_file()
|
||||
configs = list(data.get("configs") or [])
|
||||
active_id = data.get("active_id") or ""
|
||||
new_id = str(uuid.uuid4())[:8]
|
||||
name = (payload.name or "").strip() or f"{payload.provider} - {payload.model_name}"
|
||||
new_cfg = {
|
||||
"id": new_id,
|
||||
"name": name[:64],
|
||||
"provider": payload.provider or "OpenAI",
|
||||
"api_key": payload.api_key or "",
|
||||
"base_url": (payload.base_url or "").strip(),
|
||||
"model_name": (payload.model_name or "gpt-4o-mini").strip(),
|
||||
"temperature": float(payload.temperature) if payload.temperature is not None else 0.2,
|
||||
"system_prompt_override": (payload.system_prompt_override or "").strip(),
|
||||
}
|
||||
configs.append(new_cfg)
|
||||
if not active_id:
|
||||
active_id = new_id
|
||||
_save_configs(configs, active_id)
|
||||
return AIConfigRead(**new_cfg)
|
||||
|
||||
|
||||
@router.put("/{config_id}", response_model=AIConfigRead)
|
||||
async def update_ai_config(config_id: str, payload: AIConfigUpdate):
|
||||
"""更新指定配置。"""
|
||||
_migrate_from_legacy()
|
||||
data = _load_configs_file()
|
||||
configs = data.get("configs") or []
|
||||
for c in configs:
|
||||
if c.get("id") == config_id:
|
||||
if payload.name is not None:
|
||||
c["name"] = (payload.name or "").strip()[:64] or c.get("name", "")
|
||||
if payload.provider is not None:
|
||||
c["provider"] = payload.provider
|
||||
if payload.api_key is not None:
|
||||
c["api_key"] = payload.api_key
|
||||
if payload.base_url is not None:
|
||||
c["base_url"] = (payload.base_url or "").strip()
|
||||
if payload.model_name is not None:
|
||||
c["model_name"] = (payload.model_name or "").strip()
|
||||
if payload.temperature is not None:
|
||||
c["temperature"] = float(payload.temperature)
|
||||
if payload.system_prompt_override is not None:
|
||||
c["system_prompt_override"] = (payload.system_prompt_override or "").strip()
|
||||
_save_configs(configs, data.get("active_id", ""))
|
||||
return AIConfigRead(
|
||||
id=c.get("id", ""),
|
||||
name=c.get("name", ""),
|
||||
provider=c.get("provider", "OpenAI"),
|
||||
api_key=c.get("api_key", ""),
|
||||
base_url=c.get("base_url", ""),
|
||||
model_name=c.get("model_name", "gpt-4o-mini"),
|
||||
temperature=float(c.get("temperature", 0.2)),
|
||||
system_prompt_override=c.get("system_prompt_override", ""),
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="配置不存在")
|
||||
|
||||
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_ai_config(config_id: str):
|
||||
"""删除指定配置;若为当前选用,则改用列表第一项。"""
|
||||
_migrate_from_legacy()
|
||||
data = _load_configs_file()
|
||||
configs = [c for c in (data.get("configs") or []) if c.get("id") != config_id]
|
||||
active_id = data.get("active_id", "")
|
||||
if active_id == config_id:
|
||||
active_id = configs[0].get("id", "") if configs else ""
|
||||
_save_configs(configs, active_id)
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_ai_connection(config_id: str | None = Query(None, description="指定配置 ID,不传则用当前选用")):
|
||||
"""测试连接;不传 config_id 时使用当前选用配置。"""
|
||||
if config_id:
|
||||
data = _load_configs_file()
|
||||
for c in data.get("configs") or []:
|
||||
if c.get("id") == config_id:
|
||||
try:
|
||||
result = await test_connection_with_config(c)
|
||||
return {"status": "ok", "message": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
raise HTTPException(status_code=404, detail="配置不存在")
|
||||
try:
|
||||
result = await test_connection_with_config(get_active_ai_config())
|
||||
return {"status": "ok", "message": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/{config_id}/activate", response_model=AIConfigRead)
|
||||
async def activate_ai_config(config_id: str):
|
||||
"""选用该配置为当前生效。"""
|
||||
_migrate_from_legacy()
|
||||
data = _load_configs_file()
|
||||
exists = any(c.get("id") == config_id for c in (data.get("configs") or []))
|
||||
if not exists:
|
||||
raise HTTPException(status_code=404, detail="配置不存在")
|
||||
_save_configs(data.get("configs", []), config_id)
|
||||
cfg = get_active_ai_config()
|
||||
return AIConfigRead(
|
||||
id=cfg.get("id", ""),
|
||||
name=cfg.get("name", ""),
|
||||
provider=cfg.get("provider", "OpenAI"),
|
||||
api_key=cfg.get("api_key", ""),
|
||||
base_url=cfg.get("base_url", ""),
|
||||
model_name=cfg.get("model_name", "gpt-4o-mini"),
|
||||
temperature=float(cfg.get("temperature", 0.2)),
|
||||
system_prompt_override=cfg.get("system_prompt_override", ""),
|
||||
)
|
||||
139
backend/app/routers/cloud_doc_config.py
Normal file
139
backend/app/routers/cloud_doc_config.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
云文档配置:各平台 API 凭证的存储与读取。
|
||||
飞书 App ID/Secret、语雀 Token、腾讯文档 Client ID/Secret。
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter(prefix="/settings/cloud-doc-config", tags=["cloud-doc-config"])
|
||||
|
||||
CONFIG_PATH = Path("data/cloud_doc_credentials.json")
|
||||
|
||||
PLATFORMS = ("feishu", "yuque", "tencent")
|
||||
|
||||
|
||||
class FeishuConfig(BaseModel):
|
||||
app_id: str = Field("", description="飞书应用 App ID")
|
||||
app_secret: str = Field("", description="飞书应用 App Secret")
|
||||
|
||||
|
||||
class YuqueConfig(BaseModel):
|
||||
token: str = Field("", description="语雀 Personal Access Token")
|
||||
default_repo: str = Field("", description="默认知识库 namespace,如 my/repo")
|
||||
|
||||
|
||||
class TencentConfig(BaseModel):
|
||||
client_id: str = Field("", description="腾讯文档应用 Client ID")
|
||||
client_secret: str = Field("", description="腾讯文档应用 Client Secret")
|
||||
|
||||
|
||||
class FeishuConfigRead(BaseModel):
|
||||
app_id: str = ""
|
||||
app_secret_configured: bool = False
|
||||
|
||||
|
||||
class YuqueConfigRead(BaseModel):
|
||||
token_configured: bool = False
|
||||
default_repo: str = ""
|
||||
|
||||
|
||||
class TencentConfigRead(BaseModel):
|
||||
client_id: str = ""
|
||||
client_secret_configured: bool = False
|
||||
|
||||
|
||||
class CloudDocConfigRead(BaseModel):
|
||||
feishu: FeishuConfigRead
|
||||
yuque: YuqueConfigRead
|
||||
tencent: TencentConfigRead
|
||||
|
||||
|
||||
def _load_config() -> Dict[str, Any]:
|
||||
if not CONFIG_PATH.exists():
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_config(data: Dict[str, Any]) -> None:
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_PATH.write_text(
|
||||
json.dumps(data, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _mask_secrets_for_read(raw: Dict[str, Any]) -> CloudDocConfigRead:
|
||||
f = raw.get("feishu") or {}
|
||||
y = raw.get("yuque") or {}
|
||||
t = raw.get("tencent") or {}
|
||||
return CloudDocConfigRead(
|
||||
feishu=FeishuConfigRead(
|
||||
app_id=f.get("app_id") or "",
|
||||
app_secret_configured=bool((f.get("app_secret") or "").strip()),
|
||||
),
|
||||
yuque=YuqueConfigRead(
|
||||
token_configured=bool((y.get("token") or "").strip()),
|
||||
default_repo=(y.get("default_repo") or "").strip(),
|
||||
),
|
||||
tencent=TencentConfigRead(
|
||||
client_id=t.get("client_id") or "",
|
||||
client_secret_configured=bool((t.get("client_secret") or "").strip()),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=CloudDocConfigRead)
|
||||
async def get_cloud_doc_config():
|
||||
"""获取云文档配置(凭证以是否已配置返回,不返回明文)。"""
|
||||
raw = _load_config()
|
||||
return _mask_secrets_for_read(raw)
|
||||
|
||||
|
||||
@router.put("", response_model=CloudDocConfigRead)
|
||||
async def update_cloud_doc_config(payload: Dict[str, Any]):
|
||||
"""
|
||||
更新云文档配置。传各平台字段,未传的保留原值。
|
||||
例: { "feishu": { "app_id": "xxx", "app_secret": "yyy" }, "yuque": { "token": "zzz", "default_repo": "a/b" } }
|
||||
"""
|
||||
raw = _load_config()
|
||||
for platform in PLATFORMS:
|
||||
if platform not in payload or not isinstance(payload[platform], dict):
|
||||
continue
|
||||
p = payload[platform]
|
||||
if platform == "feishu":
|
||||
if "app_id" in p and p["app_id"] is not None:
|
||||
raw.setdefault("feishu", {})["app_id"] = str(p["app_id"]).strip()
|
||||
if "app_secret" in p and p["app_secret"] is not None:
|
||||
raw.setdefault("feishu", {})["app_secret"] = str(p["app_secret"]).strip()
|
||||
elif platform == "yuque":
|
||||
if "token" in p and p["token"] is not None:
|
||||
raw.setdefault("yuque", {})["token"] = str(p["token"]).strip()
|
||||
if "default_repo" in p and p["default_repo"] is not None:
|
||||
raw.setdefault("yuque", {})["default_repo"] = str(p["default_repo"]).strip()
|
||||
elif platform == "tencent":
|
||||
if "client_id" in p and p["client_id"] is not None:
|
||||
raw.setdefault("tencent", {})["client_id"] = str(p["client_id"]).strip()
|
||||
if "client_secret" in p and p["client_secret"] is not None:
|
||||
raw.setdefault("tencent", {})["client_secret"] = str(p["client_secret"]).strip()
|
||||
_save_config(raw)
|
||||
return _mask_secrets_for_read(raw)
|
||||
|
||||
|
||||
def get_credentials(platform: str) -> Dict[str, str]:
|
||||
"""供 cloud_doc_service 使用:读取某平台明文凭证。"""
|
||||
raw = _load_config()
|
||||
return (raw.get(platform) or {}).copy()
|
||||
|
||||
|
||||
def get_all_credentials() -> Dict[str, Dict[str, str]]:
|
||||
"""供推送流程使用:读取全部平台凭证(明文)。"""
|
||||
raw = _load_config()
|
||||
return {k: dict(v) for k, v in raw.items() if isinstance(v, dict)}
|
||||
91
backend/app/routers/cloud_docs.py
Normal file
91
backend/app/routers/cloud_docs.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
云文档快捷入口:持久化在 data/cloud_docs.json,支持增删改查。
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter(prefix="/settings/cloud-docs", tags=["cloud-docs"])
|
||||
|
||||
CONFIG_PATH = Path("data/cloud_docs.json")
|
||||
|
||||
|
||||
class CloudDocLinkCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=64, description="显示名称")
|
||||
url: str = Field(..., min_length=1, max_length=512, description="登录/入口 URL")
|
||||
|
||||
|
||||
class CloudDocLinkRead(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
|
||||
|
||||
class CloudDocLinkUpdate(BaseModel):
|
||||
name: str | None = Field(None, min_length=1, max_length=64)
|
||||
url: str | None = Field(None, min_length=1, max_length=512)
|
||||
|
||||
|
||||
def _load_links() -> List[Dict[str, Any]]:
|
||||
if not CONFIG_PATH.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, list) else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _save_links(links: List[Dict[str, Any]]) -> None:
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_PATH.write_text(
|
||||
json.dumps(links, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[CloudDocLinkRead])
|
||||
async def list_cloud_docs():
|
||||
"""获取所有云文档快捷入口。"""
|
||||
links = _load_links()
|
||||
return [CloudDocLinkRead(**x) for x in links]
|
||||
|
||||
|
||||
@router.post("", response_model=CloudDocLinkRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_cloud_doc(payload: CloudDocLinkCreate):
|
||||
"""新增一条云文档入口。"""
|
||||
links = _load_links()
|
||||
new_id = str(uuid.uuid4())[:8]
|
||||
new_item = {"id": new_id, "name": payload.name.strip(), "url": payload.url.strip()}
|
||||
links.append(new_item)
|
||||
_save_links(links)
|
||||
return CloudDocLinkRead(**new_item)
|
||||
|
||||
|
||||
@router.put("/{link_id}", response_model=CloudDocLinkRead)
|
||||
async def update_cloud_doc(link_id: str, payload: CloudDocLinkUpdate):
|
||||
"""更新名称或 URL。"""
|
||||
links = _load_links()
|
||||
for item in links:
|
||||
if item.get("id") == link_id:
|
||||
if payload.name is not None:
|
||||
item["name"] = payload.name.strip()
|
||||
if payload.url is not None:
|
||||
item["url"] = payload.url.strip()
|
||||
_save_links(links)
|
||||
return CloudDocLinkRead(**item)
|
||||
raise HTTPException(status_code=404, detail="云文档入口不存在")
|
||||
|
||||
|
||||
@router.delete("/{link_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_cloud_doc(link_id: str):
|
||||
"""删除一条云文档入口。"""
|
||||
links = _load_links()
|
||||
new_list = [x for x in links if x.get("id") != link_id]
|
||||
if len(new_list) == len(links):
|
||||
raise HTTPException(status_code=404, detail="云文档入口不存在")
|
||||
_save_links(new_list)
|
||||
@@ -16,9 +16,18 @@ router = APIRouter(prefix="/customers", tags=["customers"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[CustomerRead])
|
||||
async def list_customers(db: Session = Depends(get_db)):
|
||||
customers = db.query(models.Customer).order_by(models.Customer.created_at.desc()).all()
|
||||
return customers
|
||||
async def list_customers(
|
||||
q: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列表客户,支持 q 按名称、联系方式模糊搜索。"""
|
||||
query = db.query(models.Customer).order_by(models.Customer.created_at.desc())
|
||||
if q and q.strip():
|
||||
term = f"%{q.strip()}%"
|
||||
query = query.filter(
|
||||
(models.Customer.name.ilike(term)) | (models.Customer.contact_info.ilike(term))
|
||||
)
|
||||
return query.all()
|
||||
|
||||
|
||||
@router.post("/", response_model=CustomerRead, status_code=status.HTTP_201_CREATED)
|
||||
@@ -26,6 +35,7 @@ async def create_customer(payload: CustomerCreate, db: Session = Depends(get_db)
|
||||
customer = models.Customer(
|
||||
name=payload.name,
|
||||
contact_info=payload.contact_info,
|
||||
tags=payload.tags,
|
||||
)
|
||||
db.add(customer)
|
||||
db.commit()
|
||||
@@ -53,6 +63,8 @@ async def update_customer(
|
||||
customer.name = payload.name
|
||||
if payload.contact_info is not None:
|
||||
customer.contact_info = payload.contact_info
|
||||
if payload.tags is not None:
|
||||
customer.tags = payload.tags
|
||||
|
||||
db.commit()
|
||||
db.refresh(customer)
|
||||
|
||||
183
backend/app/routers/email_configs.py
Normal file
183
backend/app/routers/email_configs.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Email accounts for multi-email finance sync. Stored in data/email_configs.json.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter(prefix="/settings/email", tags=["email-configs"])
|
||||
|
||||
CONFIG_PATH = Path("data/email_configs.json")
|
||||
|
||||
|
||||
class EmailConfigCreate(BaseModel):
|
||||
host: str = Field(..., description="IMAP host")
|
||||
port: int = Field(993, description="IMAP port")
|
||||
user: str = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password or authorization code")
|
||||
mailbox: str = Field("INBOX", description="Mailbox name")
|
||||
active: bool = Field(True, description="Include in sync")
|
||||
|
||||
|
||||
class EmailConfigRead(BaseModel):
|
||||
id: str
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
mailbox: str
|
||||
active: bool
|
||||
|
||||
|
||||
class EmailConfigUpdate(BaseModel):
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
mailbox: str | None = None
|
||||
active: bool | None = None
|
||||
|
||||
|
||||
def _load_configs() -> List[Dict[str, Any]]:
|
||||
if not CONFIG_PATH.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, list) else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _save_configs(configs: List[Dict[str, Any]]) -> None:
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_PATH.write_text(
|
||||
json.dumps(configs, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _to_read(c: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": c["id"],
|
||||
"host": c["host"],
|
||||
"port": c["port"],
|
||||
"user": c["user"],
|
||||
"mailbox": c.get("mailbox", "INBOX"),
|
||||
"active": c.get("active", True),
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=List[EmailConfigRead])
|
||||
async def list_email_configs():
|
||||
"""List all email account configs (password omitted)."""
|
||||
configs = _load_configs()
|
||||
return [_to_read(c) for c in configs]
|
||||
|
||||
|
||||
@router.post("", response_model=EmailConfigRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_email_config(payload: EmailConfigCreate):
|
||||
"""Add a new email account."""
|
||||
configs = _load_configs()
|
||||
new_id = str(uuid.uuid4())
|
||||
configs.append({
|
||||
"id": new_id,
|
||||
"host": payload.host,
|
||||
"port": payload.port,
|
||||
"user": payload.user,
|
||||
"password": payload.password,
|
||||
"mailbox": payload.mailbox,
|
||||
"active": payload.active,
|
||||
})
|
||||
_save_configs(configs)
|
||||
return _to_read(configs[-1])
|
||||
|
||||
|
||||
@router.put("/{config_id}", response_model=EmailConfigRead)
|
||||
async def update_email_config(config_id: str, payload: EmailConfigUpdate):
|
||||
"""Update an email account (omit password to keep existing)."""
|
||||
configs = _load_configs()
|
||||
for c in configs:
|
||||
if c.get("id") == config_id:
|
||||
if payload.host is not None:
|
||||
c["host"] = payload.host
|
||||
if payload.port is not None:
|
||||
c["port"] = payload.port
|
||||
if payload.user is not None:
|
||||
c["user"] = payload.user
|
||||
if payload.password is not None:
|
||||
c["password"] = payload.password
|
||||
if payload.mailbox is not None:
|
||||
c["mailbox"] = payload.mailbox
|
||||
if payload.active is not None:
|
||||
c["active"] = payload.active
|
||||
_save_configs(configs)
|
||||
return _to_read(c)
|
||||
raise HTTPException(status_code=404, detail="Email config not found")
|
||||
|
||||
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_email_config(config_id: str):
|
||||
"""Remove an email account."""
|
||||
configs = _load_configs()
|
||||
new_list = [c for c in configs if c.get("id") != config_id]
|
||||
if len(new_list) == len(configs):
|
||||
raise HTTPException(status_code=404, detail="Email config not found")
|
||||
_save_configs(new_list)
|
||||
|
||||
|
||||
@router.get("/{config_id}/folders")
|
||||
async def list_email_folders(config_id: str):
|
||||
"""
|
||||
List mailbox folders for this account (for choosing custom labels).
|
||||
Returns [{ "raw": "...", "decoded": "收件箱" }, ...]. Use decoded for display and for mailbox config.
|
||||
"""
|
||||
import asyncio
|
||||
from backend.app.services.email_service import list_mailboxes_for_config
|
||||
|
||||
configs = _load_configs()
|
||||
config = next((c for c in configs if c.get("id") == config_id), None)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Email config not found")
|
||||
host = config.get("host")
|
||||
user = config.get("user")
|
||||
password = config.get("password")
|
||||
port = int(config.get("port", 993))
|
||||
if not all([host, user, password]):
|
||||
raise HTTPException(status_code=400, detail="Config missing host/user/password")
|
||||
|
||||
def _fetch():
|
||||
return list_mailboxes_for_config(host, port, user, password)
|
||||
|
||||
try:
|
||||
folders = await asyncio.to_thread(_fetch)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"无法连接邮箱或获取文件夹列表: {e}") from e
|
||||
|
||||
return {"folders": [{"raw": r, "decoded": d} for r, d in folders]}
|
||||
|
||||
|
||||
def get_email_configs_for_sync() -> List[Dict[str, Any]]:
|
||||
"""Return list of configs that are active (for sync). Falls back to env if file empty."""
|
||||
configs = _load_configs()
|
||||
active = [c for c in configs if c.get("active", True)]
|
||||
if active:
|
||||
return active
|
||||
# Fallback to single account from env
|
||||
import os
|
||||
host = os.getenv("IMAP_HOST")
|
||||
user = os.getenv("IMAP_USER")
|
||||
password = os.getenv("IMAP_PASSWORD")
|
||||
if host and user and password:
|
||||
return [{
|
||||
"id": "env",
|
||||
"host": host,
|
||||
"port": int(os.getenv("IMAP_PORT", "993")),
|
||||
"user": user,
|
||||
"password": password,
|
||||
"mailbox": os.getenv("IMAP_MAILBOX", "INBOX"),
|
||||
"active": True,
|
||||
}]
|
||||
return []
|
||||
@@ -1,8 +1,20 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import List
|
||||
|
||||
from backend.app.schemas import FinanceSyncResponse, FinanceSyncResult
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db import get_db
|
||||
from backend.app import models
|
||||
from backend.app.schemas import (
|
||||
FinanceRecordRead,
|
||||
FinanceRecordUpdate,
|
||||
FinanceSyncResponse,
|
||||
FinanceSyncResult,
|
||||
FinanceUploadResponse,
|
||||
)
|
||||
from backend.app.services.email_service import create_monthly_zip, sync_finance_emails
|
||||
from backend.app.services.invoice_upload import process_invoice_upload
|
||||
|
||||
|
||||
router = APIRouter(prefix="/finance", tags=["finance"])
|
||||
@@ -13,10 +25,87 @@ async def sync_finance():
|
||||
try:
|
||||
items_raw = await sync_finance_emails()
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
# 邮箱配置/连接等问题属于可预期的业务错误,用 400 让前端直接展示原因,而不是泛化为 500。
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
items = [FinanceSyncResult(**item) for item in items_raw]
|
||||
return FinanceSyncResponse(items=items)
|
||||
details = [FinanceSyncResult(**item) for item in items_raw]
|
||||
return FinanceSyncResponse(
|
||||
status="success",
|
||||
new_files=len(details),
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/months", response_model=List[str])
|
||||
async def list_finance_months(db: Session = Depends(get_db)):
|
||||
"""List distinct months that have finance records (YYYY-MM), newest first."""
|
||||
from sqlalchemy import distinct
|
||||
|
||||
rows = (
|
||||
db.query(distinct(models.FinanceRecord.month))
|
||||
.order_by(models.FinanceRecord.month.desc())
|
||||
.all()
|
||||
)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
@router.get("/records", response_model=List[FinanceRecordRead])
|
||||
async def list_finance_records(
|
||||
month: str = Query(..., description="YYYY-MM"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List finance records for a given month."""
|
||||
records = (
|
||||
db.query(models.FinanceRecord)
|
||||
.filter(models.FinanceRecord.month == month)
|
||||
.order_by(models.FinanceRecord.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
@router.post("/upload", response_model=FinanceUploadResponse, status_code=201)
|
||||
async def upload_invoice(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Upload an invoice (PDF or image). Saves to data/finance/{YYYY-MM}/manual/, runs AI OCR for amount/date."""
|
||||
suf = (file.filename or "").lower().split(".")[-1] if "." in (file.filename or "") else ""
|
||||
allowed = {"pdf", "jpg", "jpeg", "png", "webp"}
|
||||
if suf not in allowed:
|
||||
raise HTTPException(400, "仅支持 PDF、JPG、PNG、WEBP 格式")
|
||||
file_name, file_path, month_str, amount, billing_date = await process_invoice_upload(file)
|
||||
record = models.FinanceRecord(
|
||||
month=month_str,
|
||||
type="manual",
|
||||
file_name=file_name,
|
||||
file_path=file_path,
|
||||
amount=amount,
|
||||
billing_date=billing_date,
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
@router.patch("/records/{record_id}", response_model=FinanceRecordRead)
|
||||
async def update_finance_record(
|
||||
record_id: int,
|
||||
payload: FinanceRecordUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update amount and/or billing_date of a finance record (e.g. after manual review)."""
|
||||
record = db.query(models.FinanceRecord).get(record_id)
|
||||
if not record:
|
||||
raise HTTPException(404, "记录不存在")
|
||||
if payload.amount is not None:
|
||||
record.amount = payload.amount
|
||||
if payload.billing_date is not None:
|
||||
record.billing_date = payload.billing_date
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
@router.get("/download/{month}")
|
||||
|
||||
91
backend/app/routers/portal_links.py
Normal file
91
backend/app/routers/portal_links.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
快捷门户入口:持久化在 data/portal_links.json,支持增删改查。
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter(prefix="/settings/portal-links", tags=["portal-links"])
|
||||
|
||||
CONFIG_PATH = Path("data/portal_links.json")
|
||||
|
||||
|
||||
class PortalLinkCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=64, description="显示名称")
|
||||
url: str = Field(..., min_length=1, max_length=512, description="门户 URL")
|
||||
|
||||
|
||||
class PortalLinkRead(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
|
||||
|
||||
class PortalLinkUpdate(BaseModel):
|
||||
name: str | None = Field(None, min_length=1, max_length=64)
|
||||
url: str | None = Field(None, min_length=1, max_length=512)
|
||||
|
||||
|
||||
def _load_links() -> List[Dict[str, Any]]:
|
||||
if not CONFIG_PATH.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, list) else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _save_links(links: List[Dict[str, Any]]) -> None:
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_PATH.write_text(
|
||||
json.dumps(links, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[PortalLinkRead])
|
||||
async def list_portal_links():
|
||||
"""获取所有快捷门户入口。"""
|
||||
links = _load_links()
|
||||
return [PortalLinkRead(**x) for x in links]
|
||||
|
||||
|
||||
@router.post("", response_model=PortalLinkRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_portal_link(payload: PortalLinkCreate):
|
||||
"""新增一条快捷门户入口。"""
|
||||
links = _load_links()
|
||||
new_id = str(uuid.uuid4())[:8]
|
||||
new_item = {"id": new_id, "name": payload.name.strip(), "url": payload.url.strip()}
|
||||
links.append(new_item)
|
||||
_save_links(links)
|
||||
return PortalLinkRead(**new_item)
|
||||
|
||||
|
||||
@router.put("/{link_id}", response_model=PortalLinkRead)
|
||||
async def update_portal_link(link_id: str, payload: PortalLinkUpdate):
|
||||
"""更新名称或 URL。"""
|
||||
links = _load_links()
|
||||
for item in links:
|
||||
if item.get("id") == link_id:
|
||||
if payload.name is not None:
|
||||
item["name"] = payload.name.strip()
|
||||
if payload.url is not None:
|
||||
item["url"] = payload.url.strip()
|
||||
_save_links(links)
|
||||
return PortalLinkRead(**item)
|
||||
raise HTTPException(status_code=404, detail="快捷门户入口不存在")
|
||||
|
||||
|
||||
@router.delete("/{link_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_portal_link(link_id: str):
|
||||
"""删除一条快捷门户入口。"""
|
||||
links = _load_links()
|
||||
new_list = [x for x in links if x.get("id") != link_id]
|
||||
if len(new_list) == len(links):
|
||||
raise HTTPException(status_code=404, detail="快捷门户入口不存在")
|
||||
_save_links(new_list)
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from backend.app import models
|
||||
from backend.app.db import get_db
|
||||
@@ -11,25 +13,35 @@ from backend.app.schemas import (
|
||||
ContractGenerateResponse,
|
||||
ProjectRead,
|
||||
ProjectUpdate,
|
||||
PushToCloudRequest,
|
||||
PushToCloudResponse,
|
||||
QuoteGenerateResponse,
|
||||
RequirementAnalyzeRequest,
|
||||
RequirementAnalyzeResponse,
|
||||
)
|
||||
from backend.app.services.ai_service import analyze_requirement
|
||||
from backend.app.services.cloud_doc_service import CloudDocManager
|
||||
from backend.app.services.doc_service import (
|
||||
generate_contract_word,
|
||||
generate_quote_excel,
|
||||
generate_quote_pdf_from_data,
|
||||
)
|
||||
from backend.app.routers.cloud_doc_config import get_all_credentials
|
||||
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["projects"])
|
||||
|
||||
|
||||
def _build_markdown_from_analysis(data: Dict[str, Any]) -> str:
|
||||
def _build_markdown_from_analysis(data: Union[Dict[str, Any], List[Any]]) -> str:
|
||||
"""
|
||||
Convert structured AI analysis JSON into a human-editable Markdown document.
|
||||
Tolerates AI returning a list (e.g. modules only) and normalizes to a dict.
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = {"modules": data, "total_estimated_hours": None, "total_amount": None, "notes": None}
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
|
||||
lines: list[str] = []
|
||||
lines.append("# 项目方案草稿")
|
||||
lines.append("")
|
||||
@@ -48,7 +60,15 @@ def _build_markdown_from_analysis(data: Dict[str, Any]) -> str:
|
||||
if modules:
|
||||
lines.append("## 功能模块与技术实现")
|
||||
for idx, module in enumerate(modules, start=1):
|
||||
name = module.get("name", f"模块 {idx}")
|
||||
if not isinstance(module, dict):
|
||||
# AI sometimes returns strings or other shapes; treat as a single title line
|
||||
raw_name = str(module).strip() if module else ""
|
||||
name = raw_name if len(raw_name) > 1 and raw_name not in (":", "[", "{", "}") else f"模块 {idx}"
|
||||
lines.append(f"### {idx}. {name}")
|
||||
lines.append("")
|
||||
continue
|
||||
raw_name = (module.get("name") or "").strip()
|
||||
name = raw_name if len(raw_name) > 1 and raw_name not in (":", "[", "{", "}") else f"模块 {idx}"
|
||||
desc = module.get("description") or ""
|
||||
tech = module.get("technical_approach") or ""
|
||||
hours = module.get("estimated_hours")
|
||||
@@ -83,13 +103,31 @@ def _build_markdown_from_analysis(data: Dict[str, Any]) -> str:
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ProjectRead])
|
||||
async def list_projects(db: Session = Depends(get_db)):
|
||||
projects = (
|
||||
async def list_projects(
|
||||
customer_tag: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列表项目;customer_tag 不为空时只返回该客户标签下的项目(按客户 tags 筛选)。"""
|
||||
query = (
|
||||
db.query(models.Project)
|
||||
.options(joinedload(models.Project.customer))
|
||||
.join(models.Customer)
|
||||
.order_by(models.Project.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return projects
|
||||
if customer_tag and customer_tag.strip():
|
||||
tag = customer_tag.strip()
|
||||
# 客户 tags 逗号分隔,按整词匹配
|
||||
from sqlalchemy import or_
|
||||
t = models.Customer.tags
|
||||
query = query.filter(
|
||||
or_(
|
||||
t == tag,
|
||||
t.ilike(f"{tag},%"),
|
||||
t.ilike(f"%,{tag},%"),
|
||||
t.ilike(f"%,{tag}"),
|
||||
)
|
||||
)
|
||||
return query.all()
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectRead)
|
||||
@@ -109,6 +147,8 @@ async def update_project(
|
||||
project = db.query(models.Project).get(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")
|
||||
if payload.raw_requirement is not None:
|
||||
project.raw_requirement = payload.raw_requirement
|
||||
if payload.ai_solution_md is not None:
|
||||
project.ai_solution_md = payload.ai_solution_md
|
||||
if payload.status is not None:
|
||||
@@ -123,12 +163,24 @@ async def analyze_project_requirement(
|
||||
payload: RequirementAnalyzeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
logging.getLogger(__name__).info(
|
||||
"收到 AI 解析请求: customer_id=%s, 需求长度=%d",
|
||||
payload.customer_id,
|
||||
len(payload.raw_text or ""),
|
||||
)
|
||||
# Ensure customer exists
|
||||
customer = db.query(models.Customer).get(payload.customer_id)
|
||||
if not customer:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found")
|
||||
|
||||
analysis = await analyze_requirement(payload.raw_text)
|
||||
try:
|
||||
analysis = await analyze_requirement(payload.raw_text)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
ai_solution_md = _build_markdown_from_analysis(analysis)
|
||||
|
||||
project = models.Project(
|
||||
@@ -151,6 +203,7 @@ async def analyze_project_requirement(
|
||||
@router.post("/{project_id}/generate_quote", response_model=QuoteGenerateResponse)
|
||||
async def generate_project_quote(
|
||||
project_id: int,
|
||||
template: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
project = db.query(models.Project).get(project_id)
|
||||
@@ -167,7 +220,9 @@ async def generate_project_quote(
|
||||
excel_path = base_dir / f"quote_project_{project.id}.xlsx"
|
||||
pdf_path = base_dir / f"quote_project_{project.id}.pdf"
|
||||
|
||||
template_path = Path("templates/quote_template.xlsx")
|
||||
from backend.app.routers.settings import get_quote_template_path
|
||||
|
||||
template_path = get_quote_template_path(template)
|
||||
if not template_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -254,3 +309,61 @@ async def generate_project_contract(
|
||||
|
||||
return ContractGenerateResponse(project_id=project.id, contract_path=str(output_path))
|
||||
|
||||
|
||||
@router.post("/{project_id}/push-to-cloud", response_model=PushToCloudResponse)
|
||||
async def push_project_to_cloud(
|
||||
project_id: int,
|
||||
payload: PushToCloudRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
将当前项目方案(Markdown)推送到云文档。若该项目此前已推送过该平台,则更新原文档(增量同步)。
|
||||
"""
|
||||
project = db.query(models.Project).options(joinedload(models.Project.customer)).get(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")
|
||||
platform = (payload.platform or "").strip().lower()
|
||||
if platform not in ("feishu", "yuque", "tencent"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="platform 须为 feishu / yuque / tencent",
|
||||
)
|
||||
title = (payload.title or "").strip() or f"项目方案 - 项目#{project_id}"
|
||||
body_md = (payload.body_md if payload.body_md is not None else project.ai_solution_md) or ""
|
||||
if not body_md.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="暂无方案内容,请先在编辑器中填写或保存方案后再推送",
|
||||
)
|
||||
existing = (
|
||||
db.query(models.ProjectCloudDoc)
|
||||
.filter(
|
||||
models.ProjectCloudDoc.project_id == project_id,
|
||||
models.ProjectCloudDoc.platform == platform,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
existing_doc_id = existing.cloud_doc_id if existing else None
|
||||
credentials = get_all_credentials()
|
||||
manager = CloudDocManager(credentials)
|
||||
try:
|
||||
cloud_doc_id, url = await manager.push_markdown(
|
||||
platform, title, body_md, existing_doc_id=existing_doc_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
if existing:
|
||||
existing.cloud_doc_id = cloud_doc_id
|
||||
existing.cloud_url = url
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
record = models.ProjectCloudDoc(
|
||||
project_id=project_id,
|
||||
platform=platform,
|
||||
cloud_doc_id=cloud_doc_id,
|
||||
cloud_url=url,
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
return PushToCloudResponse(url=url, cloud_doc_id=cloud_doc_id)
|
||||
|
||||
|
||||
93
backend/app/routers/settings.py
Normal file
93
backend/app/routers/settings.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
|
||||
router = APIRouter(prefix="/settings", tags=["settings"])
|
||||
|
||||
TEMPLATES_DIR = Path("data/templates")
|
||||
ALLOWED_EXCEL = {".xlsx", ".xltx"}
|
||||
ALLOWED_WORD = {".docx", ".dotx"}
|
||||
ALLOWED_EXTENSIONS = ALLOWED_EXCEL | ALLOWED_WORD
|
||||
|
||||
# Allowed MIME types when client sends Content-Type (validate if present)
|
||||
ALLOWED_MIME_TYPES = frozenset({
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # .xlsx
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.template", # .xltx
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # .docx
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.template", # .dotx
|
||||
})
|
||||
|
||||
|
||||
def _ensure_templates_dir() -> Path:
|
||||
TEMPLATES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return TEMPLATES_DIR
|
||||
|
||||
|
||||
@router.get("/templates", response_model=List[dict])
|
||||
async def list_templates():
|
||||
"""List uploaded template files (name, type, size, mtime)."""
|
||||
_ensure_templates_dir()
|
||||
out: List[dict] = []
|
||||
for f in sorted(TEMPLATES_DIR.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True):
|
||||
if not f.is_file():
|
||||
continue
|
||||
suf = f.suffix.lower()
|
||||
if suf not in ALLOWED_EXTENSIONS:
|
||||
continue
|
||||
st = f.stat()
|
||||
out.append({
|
||||
"name": f.name,
|
||||
"type": "excel" if suf in ALLOWED_EXCEL else "word",
|
||||
"size": st.st_size,
|
||||
"uploaded_at": st.st_mtime,
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
@router.post("/templates/upload", status_code=status.HTTP_201_CREATED)
|
||||
async def upload_template(file: UploadFile = File(...)):
|
||||
"""Upload a .xlsx, .xltx, .docx or .dotx template to data/templates/."""
|
||||
suf = Path(file.filename or "").suffix.lower()
|
||||
if suf not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only .xlsx, .xltx, .docx and .dotx files are allowed.",
|
||||
)
|
||||
content_type = (file.content_type or "").strip().split(";")[0].strip().lower()
|
||||
if content_type and content_type not in ALLOWED_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid content type. Allowed: .xlsx, .xltx, .docx, .dotx Office formats.",
|
||||
)
|
||||
dir_path = _ensure_templates_dir()
|
||||
dest = dir_path / (file.filename or "template" + suf)
|
||||
content = await file.read()
|
||||
dest.write_bytes(content)
|
||||
return {"name": dest.name, "path": str(dest)}
|
||||
|
||||
|
||||
def get_latest_excel_template() -> Path | None:
|
||||
"""Return path to the most recently modified .xlsx or .xltx in data/templates, or None."""
|
||||
if not TEMPLATES_DIR.exists():
|
||||
return None
|
||||
excel_files = [
|
||||
f for f in TEMPLATES_DIR.iterdir()
|
||||
if f.is_file() and f.suffix.lower() in ALLOWED_EXCEL
|
||||
]
|
||||
if not excel_files:
|
||||
return None
|
||||
return max(excel_files, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
|
||||
def get_quote_template_path(template_filename: str | None) -> Path:
|
||||
"""Resolve quote template path: optional filename in data/templates or latest excel template or default."""
|
||||
if template_filename:
|
||||
candidate = TEMPLATES_DIR / template_filename
|
||||
if candidate.is_file() and candidate.suffix.lower() in ALLOWED_EXCEL:
|
||||
return candidate
|
||||
latest = get_latest_excel_template()
|
||||
if latest:
|
||||
return latest
|
||||
default = Path("templates/quote_template.xlsx")
|
||||
return default
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
|
||||
class CustomerBase(BaseModel):
|
||||
name: str = Field(..., description="Customer name")
|
||||
contact_info: Optional[str] = Field(None, description="Contact information")
|
||||
tags: Optional[str] = Field(None, description="Comma-separated tags, e.g. 重点客户,已签约")
|
||||
|
||||
|
||||
class CustomerCreate(CustomerBase):
|
||||
@@ -16,6 +17,7 @@ class CustomerCreate(CustomerBase):
|
||||
class CustomerUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
contact_info: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
|
||||
|
||||
class CustomerRead(CustomerBase):
|
||||
@@ -33,12 +35,14 @@ class ProjectRead(BaseModel):
|
||||
ai_solution_md: Optional[str] = None
|
||||
status: str
|
||||
created_at: datetime
|
||||
customer: Optional[CustomerRead] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
raw_requirement: Optional[str] = None
|
||||
ai_solution_md: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
@@ -75,6 +79,17 @@ class ContractGenerateResponse(BaseModel):
|
||||
contract_path: str
|
||||
|
||||
|
||||
class PushToCloudRequest(BaseModel):
|
||||
platform: str = Field(..., description="feishu | yuque | tencent")
|
||||
title: Optional[str] = Field(None, description="文档标题,默认使用「项目方案 - 项目#id」")
|
||||
body_md: Optional[str] = Field(None, description="要推送的 Markdown 内容,不传则使用项目已保存的方案")
|
||||
|
||||
|
||||
class PushToCloudResponse(BaseModel):
|
||||
url: str
|
||||
cloud_doc_id: str
|
||||
|
||||
|
||||
class FinanceSyncResult(BaseModel):
|
||||
id: int
|
||||
month: str
|
||||
@@ -84,5 +99,40 @@ class FinanceSyncResult(BaseModel):
|
||||
|
||||
|
||||
class FinanceSyncResponse(BaseModel):
|
||||
items: List[FinanceSyncResult]
|
||||
status: str = "success"
|
||||
new_files: int = 0
|
||||
details: List[FinanceSyncResult] = Field(default_factory=list)
|
||||
|
||||
|
||||
class FinanceRecordRead(BaseModel):
|
||||
id: int
|
||||
month: str
|
||||
type: str
|
||||
file_name: str
|
||||
file_path: str
|
||||
amount: Optional[float] = None
|
||||
billing_date: Optional[date] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FinanceRecordUpdate(BaseModel):
|
||||
amount: Optional[float] = None
|
||||
billing_date: Optional[date] = None
|
||||
|
||||
|
||||
class FinanceUploadResponse(BaseModel):
|
||||
id: int
|
||||
month: str
|
||||
type: str
|
||||
file_name: str
|
||||
file_path: str
|
||||
amount: Optional[float] = None
|
||||
billing_date: Optional[date] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@@ -1,37 +1,71 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai import NotFoundError as OpenAINotFoundError
|
||||
|
||||
AI_CONFIG_PATH = Path("data/ai_config.json")
|
||||
AI_CONFIGS_PATH = Path("data/ai_configs.json")
|
||||
|
||||
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_ai_client() -> AsyncOpenAI:
|
||||
def get_active_ai_config() -> Dict[str, Any]:
|
||||
"""
|
||||
Create (or reuse) a singleton AsyncOpenAI client.
|
||||
|
||||
The client is configured via:
|
||||
- AI_API_KEY / OPENAI_API_KEY
|
||||
- AI_BASE_URL (optional, defaults to official OpenAI endpoint)
|
||||
- AI_MODEL (optional, defaults to gpt-4.1-mini or a similar capable model)
|
||||
从 data/ai_configs.json 读取当前选用的配置;若无则从旧版 ai_config.json 迁移并返回。
|
||||
供 router 与内部调用。
|
||||
"""
|
||||
global _client
|
||||
if _client is not None:
|
||||
return _client
|
||||
defaults = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"provider": "OpenAI",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"model_name": "gpt-4o-mini",
|
||||
"temperature": 0.2,
|
||||
"system_prompt_override": "",
|
||||
}
|
||||
if AI_CONFIGS_PATH.exists():
|
||||
try:
|
||||
data = json.loads(AI_CONFIGS_PATH.read_text(encoding="utf-8"))
|
||||
configs = data.get("configs") or []
|
||||
active_id = data.get("active_id") or ""
|
||||
for c in configs:
|
||||
if c.get("id") == active_id:
|
||||
return {**defaults, **c}
|
||||
if configs:
|
||||
return {**defaults, **configs[0]}
|
||||
except Exception:
|
||||
pass
|
||||
# 兼容旧版单文件
|
||||
if AI_CONFIG_PATH.exists():
|
||||
try:
|
||||
data = json.loads(AI_CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return {**defaults, **data}
|
||||
except Exception:
|
||||
pass
|
||||
if not defaults.get("api_key"):
|
||||
defaults["api_key"] = os.getenv("AI_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
if not defaults.get("base_url") and os.getenv("AI_BASE_URL"):
|
||||
defaults["base_url"] = os.getenv("AI_BASE_URL")
|
||||
if defaults.get("model_name") == "gpt-4o-mini" and os.getenv("AI_MODEL"):
|
||||
defaults["model_name"] = os.getenv("AI_MODEL")
|
||||
return defaults
|
||||
|
||||
api_key = os.getenv("AI_API_KEY") or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
def _load_ai_config() -> Dict[str, Any]:
|
||||
"""当前生效的 AI 配置(供需求解析、发票识别等使用)。"""
|
||||
return get_active_ai_config()
|
||||
|
||||
|
||||
def _client_from_config(config: Dict[str, Any]) -> AsyncOpenAI:
|
||||
api_key = (config.get("api_key") or "").strip()
|
||||
if not api_key:
|
||||
raise RuntimeError("AI_API_KEY or OPENAI_API_KEY must be set in environment.")
|
||||
|
||||
base_url = os.getenv("AI_BASE_URL") # can point to OpenAI, DeepSeek, Qwen, etc.
|
||||
|
||||
_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
)
|
||||
return _client
|
||||
raise RuntimeError("AI API Key 未配置,请在 设置 → AI 模型配置 中填写。")
|
||||
base_url = (config.get("base_url") or "").strip() or None
|
||||
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def _build_requirement_prompt(raw_text: str) -> str:
|
||||
@@ -71,38 +105,139 @@ def _build_requirement_prompt(raw_text: str) -> str:
|
||||
async def analyze_requirement(raw_text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Call the AI model to analyze customer requirements.
|
||||
|
||||
Returns a Python dict matching the JSON structure described
|
||||
in `_build_requirement_prompt`.
|
||||
Reads config from data/ai_config.json (and env fallback) on every request.
|
||||
"""
|
||||
client = get_ai_client()
|
||||
model = os.getenv("AI_MODEL", "gpt-4.1-mini")
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = _load_ai_config()
|
||||
client = _client_from_config(config)
|
||||
model = config.get("model_name") or "gpt-4o-mini"
|
||||
temperature = float(config.get("temperature", 0.2))
|
||||
system_override = (config.get("system_prompt_override") or "").strip()
|
||||
|
||||
logger.info("AI 需求解析: 调用模型 %s,输入长度 %d 字符", model, len(raw_text))
|
||||
|
||||
prompt = _build_requirement_prompt(raw_text)
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是一名严谨的系统架构师,只能输出有效的 JSON,不要输出任何解释文字。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
],
|
||||
temperature=0.2,
|
||||
system_content = (
|
||||
system_override
|
||||
if system_override
|
||||
else "你是一名严谨的系统架构师,只能输出有效的 JSON,不要输出任何解释文字。"
|
||||
)
|
||||
|
||||
try:
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=temperature,
|
||||
)
|
||||
except OpenAINotFoundError as e:
|
||||
raise RuntimeError(
|
||||
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」与当前提供商一致(如阿里云使用 qwen 系列、OpenAI 使用 gpt-4o-mini 等)。"
|
||||
) from e
|
||||
|
||||
content = completion.choices[0].message.content or "{}"
|
||||
try:
|
||||
data: Dict[str, Any] = json.loads(content)
|
||||
data: Any = json.loads(content)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error("AI 返回非 JSON,片段: %s", (content or "")[:200])
|
||||
raise RuntimeError(f"AI 返回的内容不是合法 JSON:{content}") from exc
|
||||
|
||||
# Some models return a list (e.g. modules only); normalize to expected dict shape
|
||||
if isinstance(data, list):
|
||||
data = {
|
||||
"modules": data,
|
||||
"total_estimated_hours": None,
|
||||
"total_amount": None,
|
||||
"notes": None,
|
||||
}
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
|
||||
mods = data.get("modules") or []
|
||||
logger.info("AI 需求解析完成: 模块数 %d", len(mods) if isinstance(mods, list) else 0)
|
||||
return data
|
||||
|
||||
|
||||
async def test_connection() -> str:
|
||||
"""使用当前选用配置测试连接。"""
|
||||
return await test_connection_with_config(get_active_ai_config())
|
||||
|
||||
|
||||
async def test_connection_with_config(config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
使用指定配置发送简单补全以验证 API Key 与 Base URL。
|
||||
供测试当前配置或指定 config_id 时使用。
|
||||
"""
|
||||
client = _client_from_config(config)
|
||||
model = config.get("model_name") or "gpt-4o-mini"
|
||||
try:
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=50,
|
||||
)
|
||||
except OpenAINotFoundError as e:
|
||||
raise RuntimeError(
|
||||
"当前配置的模型不存在或无权访问。请在 设置 → AI 模型配置 中确认「模型名称」(如阿里云使用 qwen 系列)。"
|
||||
) from e
|
||||
return (completion.choices[0].message.content or "").strip() or "OK"
|
||||
|
||||
|
||||
async def extract_invoice_metadata(image_bytes: bytes, mime: str = "image/jpeg") -> Tuple[float | None, str | None]:
|
||||
"""
|
||||
Use AI vision to extract total amount and invoice date from an image.
|
||||
Returns (amount, date_yyyy_mm_dd). On any error or unsupported model, returns (None, None).
|
||||
"""
|
||||
config = _load_ai_config()
|
||||
api_key = (config.get("api_key") or "").strip()
|
||||
if not api_key:
|
||||
return (None, None)
|
||||
try:
|
||||
client = _client_from_config(config)
|
||||
model = config.get("model_name") or "gpt-4o-mini"
|
||||
b64 = base64.b64encode(image_bytes).decode("ascii")
|
||||
data_url = f"data:{mime};base64,{b64}"
|
||||
prompt = (
|
||||
"从这张发票/收据图片中识别并提取:1) 价税合计/总金额(数字,不含货币符号);2) 开票日期(格式 YYYY-MM-DD)。"
|
||||
"只返回 JSON,不要其他文字,格式:{\"amount\": 数字或null, \"date\": \"YYYY-MM-DD\" 或 null}。"
|
||||
)
|
||||
completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=150,
|
||||
)
|
||||
content = (completion.choices[0].message.content or "").strip()
|
||||
if not content:
|
||||
return (None, None)
|
||||
# Handle markdown code block
|
||||
if "```" in content:
|
||||
content = re.sub(r"^.*?```(?:json)?\s*", "", content).strip()
|
||||
content = re.sub(r"\s*```.*$", "", content).strip()
|
||||
data = json.loads(content)
|
||||
amount_raw = data.get("amount")
|
||||
date_raw = data.get("date")
|
||||
amount = None
|
||||
if amount_raw is not None:
|
||||
try:
|
||||
amount = float(amount_raw)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
date_str = None
|
||||
if isinstance(date_raw, str) and re.match(r"\d{4}-\d{2}-\d{2}", date_raw):
|
||||
date_str = date_raw[:10]
|
||||
return (amount, date_str)
|
||||
except Exception:
|
||||
return (None, None)
|
||||
|
||||
315
backend/app/services/cloud_doc_service.py
Normal file
315
backend/app/services/cloud_doc_service.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
云文档集成:飞书、语雀、腾讯文档的文档创建/更新。
|
||||
统一以 Markdown 为中间格式,由各平台 API 写入。
|
||||
|
||||
扩展建议:可增加「月度财务明细表」自动导出——每月在飞书/腾讯文档生成表格,
|
||||
插入当月发票等附件预览链接,供财务查看(需对接财务记录与附件列表)。
|
||||
"""
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
FEISHU_BASE = "https://open.feishu.cn"
|
||||
YUQUE_BASE = "https://www.yuque.com/api/v2"
|
||||
|
||||
|
||||
async def get_feishu_tenant_token(app_id: str, app_secret: str) -> str:
|
||||
"""获取飞书 tenant_access_token。"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
f"{FEISHU_BASE}/open-apis/auth/v3/tenant_access_token/internal",
|
||||
json={"app_id": app_id, "app_secret": app_secret},
|
||||
timeout=10.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if data.get("code") != 0:
|
||||
raise RuntimeError(data.get("msg", "飞书鉴权失败"))
|
||||
return data["tenant_access_token"]
|
||||
|
||||
|
||||
def _feishu_text_block_elements(md: str) -> List[Dict[str, Any]]:
|
||||
"""将 Markdown 转为飞书文本块 elements(按行拆成 textRun,简单实现)。"""
|
||||
elements: List[Dict[str, Any]] = []
|
||||
for line in md.split("\n"):
|
||||
line = line.rstrip()
|
||||
if not line:
|
||||
elements.append({"type": "textRun", "text_run": {"text": "\n"}})
|
||||
else:
|
||||
elements.append({"type": "textRun", "text_run": {"text": line + "\n"}})
|
||||
if not elements:
|
||||
elements.append({"type": "textRun", "text_run": {"text": " "}})
|
||||
return elements
|
||||
|
||||
|
||||
async def feishu_create_doc(
|
||||
token: str, title: str, body_md: str, folder_token: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
创建飞书文档并写入内容。返回 (document_id, url)。
|
||||
使用 docx/v1:创建文档后向根块下添加子块写入 Markdown 文本。
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
# 1. 创建文档
|
||||
create_body: Dict[str, Any] = {"title": title[:50] or "未命名文档"}
|
||||
if folder_token:
|
||||
create_body["folder_token"] = folder_token
|
||||
r = await client.post(
|
||||
f"{FEISHU_BASE}/open-apis/docx/v1/documents",
|
||||
headers=headers,
|
||||
json=create_body,
|
||||
timeout=15.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if data.get("code") != 0:
|
||||
raise RuntimeError(data.get("msg", "飞书创建文档失败"))
|
||||
doc = data.get("data", {})
|
||||
document_id = doc.get("document", {}).get("document_id")
|
||||
if not document_id:
|
||||
raise RuntimeError("飞书未返回 document_id")
|
||||
url = doc.get("document", {}).get("url", "")
|
||||
# 2. 根块 ID 即 document_id(飞书约定)
|
||||
block_id = document_id
|
||||
# 3. 添加子块(内容)
|
||||
elements = _feishu_text_block_elements(body_md)
|
||||
# 单块有长度限制,分批写入多块
|
||||
chunk_size = 3000
|
||||
for i in range(0, len(elements), chunk_size):
|
||||
chunk = elements[i : i + chunk_size]
|
||||
body_json = {"children": [{"block_type": "text", "text": {"elements": chunk}}], "index": -1}
|
||||
r3 = await client.post(
|
||||
f"{FEISHU_BASE}/open-apis/docx/v1/documents/{document_id}/blocks/{block_id}/children",
|
||||
headers=headers,
|
||||
json=body_json,
|
||||
timeout=15.0,
|
||||
)
|
||||
r3.raise_for_status()
|
||||
res = r3.json()
|
||||
if res.get("code") != 0:
|
||||
raise RuntimeError(res.get("msg", "飞书写入块失败"))
|
||||
# 下一批挂在刚创建的块下
|
||||
new_items = res.get("data", {}).get("children", [])
|
||||
if new_items:
|
||||
block_id = new_items[0].get("block_id", block_id)
|
||||
return document_id, url or f"https://feishu.cn/docx/{document_id}"
|
||||
|
||||
|
||||
async def feishu_update_doc(token: str, document_id: str, body_md: str) -> str:
|
||||
"""
|
||||
更新飞书文档内容:获取现有块并批量更新首个文本块,或追加新块。
|
||||
返回文档 URL。
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
r = await client.get(
|
||||
f"{FEISHU_BASE}/open-apis/docx/v1/documents/{document_id}/blocks",
|
||||
headers=headers,
|
||||
params={"document_id": document_id},
|
||||
timeout=10.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if data.get("code") != 0:
|
||||
raise RuntimeError(data.get("msg", "飞书获取块失败"))
|
||||
items = data.get("data", {}).get("items", [])
|
||||
elements = _feishu_text_block_elements(body_md)
|
||||
if items:
|
||||
first_id = items[0].get("block_id")
|
||||
if first_id:
|
||||
# 批量更新:只更新第一个块的内容
|
||||
update_body = {
|
||||
"requests": [
|
||||
{
|
||||
"request_type": "blockUpdate",
|
||||
"block_id": first_id,
|
||||
"update_text": {"elements": elements},
|
||||
}
|
||||
]
|
||||
}
|
||||
r2 = await client.patch(
|
||||
f"{FEISHU_BASE}/open-apis/docx/v1/documents/{document_id}/blocks/batch_update",
|
||||
headers=headers,
|
||||
json=update_body,
|
||||
timeout=15.0,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
if r2.json().get("code") != 0:
|
||||
# 若 PATCH 不支持该块类型,则追加新块
|
||||
pass
|
||||
else:
|
||||
return f"https://feishu.cn/docx/{document_id}"
|
||||
# 无块或更新失败:在根下追加子块
|
||||
block_id = document_id
|
||||
for i in range(0, len(elements), 3000):
|
||||
chunk = elements[i : i + 3000]
|
||||
body_json = {"children": [{"block_type": "text", "text": {"elements": chunk}}], "index": -1}
|
||||
r3 = await client.post(
|
||||
f"{FEISHU_BASE}/open-apis/docx/v1/documents/{document_id}/blocks/{block_id}/children",
|
||||
headers=headers,
|
||||
json=body_json,
|
||||
timeout=15.0,
|
||||
)
|
||||
r3.raise_for_status()
|
||||
res = r3.json()
|
||||
if res.get("data", {}).get("children"):
|
||||
block_id = res["data"]["children"][0].get("block_id", block_id)
|
||||
return f"https://feishu.cn/docx/{document_id}"
|
||||
|
||||
|
||||
# --------------- 语雀 ---------------
|
||||
|
||||
|
||||
async def yuque_create_doc(
|
||||
token: str, repo_id_or_namespace: str, title: str, body_md: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
在语雀知识库创建文档。repo_id_or_namespace 可为 repo_id 或 namespace(如 user/repo)。
|
||||
返回 (doc_id, url)。
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"X-Auth-Token": token,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "OpsCore-CloudDoc/1.0",
|
||||
}
|
||||
# 若为 namespace 需先解析为 repo_id(语雀 API 创建文档用 repo_id)
|
||||
repo_id = repo_id_or_namespace
|
||||
if "/" in repo_id_or_namespace:
|
||||
r_repo = await client.get(
|
||||
f"{YUQUE_BASE}/repos/{repo_id_or_namespace}",
|
||||
headers=headers,
|
||||
timeout=10.0,
|
||||
)
|
||||
if r_repo.status_code == 200 and r_repo.json().get("data"):
|
||||
repo_id = str(r_repo.json()["data"]["id"])
|
||||
r = await client.post(
|
||||
f"{YUQUE_BASE}/repos/{repo_id}/docs",
|
||||
headers=headers,
|
||||
json={
|
||||
"title": title[:100] or "未命名",
|
||||
"body": body_md,
|
||||
"format": "markdown",
|
||||
},
|
||||
timeout=15.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
doc = data.get("data", {})
|
||||
doc_id = str(doc.get("id", ""))
|
||||
url = doc.get("url", "")
|
||||
if not url and doc.get("slug"):
|
||||
url = f"https://www.yuque.com/{doc.get('namespace', '').replace('/', '/')}/{doc.get('slug', '')}"
|
||||
return doc_id, url or ""
|
||||
|
||||
|
||||
async def yuque_update_doc(
|
||||
token: str, repo_id_or_namespace: str, doc_id: str, title: str, body_md: str
|
||||
) -> str:
|
||||
"""更新语雀文档。返回文档 URL。"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"X-Auth-Token": token,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "OpsCore-CloudDoc/1.0",
|
||||
}
|
||||
r = await client.put(
|
||||
f"{YUQUE_BASE}/repos/{repo_id_or_namespace}/docs/{doc_id}",
|
||||
headers=headers,
|
||||
json={
|
||||
"title": title[:100] or "未命名",
|
||||
"body": body_md,
|
||||
"format": "markdown",
|
||||
},
|
||||
timeout=15.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
doc = data.get("data", {})
|
||||
return doc.get("url", "") or f"https://www.yuque.com/docs/{doc_id}"
|
||||
|
||||
|
||||
async def yuque_list_docs(token: str, repo_id_or_namespace: str) -> List[Dict[str, Any]]:
|
||||
"""获取知识库文档列表。"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"X-Auth-Token": token,
|
||||
"User-Agent": "OpsCore-CloudDoc/1.0",
|
||||
}
|
||||
r = await client.get(
|
||||
f"{YUQUE_BASE}/repos/{repo_id_or_namespace}/docs",
|
||||
headers=headers,
|
||||
timeout=10.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
return data.get("data", [])
|
||||
|
||||
|
||||
# --------------- 腾讯文档(占位) ---------------
|
||||
|
||||
|
||||
async def tencent_create_doc(client_id: str, client_secret: str, title: str, body_md: str) -> Tuple[str, str]:
|
||||
"""
|
||||
腾讯文档需 OAuth 用户授权与文件创建 API,此处返回占位。
|
||||
正式接入需在腾讯开放平台创建应用并走 OAuth 流程。
|
||||
"""
|
||||
raise RuntimeError(
|
||||
"腾讯文档 Open API 需在开放平台配置 OAuth 并获取用户授权;当前版本请先用飞书或语雀推送。"
|
||||
)
|
||||
|
||||
|
||||
# --------------- 统一入口 ---------------
|
||||
|
||||
|
||||
class CloudDocManager:
|
||||
"""统一封装:读取配置并执行创建/更新,支持增量(有 cloud_doc_id 则更新)。"""
|
||||
|
||||
def __init__(self, credentials: Dict[str, Dict[str, str]]):
|
||||
self.credentials = credentials
|
||||
|
||||
async def push_markdown(
|
||||
self,
|
||||
platform: str,
|
||||
title: str,
|
||||
body_md: str,
|
||||
existing_doc_id: str | None = None,
|
||||
extra: Dict[str, str] | None = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
将 Markdown 推送到指定平台。若 existing_doc_id 存在则更新,否则创建。
|
||||
返回 (cloud_doc_id, url)。
|
||||
extra: 平台相关参数,如 yuque 的 default_repo。
|
||||
"""
|
||||
extra = extra or {}
|
||||
if platform == "feishu":
|
||||
cred = self.credentials.get("feishu") or {}
|
||||
app_id = (cred.get("app_id") or "").strip()
|
||||
app_secret = (cred.get("app_secret") or "").strip()
|
||||
if not app_id or not app_secret:
|
||||
raise RuntimeError("请先在设置中配置飞书 App ID 与 App Secret")
|
||||
token = await get_feishu_tenant_token(app_id, app_secret)
|
||||
if existing_doc_id:
|
||||
url = await feishu_update_doc(token, existing_doc_id, body_md)
|
||||
return existing_doc_id, url
|
||||
return await feishu_create_doc(token, title, body_md)
|
||||
|
||||
if platform == "yuque":
|
||||
cred = self.credentials.get("yuque") or {}
|
||||
token = (cred.get("token") or "").strip()
|
||||
default_repo = (cred.get("default_repo") or extra.get("repo") or "").strip()
|
||||
if not token:
|
||||
raise RuntimeError("请先在设置中配置语雀 Personal Access Token")
|
||||
if not default_repo:
|
||||
raise RuntimeError("请先在设置中配置语雀默认知识库(namespace,如 user/repo)")
|
||||
if existing_doc_id:
|
||||
url = await yuque_update_doc(token, default_repo, existing_doc_id, title, body_md)
|
||||
return existing_doc_id, url
|
||||
return await yuque_create_doc(token, default_repo, title, body_md)
|
||||
|
||||
if platform == "tencent":
|
||||
await tencent_create_doc("", "", title, body_md)
|
||||
return "", ""
|
||||
|
||||
raise RuntimeError(f"不支持的平台: {platform}")
|
||||
@@ -44,7 +44,12 @@ async def generate_quote_excel(
|
||||
# Assume the first worksheet is used for the quote.
|
||||
ws = wb.active
|
||||
|
||||
modules: List[Dict[str, Any]] = project_data.get("modules", [])
|
||||
raw_modules: List[Any] = project_data.get("modules", [])
|
||||
# Normalize: only dicts have .get(); coerce others to a minimal dict
|
||||
modules: List[Dict[str, Any]] = [
|
||||
m if isinstance(m, dict) else {"name": str(m) or f"模块 {i}"}
|
||||
for i, m in enumerate(raw_modules, start=1)
|
||||
]
|
||||
total_amount = project_data.get("total_amount")
|
||||
total_hours = project_data.get("total_estimated_hours")
|
||||
notes = project_data.get("notes")
|
||||
@@ -157,7 +162,11 @@ async def generate_quote_pdf_from_data(
|
||||
|
||||
c.setFont("Helvetica", 10)
|
||||
|
||||
modules: List[Dict[str, Any]] = project_data.get("modules", [])
|
||||
raw_modules: List[Any] = project_data.get("modules", [])
|
||||
modules = [
|
||||
m if isinstance(m, dict) else {"name": str(m) or f"模块 {i}"}
|
||||
for i, m in enumerate(raw_modules, start=1)
|
||||
]
|
||||
for idx, module in enumerate(modules, start=1):
|
||||
name = module.get("name", "")
|
||||
hours = module.get("estimated_hours", "")
|
||||
|
||||
@@ -1,17 +1,34 @@
|
||||
import asyncio
|
||||
import email
|
||||
import hashlib
|
||||
import imaplib
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import re
|
||||
import sqlite3
|
||||
import ssl
|
||||
from datetime import date, datetime
|
||||
from email.header import decode_header
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
# Ensure IMAP ID command is recognised by imaplib so we can spoof a
|
||||
# desktop mail client (Foxmail/Outlook) for providers like NetEase/163.
|
||||
imaplib.Commands["ID"] = ("NONAUTH", "AUTH", "SELECTED")
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.app.db import SessionLocal
|
||||
from backend.app.models import FinanceRecord
|
||||
|
||||
|
||||
FINANCE_BASE_DIR = Path("data/finance")
|
||||
SYNC_DB_PATH = Path("data/finance/sync_history.db")
|
||||
|
||||
# Folder names for classification (invoices, receipts, statements)
|
||||
INVOICES_DIR = "invoices"
|
||||
RECEIPTS_DIR = "receipts"
|
||||
STATEMENTS_DIR = "statements"
|
||||
|
||||
|
||||
def _decode_header_value(value: str | None) -> str:
|
||||
@@ -27,17 +44,21 @@ def _decode_header_value(value: str | None) -> str:
|
||||
return decoded
|
||||
|
||||
|
||||
def _classify_type(subject: str) -> str:
|
||||
def _classify_type(subject: str, filename: str) -> str:
|
||||
"""
|
||||
Classify finance document type based on subject keywords.
|
||||
Classify finance document type. Returns: invoices, receipts, statements, others.
|
||||
Maps to folders: invoices/, receipts/, statements/.
|
||||
"""
|
||||
subject_lower = subject.lower()
|
||||
text = f"{subject} {filename}".lower()
|
||||
# 发票 / 开票类
|
||||
if any(k in subject for k in ["发票", "开票", "票据", "invoice"]):
|
||||
if any(k in text for k in ["发票", "开票", "票据", "invoice", "fapiao"]):
|
||||
return "invoices"
|
||||
# 回执
|
||||
if any(k in text for k in ["回执", "签收单", "receipt"]):
|
||||
return "receipts"
|
||||
# 银行流水 / 账户明细 / 对公活期等
|
||||
if any(
|
||||
k in subject
|
||||
k in text
|
||||
for k in [
|
||||
"流水",
|
||||
"活期",
|
||||
@@ -50,9 +71,7 @@ def _classify_type(subject: str) -> str:
|
||||
"statement",
|
||||
]
|
||||
):
|
||||
return "bank_records"
|
||||
if any(k in subject for k in ["回执", "receipt"]):
|
||||
return "receipts"
|
||||
return "statements"
|
||||
return "others"
|
||||
|
||||
|
||||
@@ -71,132 +90,474 @@ def _parse_email_date(msg: email.message.Message) -> datetime:
|
||||
return dt
|
||||
|
||||
|
||||
def _run_invoice_ocr_sync(file_path: str, mime: str, raw_bytes: bytes) -> Tuple[float | None, str | None]:
|
||||
"""Run extract_invoice_metadata from a sync context (new event loop). Handles PDF via first page image."""
|
||||
from backend.app.services.ai_service import extract_invoice_metadata
|
||||
from backend.app.services.invoice_upload import _pdf_first_page_to_image
|
||||
|
||||
if "pdf" in (mime or "").lower() or Path(file_path).suffix.lower() == ".pdf":
|
||||
img_result = _pdf_first_page_to_image(raw_bytes)
|
||||
if img_result:
|
||||
image_bytes, img_mime = img_result
|
||||
raw_bytes, mime = image_bytes, img_mime
|
||||
# else keep raw_bytes and try anyway (may fail)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(extract_invoice_metadata(raw_bytes, mime))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def _rename_invoice_file(
|
||||
file_path: str,
|
||||
amount: float | None,
|
||||
billing_date: date | None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Rename invoice file to YYYYMMDD_金额_原文件名.
|
||||
Returns (new_file_name, new_file_path).
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return (path.name, file_path)
|
||||
date_str = (billing_date or date.today()).strftime("%Y%m%d")
|
||||
amount_str = f"{amount:.2f}" if amount is not None else "0.00"
|
||||
# Sanitize original name: take stem, limit length
|
||||
orig_stem = path.stem[: 80] if len(path.stem) > 80 else path.stem
|
||||
suffix = path.suffix
|
||||
new_name = f"{date_str}_{amount_str}_{orig_stem}{suffix}"
|
||||
new_path = path.parent / new_name
|
||||
counter = 1
|
||||
while new_path.exists():
|
||||
new_path = path.parent / f"{date_str}_{amount_str}_{orig_stem}_{counter}{suffix}"
|
||||
counter += 1
|
||||
path.rename(new_path)
|
||||
return (new_path.name, str(new_path))
|
||||
|
||||
|
||||
def _ensure_sync_history_table(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS attachment_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id TEXT,
|
||||
file_hash TEXT NOT NULL,
|
||||
month TEXT,
|
||||
doc_type TEXT,
|
||||
file_name TEXT,
|
||||
file_path TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(message_id, file_hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _has_sync_history() -> bool:
|
||||
"""是否有过同步记录;无记录视为首次同步,需拉全量;有记录则只拉增量(UNSEEN)。"""
|
||||
if not SYNC_DB_PATH.exists():
|
||||
return False
|
||||
try:
|
||||
conn = sqlite3.connect(SYNC_DB_PATH)
|
||||
try:
|
||||
cur = conn.execute("SELECT 1 FROM attachment_history LIMIT 1")
|
||||
return cur.fetchone() is not None
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _save_attachment(
|
||||
msg: email.message.Message,
|
||||
month_str: str,
|
||||
doc_type: str,
|
||||
) -> List[Tuple[str, str]]:
|
||||
) -> List[Tuple[str, str, str, bytes, str]]:
|
||||
"""
|
||||
Save PDF/image attachments and return list of (file_name, file_path).
|
||||
Save PDF/image attachments.
|
||||
Returns list of (file_name, file_path, mime, raw_bytes, doc_type).
|
||||
raw_bytes kept for invoice OCR when doc_type == invoices.
|
||||
|
||||
同时使用 data/finance/sync_history.db 做增量去重:
|
||||
- 以 (message_id, MD5(content)) 为唯一键,避免重复保存相同附件。
|
||||
"""
|
||||
saved: List[Tuple[str, str]] = []
|
||||
base_dir = _ensure_month_dir(month_str, doc_type)
|
||||
saved: List[Tuple[str, str, str, bytes, str]] = []
|
||||
|
||||
for part in msg.walk():
|
||||
content_disposition = part.get("Content-Disposition", "")
|
||||
if "attachment" not in content_disposition:
|
||||
continue
|
||||
msg_id = msg.get("Message-ID") or ""
|
||||
subject = _decode_header_value(msg.get("Subject"))
|
||||
|
||||
filename = part.get_filename()
|
||||
filename = _decode_header_value(filename)
|
||||
if not filename:
|
||||
continue
|
||||
SYNC_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(SYNC_DB_PATH)
|
||||
try:
|
||||
_ensure_sync_history_table(conn)
|
||||
|
||||
content_type = part.get_content_type()
|
||||
maintype = part.get_content_maintype()
|
||||
for part in msg.walk():
|
||||
content_disposition = part.get("Content-Disposition", "")
|
||||
if "attachment" not in content_disposition:
|
||||
continue
|
||||
|
||||
# Accept pdf and common images
|
||||
if maintype not in ("application", "image"):
|
||||
continue
|
||||
filename = part.get_filename()
|
||||
filename = _decode_header_value(filename)
|
||||
if not filename:
|
||||
continue
|
||||
|
||||
data = part.get_payload(decode=True)
|
||||
if not data:
|
||||
continue
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext not in (".pdf", ".jpg", ".jpeg", ".png", ".xlsx"):
|
||||
continue
|
||||
|
||||
file_path = base_dir / filename
|
||||
# Ensure unique filename
|
||||
counter = 1
|
||||
while file_path.exists():
|
||||
stem = file_path.stem
|
||||
suffix = file_path.suffix
|
||||
file_path = base_dir / f"{stem}_{counter}{suffix}"
|
||||
counter += 1
|
||||
maintype = part.get_content_maintype()
|
||||
if maintype not in ("application", "image"):
|
||||
continue
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(data)
|
||||
data = part.get_payload(decode=True)
|
||||
if not data:
|
||||
continue
|
||||
|
||||
saved.append((filename, str(file_path)))
|
||||
# 分类:基于主题 + 文件名
|
||||
doc_type = _classify_type(subject, filename)
|
||||
base_dir = _ensure_month_dir(month_str, doc_type)
|
||||
|
||||
# 增量去重:根据 (message_id, md5) 判断是否已同步过
|
||||
file_hash = hashlib.md5(data).hexdigest() # nosec - content hash only
|
||||
cur = conn.execute(
|
||||
"SELECT 1 FROM attachment_history WHERE message_id = ? AND file_hash = ?",
|
||||
(msg_id, file_hash),
|
||||
)
|
||||
if cur.fetchone():
|
||||
continue
|
||||
|
||||
mime = part.get_content_type() or "application/octet-stream"
|
||||
file_path = base_dir / filename
|
||||
counter = 1
|
||||
while file_path.exists():
|
||||
stem, suffix = file_path.stem, file_path.suffix
|
||||
file_path = base_dir / f"{stem}_{counter}{suffix}"
|
||||
counter += 1
|
||||
|
||||
file_path.write_bytes(data)
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO attachment_history
|
||||
(message_id, file_hash, month, doc_type, file_name, file_path)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(msg_id, file_hash, month_str, doc_type, file_path.name, str(file_path)),
|
||||
)
|
||||
|
||||
saved.append((file_path.name, str(file_path), mime, data, doc_type))
|
||||
|
||||
finally:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def _decode_imap_utf7(s: str | bytes) -> str:
|
||||
"""Decode IMAP4 UTF-7 mailbox name (RFC 3501). Returns decoded string."""
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode("ascii", errors="replace")
|
||||
if "&" not in s:
|
||||
return s
|
||||
parts = s.split("&")
|
||||
out = [parts[0]]
|
||||
for i in range(1, len(parts)):
|
||||
chunk = parts[i]
|
||||
if "-" in chunk:
|
||||
u, rest = chunk.split("-", 1)
|
||||
if u == "":
|
||||
out.append("&")
|
||||
else:
|
||||
try:
|
||||
# IMAP UTF-7: &BASE64- where BASE64 is modified (,+ instead of /,=)
|
||||
pad = (4 - len(u) % 4) % 4
|
||||
b = (u + "=" * pad).translate(str.maketrans(",+", "/="))
|
||||
decoded = __import__("base64").b64decode(b).decode("utf-16-be")
|
||||
out.append(decoded)
|
||||
except Exception:
|
||||
out.append("&" + chunk)
|
||||
out.append(rest)
|
||||
else:
|
||||
out.append("&" + chunk)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _parse_list_response(data: List[bytes]) -> List[Tuple[str, str]]:
|
||||
"""Parse imap.list() response to [(raw_name, decoded_name), ...]. Format: (flags) \"delim\" \"mailbox\"."""
|
||||
import shlex
|
||||
result: List[Tuple[str, str]] = []
|
||||
for line in data:
|
||||
if not isinstance(line, bytes):
|
||||
continue
|
||||
try:
|
||||
line_str = line.decode("ascii", errors="replace")
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
parts = shlex.split(line_str)
|
||||
except ValueError:
|
||||
continue
|
||||
if not parts:
|
||||
continue
|
||||
# Mailbox name is the last part (RFC 3501 LIST: (attrs) delim name)
|
||||
raw = parts[-1]
|
||||
decoded = _decode_imap_utf7(raw)
|
||||
result.append((raw, decoded))
|
||||
return result
|
||||
|
||||
|
||||
def _list_mailboxes(imap: imaplib.IMAP4_SSL) -> List[Tuple[str, str]]:
|
||||
"""List all mailboxes. Returns [(raw_name, decoded_name), ...]."""
|
||||
status, data = imap.list()
|
||||
if status != "OK" or not data:
|
||||
return []
|
||||
return _parse_list_response(data)
|
||||
|
||||
|
||||
def list_mailboxes_for_config(host: str, port: int, user: str, password: str) -> List[Tuple[str, str]]:
|
||||
"""Connect and list all mailboxes (for dropdown). Returns [(raw_name, decoded_name), ...]."""
|
||||
with imaplib.IMAP4_SSL(host, int(port)) as imap:
|
||||
imap.login(user, password)
|
||||
return _list_mailboxes(imap)
|
||||
|
||||
|
||||
def _select_mailbox(imap: imaplib.IMAP4_SSL, mailbox: str) -> bool:
|
||||
"""
|
||||
Robust mailbox selection with deep discovery scan.
|
||||
|
||||
Strategy:
|
||||
1. LIST all folders, log raw lines for debugging.
|
||||
2. Look for entry containing '\\Inbox' flag; if found, SELECT that folder.
|
||||
3. Try standard candidates: user-configured name / INBOX / common UTF-7 收件箱编码.
|
||||
4. As last resort, attempt SELECT on every listed folder and log which succeed/fail.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
name = (mailbox or "INBOX").strip() or "INBOX"
|
||||
|
||||
# 1) Discovery scan: list all folders and log raw entries
|
||||
try:
|
||||
status, data = imap.list()
|
||||
if status != "OK" or not data:
|
||||
logger.warning("IMAP LIST returned no data or non-OK status: %s", status)
|
||||
data = []
|
||||
except Exception as exc:
|
||||
logger.error("IMAP LIST failed: %s", exc)
|
||||
data = []
|
||||
|
||||
logger.info("IMAP Discovery Scan: listing all folders for mailbox=%s", name)
|
||||
for raw in data:
|
||||
logger.info("IMAP FOLDER RAW: %r", raw)
|
||||
|
||||
# 2) 优先按 \\Inbox 属性查找“真正的收件箱”
|
||||
inbox_candidates: list[str] = []
|
||||
for raw in data:
|
||||
line = raw.decode("utf-8", errors="ignore") if isinstance(raw, bytes) else str(raw)
|
||||
if "\\Inbox" not in line:
|
||||
continue
|
||||
m = re.search(r'"([^"]+)"\s*$', line)
|
||||
if not m:
|
||||
continue
|
||||
folder_name = m.group(1)
|
||||
inbox_candidates.append(folder_name)
|
||||
|
||||
# 3) 补充常规候选:配置名 / INBOX / 常见 UTF-7 收件箱编码
|
||||
primary_names = [name, "INBOX"]
|
||||
utf7_names = ["&XfJT0ZTx-"]
|
||||
for nm in primary_names + utf7_names:
|
||||
if nm not in inbox_candidates:
|
||||
inbox_candidates.append(nm)
|
||||
|
||||
logger.info("IMAP Inbox candidate list (ordered): %r", inbox_candidates)
|
||||
|
||||
# 4) 依次尝试候选收件箱
|
||||
for candidate in inbox_candidates:
|
||||
for readonly in (False, True):
|
||||
try:
|
||||
status, _ = imap.select(candidate, readonly=readonly)
|
||||
logger.info(
|
||||
"IMAP SELECT candidate=%r readonly=%s -> %s", candidate, readonly, status
|
||||
)
|
||||
if status == "OK":
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"IMAP SELECT failed for candidate=%r readonly=%s: %s",
|
||||
candidate,
|
||||
readonly,
|
||||
exc,
|
||||
)
|
||||
|
||||
# 5) 最后手段:尝试 LIST 返回的每一个文件夹
|
||||
logger.info("IMAP Fallback: trying SELECT on every listed folder...")
|
||||
for raw in data:
|
||||
line = raw.decode("utf-8", errors="ignore") if isinstance(raw, bytes) else str(raw)
|
||||
m = re.search(r'"([^"]+)"\s*$', line)
|
||||
if not m:
|
||||
continue
|
||||
folder_name = m.group(1)
|
||||
for readonly in (False, True):
|
||||
try:
|
||||
status, _ = imap.select(folder_name, readonly=readonly)
|
||||
logger.info(
|
||||
"IMAP SELECT fallback folder=%r readonly=%s -> %s",
|
||||
folder_name,
|
||||
readonly,
|
||||
status,
|
||||
)
|
||||
if status == "OK":
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"IMAP SELECT fallback failed for folder=%r readonly=%s: %s",
|
||||
folder_name,
|
||||
readonly,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.error("IMAP: unable to SELECT any inbox-like folder for mailbox=%s", name)
|
||||
return False
|
||||
|
||||
|
||||
def _sync_one_account(config: Dict[str, Any], db: Session, results: List[Dict[str, Any]]) -> None:
|
||||
host = config.get("host")
|
||||
user = config.get("user")
|
||||
password = config.get("password")
|
||||
port = int(config.get("port", 993))
|
||||
mailbox = (config.get("mailbox") or "INBOX").strip() or "INBOX"
|
||||
|
||||
if not all([host, user, password]):
|
||||
return
|
||||
|
||||
# Use strict TLS context for modern protocols (TLS 1.2+)
|
||||
tls_context = ssl.create_default_context()
|
||||
|
||||
with imaplib.IMAP4_SSL(host, port, ssl_context=tls_context) as imap:
|
||||
# Enable low-level IMAP debug output to backend logs to help diagnose
|
||||
# handshake / protocol / mailbox selection issues with specific providers.
|
||||
imap.debug = 4
|
||||
imap.login(user, password)
|
||||
# NetEase / 163 等会对未知客户端静默限制 SELECT,这里通过 ID 命令伪装为常见桌面客户端。
|
||||
try:
|
||||
logger = logging.getLogger(__name__)
|
||||
id_str = (
|
||||
'("name" "Foxmail" '
|
||||
'"version" "7.2.25.170" '
|
||||
'"vendor" "Tencent" '
|
||||
'"os" "Windows" '
|
||||
'"os-version" "10.0")'
|
||||
)
|
||||
logger.info("IMAP sending Foxmail-style ID: %s", id_str)
|
||||
# Use low-level command so it works across Python versions.
|
||||
typ, dat = imap._command("ID", id_str) # type: ignore[attr-defined]
|
||||
logger.info("IMAP ID command result: %s %r", typ, dat)
|
||||
except Exception as exc:
|
||||
# ID 失败不应阻断登录,只记录日志,方便后续排查。
|
||||
logging.getLogger(__name__).warning("IMAP ID command failed: %s", exc)
|
||||
if not _select_mailbox(imap, mailbox):
|
||||
raise RuntimeError(
|
||||
f"无法选择邮箱「{mailbox}」,请检查该账户的 Mailbox 配置(如 163 使用 INBOX)"
|
||||
)
|
||||
|
||||
# 首次同步(历史库无记录):拉取全部邮件中的附件,由 attachment_history 去重
|
||||
# 已有历史:只拉取未读邮件,避免重复拉取
|
||||
is_first_sync = not _has_sync_history()
|
||||
search_criterion = "ALL" if is_first_sync else "UNSEEN"
|
||||
logging.getLogger(__name__).info(
|
||||
"Finance sync: %s (criterion=%s)",
|
||||
"全量" if is_first_sync else "增量",
|
||||
search_criterion,
|
||||
)
|
||||
status, data = imap.search(None, search_criterion)
|
||||
if status != "OK":
|
||||
return
|
||||
|
||||
id_list = data[0].split()
|
||||
for msg_id in id_list:
|
||||
status, msg_data = imap.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_email)
|
||||
dt = _parse_email_date(msg)
|
||||
month_str = dt.strftime("%Y-%m")
|
||||
|
||||
saved = _save_attachment(msg, month_str)
|
||||
for file_name, file_path, mime, raw_bytes, doc_type in saved:
|
||||
final_name = file_name
|
||||
final_path = file_path
|
||||
amount = None
|
||||
billing_date = None
|
||||
|
||||
if doc_type == "invoices":
|
||||
amount, date_str = _run_invoice_ocr_sync(file_path, mime, raw_bytes)
|
||||
if date_str:
|
||||
try:
|
||||
billing_date = date.fromisoformat(date_str[:10])
|
||||
except ValueError:
|
||||
billing_date = date.today()
|
||||
else:
|
||||
billing_date = date.today()
|
||||
final_name, final_path = _rename_invoice_file(
|
||||
file_path, amount, billing_date
|
||||
)
|
||||
|
||||
record = FinanceRecord(
|
||||
month=month_str,
|
||||
type=doc_type,
|
||||
file_name=final_name,
|
||||
file_path=final_path,
|
||||
amount=amount,
|
||||
billing_date=billing_date,
|
||||
)
|
||||
db.add(record)
|
||||
db.flush()
|
||||
results.append({
|
||||
"id": record.id,
|
||||
"month": record.month,
|
||||
"type": record.type,
|
||||
"file_name": record.file_name,
|
||||
"file_path": record.file_path,
|
||||
})
|
||||
|
||||
imap.store(msg_id, "+FLAGS", "\\Seen \\Flagged")
|
||||
|
||||
|
||||
async def sync_finance_emails() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Connect to IMAP, fetch unread finance-related emails, download attachments,
|
||||
save to filesystem and record FinanceRecord entries.
|
||||
Sync from all active email configs (data/email_configs.json).
|
||||
Falls back to env vars if no configs. Classifies into invoices/, receipts/, statements/.
|
||||
Invoices are renamed to YYYYMMDD_金额_原文件名 using OCR.
|
||||
"""
|
||||
|
||||
def _sync() -> List[Dict[str, Any]]:
|
||||
host = os.getenv("IMAP_HOST")
|
||||
user = os.getenv("IMAP_USER")
|
||||
password = os.getenv("IMAP_PASSWORD")
|
||||
port = int(os.getenv("IMAP_PORT", "993"))
|
||||
mailbox = os.getenv("IMAP_MAILBOX", "INBOX")
|
||||
from backend.app.routers.email_configs import get_email_configs_for_sync
|
||||
|
||||
if not all([host, user, password]):
|
||||
raise RuntimeError("IMAP_HOST, IMAP_USER, IMAP_PASSWORD must be set.")
|
||||
configs = get_email_configs_for_sync()
|
||||
if not configs:
|
||||
raise RuntimeError("未配置邮箱。请在 设置 → 邮箱账户 中添加,或配置 IMAP_* 环境变量。")
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
errors: List[str] = []
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for config in configs:
|
||||
try:
|
||||
_sync_one_account(config, db, results)
|
||||
except Exception as e:
|
||||
# 不让单个账户的异常中断全部同步,记录错误并继续其他账户。
|
||||
user = config.get("user", "") or config.get("id", "")
|
||||
errors.append(f"同步账户 {user} 失败: {e}")
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
with imaplib.IMAP4_SSL(host, port) as imap:
|
||||
imap.login(user, password)
|
||||
imap.select(mailbox)
|
||||
|
||||
# Search for UNSEEN emails with finance related keywords in subject.
|
||||
# Note: IMAP SEARCH is limited; here we search UNSEEN first then filter in Python.
|
||||
status, data = imap.search(None, "UNSEEN")
|
||||
if status != "OK":
|
||||
return results
|
||||
|
||||
id_list = data[0].split()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for msg_id in id_list:
|
||||
status, msg_data = imap.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_email)
|
||||
|
||||
subject = _decode_header_value(msg.get("Subject"))
|
||||
doc_type = _classify_type(subject)
|
||||
|
||||
# Filter by keywords first
|
||||
if doc_type == "others":
|
||||
continue
|
||||
|
||||
dt = _parse_email_date(msg)
|
||||
month_str = dt.strftime("%Y-%m")
|
||||
|
||||
saved_files = _save_attachment(msg, month_str, doc_type)
|
||||
for file_name, file_path in saved_files:
|
||||
record = FinanceRecord(
|
||||
month=month_str,
|
||||
type=doc_type,
|
||||
file_name=file_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
# NOTE: created_at defaults at DB layer
|
||||
db.add(record)
|
||||
db.flush()
|
||||
|
||||
results.append(
|
||||
{
|
||||
"id": record.id,
|
||||
"month": record.month,
|
||||
"type": record.type,
|
||||
"file_name": record.file_name,
|
||||
"file_path": record.file_path,
|
||||
}
|
||||
)
|
||||
|
||||
# Mark email as seen and flagged to avoid re-processing
|
||||
imap.store(msg_id, "+FLAGS", "\\Seen \\Flagged")
|
||||
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
if not results and errors:
|
||||
# 所有账户都失败了,整体报错,前端可显示详细原因。
|
||||
raise RuntimeError("; ".join(errors))
|
||||
|
||||
return results
|
||||
|
||||
@@ -205,7 +566,8 @@ async def sync_finance_emails() -> List[Dict[str, Any]]:
|
||||
|
||||
async def create_monthly_zip(month_str: str) -> str:
|
||||
"""
|
||||
Zip the finance folder for a given month (YYYY-MM) and return the zip path.
|
||||
Zip the finance folder for a given month (YYYY-MM).
|
||||
Preserves folder structure (invoices/, receipts/, statements/, manual/) inside the zip.
|
||||
"""
|
||||
import zipfile
|
||||
|
||||
@@ -227,4 +589,3 @@ async def create_monthly_zip(month_str: str) -> str:
|
||||
return str(zip_path)
|
||||
|
||||
return await asyncio.to_thread(_zip)
|
||||
|
||||
|
||||
90
backend/app/services/invoice_upload.py
Normal file
90
backend/app/services/invoice_upload.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Manual invoice upload: save file, optionally run AI vision to extract amount/date.
|
||||
"""
|
||||
import io
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
from backend.app.services.ai_service import extract_invoice_metadata
|
||||
|
||||
FINANCE_BASE = Path("data/finance")
|
||||
ALLOWED_IMAGE = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
ALLOWED_PDF = {".pdf"}
|
||||
|
||||
|
||||
def _current_month() -> str:
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
|
||||
def _pdf_first_page_to_image(pdf_bytes: bytes) -> Tuple[bytes, str] | None:
|
||||
"""Render first page of PDF to PNG bytes. Returns (bytes, 'image/png') or None on error."""
|
||||
try:
|
||||
import fitz
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
if doc.page_count == 0:
|
||||
doc.close()
|
||||
return None
|
||||
page = doc[0]
|
||||
pix = page.get_pixmap(dpi=150)
|
||||
png_bytes = pix.tobytes("png")
|
||||
doc.close()
|
||||
return (png_bytes, "image/png")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def process_invoice_upload(
|
||||
file: UploadFile,
|
||||
) -> Tuple[str, str, str, float | None, date | None]:
|
||||
"""
|
||||
Save uploaded file to data/finance/{YYYY-MM}/manual/, run OCR for amount/date.
|
||||
Returns (file_name, file_path, month_str, amount, billing_date).
|
||||
"""
|
||||
month_str = _current_month()
|
||||
manual_dir = FINANCE_BASE / month_str / "manual"
|
||||
manual_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
raw = await file.read()
|
||||
filename = file.filename or "upload"
|
||||
suf = Path(filename).suffix.lower()
|
||||
|
||||
if suf in ALLOWED_IMAGE:
|
||||
image_bytes, mime = raw, (file.content_type or "image/jpeg")
|
||||
if "png" in (suf or ""):
|
||||
mime = "image/png"
|
||||
amount, date_str = await extract_invoice_metadata(image_bytes, mime)
|
||||
elif suf in ALLOWED_PDF:
|
||||
image_result = _pdf_first_page_to_image(raw)
|
||||
if image_result:
|
||||
image_bytes, mime = image_result
|
||||
amount, date_str = await extract_invoice_metadata(image_bytes, mime)
|
||||
else:
|
||||
amount, date_str = None, None
|
||||
# Save original PDF
|
||||
else:
|
||||
amount, date_str = None, None
|
||||
|
||||
# Unique filename
|
||||
dest = manual_dir / filename
|
||||
counter = 1
|
||||
while dest.exists():
|
||||
dest = manual_dir / f"{dest.stem}_{counter}{dest.suffix}"
|
||||
counter += 1
|
||||
|
||||
dest.write_bytes(raw)
|
||||
file_path = str(dest)
|
||||
file_name = dest.name
|
||||
|
||||
billing_date = None
|
||||
if date_str:
|
||||
try:
|
||||
billing_date = date.fromisoformat(date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
if billing_date is None:
|
||||
billing_date = date.today()
|
||||
|
||||
return (file_name, file_path, month_str, amount, billing_date)
|
||||
Reference in New Issue
Block a user