Files
AiTool/backend/app/routers/finance.py
2026-03-18 17:01:10 +08:00

241 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import List
from datetime import date
import os
from pathlib import Path
from fastapi import APIRouter, Body, Depends, File, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from backend.app.db import get_db
from backend.app import models
from backend.app.schemas import (
FinanceRecordRead,
FinanceRecordUpdate,
FinanceBatchDeleteRequest,
FinanceSyncRequest,
FinanceSyncResponse,
FinanceSyncResult,
FinanceUploadResponse,
)
from backend.app.services.email_service import create_monthly_zip, sync_finance_emails
from backend.app.services.invoice_upload import process_invoice_upload
router = APIRouter(prefix="/finance", tags=["finance"])
@router.post("/sync", response_model=FinanceSyncResponse)
async def sync_finance(payload: FinanceSyncRequest = Body(default=FinanceSyncRequest())):
try:
items_raw = await sync_finance_emails(
mode=payload.mode,
start_date=payload.start_date,
end_date=payload.end_date,
doc_types=payload.doc_types,
)
except RuntimeError as exc:
# 邮箱配置/连接等问题属于可预期的业务错误,用 400 让前端直接展示原因,而不是泛化为 500。
raise HTTPException(status_code=400, detail=str(exc)) from exc
details = [FinanceSyncResult(**item) for item in items_raw]
return FinanceSyncResponse(
status="success",
new_files=len(details),
details=details,
)
@router.get("/months", response_model=List[str])
async def list_finance_months(db: Session = Depends(get_db)):
"""List distinct months that have finance records (YYYY-MM), newest first."""
from sqlalchemy import distinct
rows = (
db.query(distinct(models.FinanceRecord.month))
.order_by(models.FinanceRecord.month.desc())
.all()
)
return [r[0] for r in rows]
@router.get("/records", response_model=List[FinanceRecordRead])
async def list_finance_records(
month: str = Query(..., description="YYYY-MM"),
db: Session = Depends(get_db),
):
"""List finance records for a given month."""
records = (
db.query(models.FinanceRecord)
.filter(models.FinanceRecord.month == month)
.order_by(models.FinanceRecord.created_at.desc())
.all()
)
return records
@router.post("/upload", response_model=FinanceUploadResponse, status_code=201)
async def upload_invoice(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Upload an invoice (PDF or image). Saves to data/finance/{YYYY-MM}/manual/, runs AI OCR for amount/date."""
suf = (file.filename or "").lower().split(".")[-1] if "." in (file.filename or "") else ""
allowed = {"pdf", "jpg", "jpeg", "png", "webp"}
if suf not in allowed:
raise HTTPException(400, "仅支持 PDF、JPG、PNG、WEBP 格式")
file_name, file_path, month_str, amount, billing_date = await process_invoice_upload(file)
record = models.FinanceRecord(
month=month_str,
type="manual",
file_name=file_name,
file_path=file_path,
amount=amount,
billing_date=billing_date,
)
db.add(record)
db.commit()
db.refresh(record)
return record
@router.patch("/records/{record_id}", response_model=FinanceRecordRead)
async def update_finance_record(
record_id: int,
payload: FinanceRecordUpdate,
db: Session = Depends(get_db),
):
"""Update amount and/or billing_date of a finance record (e.g. after manual review)."""
record = db.query(models.FinanceRecord).get(record_id)
if not record:
raise HTTPException(404, "记录不存在")
if payload.amount is not None:
record.amount = payload.amount
if payload.billing_date is not None:
record.billing_date = payload.billing_date
db.commit()
db.refresh(record)
return record
@router.delete("/records/{record_id}")
async def delete_finance_record(
record_id: int,
db: Session = Depends(get_db),
):
"""删除单条财务记录及对应文件(若存在)。"""
record = db.query(models.FinanceRecord).get(record_id)
if not record:
raise HTTPException(404, "记录不存在")
file_path = Path(record.file_path)
if not file_path.is_absolute():
file_path = Path(".") / file_path
if file_path.exists():
try:
file_path.unlink()
except OSError:
pass
db.delete(record)
db.commit()
return {"status": "deleted", "id": record_id}
@router.post("/records/batch-delete")
async def batch_delete_finance_records(
payload: FinanceBatchDeleteRequest,
db: Session = Depends(get_db),
):
"""批量删除财务记录及对应文件。"""
if not payload.ids:
return {"status": "ok", "deleted": 0}
records = (
db.query(models.FinanceRecord)
.filter(models.FinanceRecord.id.in_(payload.ids))
.all()
)
for record in records:
file_path = Path(record.file_path)
if not file_path.is_absolute():
file_path = Path(".") / file_path
if file_path.exists():
try:
file_path.unlink()
except OSError:
pass
db.delete(record)
db.commit()
return {"status": "deleted", "deleted": len(records)}
@router.get("/download/{month}")
async def download_finance_month(month: str):
"""
Download a zipped archive for a given month (YYYY-MM).
"""
try:
zip_path = await create_monthly_zip(month)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return FileResponse(
path=zip_path,
media_type="application/zip",
filename=f"finance_{month}.zip",
)
@router.get("/download-range")
async def download_finance_range(
start_date: date = Query(..., description="起始日期 YYYY-MM-DD"),
end_date: date = Query(..., description="结束日期 YYYY-MM-DD含当日"),
only_invoices: bool = Query(True, description="是否仅包含发票类型"),
db: Session = Depends(get_db),
):
"""
按时间范围打包下载发票(默认仅发票,可扩展)。
"""
if end_date < start_date:
raise HTTPException(status_code=400, detail="结束日期不能早于开始日期")
q = db.query(models.FinanceRecord).filter(
models.FinanceRecord.billing_date.isnot(None),
models.FinanceRecord.billing_date >= start_date,
models.FinanceRecord.billing_date <= end_date,
)
if only_invoices:
q = q.filter(models.FinanceRecord.type == "invoices")
records = q.all()
if not records:
raise HTTPException(status_code=404, detail="该时间段内没有可导出的记录")
base_dir = Path("data/finance")
base_dir.mkdir(parents=True, exist_ok=True)
zip_name = f"invoices_{start_date.isoformat()}_{end_date.isoformat()}.zip"
zip_path = base_dir / zip_name
import zipfile
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for r in records:
file_path = Path(r.file_path)
if not file_path.is_absolute():
file_path = Path(".") / file_path
if not file_path.exists():
continue
# 保持月份/类型的相对结构
rel = file_path.relative_to(Path("data")) if "data" in file_path.parts else file_path.name
zf.write(file_path, arcname=rel)
return FileResponse(
path=str(zip_path),
media_type="application/zip",
filename=zip_name,
)