Files
AI-Testing/backend/app/main.py
2026-04-18 20:20:38 +08:00

1073 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import uuid
import zipfile
import base64
from typing import Any
from datetime import date, datetime, timedelta
from io import BytesIO
from pathlib import Path
import httpx
from docx import Document
from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from reportlab.lib.pagesizes import A4
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
from reportlab.pdfgen import canvas
from sqlalchemy import asc, desc, func, inspect, or_, select, text
from sqlalchemy.orm import Session
from .database import Base, engine, get_db
from .models import Mistake, Resource, ScoreRecord
from .schemas import (
AiMistakeAnalysisOut,
AiStudyPlanIn,
AiStudyPlanOut,
IdBatchPayload,
MistakeCreate,
MistakeOut,
MistakeUpdate,
OcrParseIn,
OcrParseOut,
ResourceBatchUpdate,
ResourceCreate,
ResourceOut,
ResourceUpdate,
ScoreRecordCreate,
ScoreRecordOut,
ScoreRecordUpdate,
ScoreStats,
)
Base.metadata.create_all(bind=engine)
def _migrate_mistake_columns() -> None:
inspector = inspect(engine)
if "mistakes" not in inspector.get_table_names():
return
existed = {col["name"] for col in inspector.get_columns("mistakes")}
with engine.begin() as conn:
if "question_content" not in existed:
conn.execute(text("ALTER TABLE mistakes ADD COLUMN question_content TEXT"))
if "answer" not in existed:
conn.execute(text("ALTER TABLE mistakes ADD COLUMN answer TEXT"))
if "explanation" not in existed:
conn.execute(text("ALTER TABLE mistakes ADD COLUMN explanation TEXT"))
_migrate_mistake_columns()
app = FastAPI(title="公考助手 API", version="1.0.0")
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
origins = os.getenv("CORS_ORIGINS", "http://localhost:5173").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=[origin.strip() for origin in origins if origin.strip()],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"status": "ok"}
def _validate_score_date(exam_date: date) -> None:
if exam_date > date.today():
raise HTTPException(status_code=400, detail="考试时间不能晚于今天")
def _query_mistakes_for_export(
db: Session,
category: str | None,
start_date: date | None,
end_date: date | None,
ids: list[int] | None = None,
):
stmt = select(Mistake)
if category:
stmt = stmt.where(Mistake.category == category)
if start_date:
stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time()))
if end_date:
stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time()))
if ids:
stmt = stmt.where(Mistake.id.in_(ids))
items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all()
if len(items) > 200:
raise HTTPException(status_code=400, detail="单次最多导出 200 题")
return items
def _validate_mistake_payload(payload: MistakeCreate | MistakeUpdate) -> None:
has_image = bool((payload.image_url or "").strip())
has_question = bool((payload.question_content or "").strip())
has_answer = bool((payload.answer or "").strip())
if not has_image and not has_question and not has_answer:
raise HTTPException(status_code=400, detail="请上传题目图片或填写试题/答案后再保存")
def _normalize_multiline_text(value: str | None) -> str:
if not value:
return ""
text_value = value.replace("\r\n", "\n").replace("\r", "\n")
lines = [line.strip() for line in text_value.split("\n")]
compact = [line for line in lines if line]
return "\n".join(compact).strip()
def _wrap_pdf_text(text: str, max_width: float, font_name: str = "STSong-Light", font_size: int = 12) -> list[str]:
normalized = _normalize_multiline_text(text)
if not normalized:
return []
wrapped: list[str] = []
for raw_line in normalized.split("\n"):
current = ""
for ch in raw_line:
candidate = f"{current}{ch}"
if pdfmetrics.stringWidth(candidate, font_name, font_size) <= max_width:
current = candidate
else:
if current:
wrapped.append(current)
current = ch
if current:
wrapped.append(current)
return wrapped
def _mistake_export_blocks(item: Mistake, content_mode: str) -> list[str]:
question = _normalize_multiline_text(item.question_content)
answer = _normalize_multiline_text(item.answer)
explanation = _normalize_multiline_text(item.explanation)
if not question:
question = "无题干与选项内容"
blocks: list[str] = [question]
if content_mode == "full":
# 「答案:」「解析:」与正文同一行开头,避免标签单独成行(与题号+题干规则一致)
blocks.append(f"答案: {answer or ''}")
blocks.append(f"解析: {explanation or ''}")
return blocks
def _extract_upload_filename(url: str | None) -> str | None:
if not url or not url.startswith("/uploads/"):
return None
return Path(url).name
def _safe_datetime(value: str | None) -> datetime:
if not value:
return datetime.utcnow()
try:
return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None)
except ValueError:
return datetime.utcnow()
def _safe_date(value: str | None) -> date:
if not value:
return date.today()
try:
return date.fromisoformat(value)
except ValueError:
return date.today()
def _extract_json_text(raw_text: str) -> str:
content = raw_text.strip()
if content.startswith("```"):
lines = content.splitlines()
if lines:
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
content = "\n".join(lines).strip()
return content
def _dump_all_data(db: Session) -> dict:
resources = db.scalars(select(Resource).order_by(asc(Resource.id))).all()
mistakes = db.scalars(select(Mistake).order_by(asc(Mistake.id))).all()
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.id))).all()
return {
"meta": {
"exported_at": datetime.utcnow().isoformat(),
"version": "1.1.0",
},
"resources": [
{
"id": item.id,
"title": item.title,
"resource_type": item.resource_type,
"url": item.url,
"file_name": item.file_name,
"category": item.category,
"tags": item.tags,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in resources
],
"mistakes": [
{
"id": item.id,
"title": item.title,
"image_url": item.image_url,
"category": item.category,
"difficulty": item.difficulty,
"question_content": item.question_content,
"answer": item.answer,
"explanation": item.explanation,
"note": item.note,
"wrong_count": item.wrong_count,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in mistakes
],
"scores": [
{
"id": item.id,
"exam_name": item.exam_name,
"exam_date": item.exam_date.isoformat() if item.exam_date else None,
"total_score": item.total_score,
"module_scores": item.module_scores,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in scores
],
}
def _restore_upload_url_from_zip(url: str | None, zip_ref: zipfile.ZipFile) -> str | None:
if not url:
return None
file_name = _extract_upload_filename(url)
if not file_name:
return url
zip_path = f"uploads/{file_name}"
if zip_path not in zip_ref.namelist():
return url
data = zip_ref.read(zip_path)
target_name = file_name
target_path = UPLOAD_DIR / target_name
if target_path.exists():
target_name = f"{uuid.uuid4().hex}_{file_name}"
target_path = UPLOAD_DIR / target_name
target_path.write_bytes(data)
return f"/uploads/{target_name}"
@app.post("/api/upload")
async def upload_file(file: UploadFile = File(...)):
suffix = Path(file.filename or "").suffix.lower()
allowed = {".pdf", ".doc", ".docx", ".jpg", ".jpeg", ".png", ".webp"}
if suffix not in allowed:
raise HTTPException(status_code=400, detail="不支持的文件类型")
content = await file.read()
if len(content) > 50 * 1024 * 1024:
raise HTTPException(status_code=400, detail="文件不能超过 50MB")
file_name = f"{uuid.uuid4().hex}{suffix}"
target = UPLOAD_DIR / file_name
target.write_bytes(content)
return {"url": f"/uploads/{file_name}", "original_name": file.filename}
@app.get("/api/resources", response_model=list[ResourceOut])
def list_resources(
q: str | None = None,
category: str | None = None,
tags: str | None = None,
resource_type: str | None = None,
sort_by: str = Query("created_at", pattern="^(created_at|title|name)$"),
order: str = Query("desc", pattern="^(asc|desc)$"),
db: Session = Depends(get_db),
):
stmt = select(Resource)
if q:
stmt = stmt.where(
or_(
Resource.title.ilike(f"%{q}%"),
Resource.tags.ilike(f"%{q}%"),
Resource.url.ilike(f"%{q}%"),
)
)
if category:
stmt = stmt.where(Resource.category == category)
if tags:
stmt = stmt.where(Resource.tags.ilike(f"%{tags}%"))
if resource_type:
stmt = stmt.where(Resource.resource_type == resource_type)
sort_col = Resource.created_at if sort_by == "created_at" else Resource.title
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
return db.scalars(stmt).all()
@app.post("/api/resources", response_model=ResourceOut)
def create_resource(payload: ResourceCreate, db: Session = Depends(get_db)):
item = Resource(**payload.model_dump())
db.add(item)
db.commit()
db.refresh(item)
return item
@app.put("/api/resources/{item_id}", response_model=ResourceOut)
def update_resource(item_id: int, payload: ResourceUpdate, db: Session = Depends(get_db)):
item = db.get(Resource, item_id)
if not item:
raise HTTPException(status_code=404, detail="Resource not found")
for k, v in payload.model_dump().items():
setattr(item, k, v)
db.commit()
db.refresh(item)
return item
@app.delete("/api/resources/{item_id}")
def delete_resource(item_id: int, db: Session = Depends(get_db)):
item = db.get(Resource, item_id)
if not item:
raise HTTPException(status_code=404, detail="Resource not found")
db.delete(item)
db.commit()
return {"success": True}
@app.patch("/api/resources/batch")
def batch_update_resources(payload: ResourceBatchUpdate, db: Session = Depends(get_db)):
if not payload.ids:
raise HTTPException(status_code=400, detail="ids 不能为空")
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
for item in items:
if payload.category is not None:
item.category = payload.category
if payload.tags is not None:
item.tags = payload.tags
db.commit()
return {"success": True, "count": len(items)}
@app.post("/api/resources/batch-delete")
def batch_delete_resources(payload: IdBatchPayload, db: Session = Depends(get_db)):
if not payload.ids:
raise HTTPException(status_code=400, detail="ids 不能为空")
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
for item in items:
db.delete(item)
db.commit()
return {"success": True, "count": len(items)}
@app.get("/api/mistakes", response_model=list[MistakeOut])
def list_mistakes(
category: str | None = None,
keyword: str | None = None,
sort_by: str = Query("created_at", pattern="^(created_at|wrong_count)$"),
order: str = Query("desc", pattern="^(asc|desc)$"),
db: Session = Depends(get_db),
):
stmt = select(Mistake)
if category:
stmt = stmt.where(Mistake.category == category)
if keyword:
stmt = stmt.where(
or_(
Mistake.note.ilike(f"%{keyword}%"),
Mistake.title.ilike(f"%{keyword}%"),
Mistake.question_content.ilike(f"%{keyword}%"),
Mistake.answer.ilike(f"%{keyword}%"),
Mistake.explanation.ilike(f"%{keyword}%"),
Mistake.image_url.ilike(f"%{keyword}%"),
)
)
sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
return db.scalars(stmt).all()
@app.post("/api/mistakes", response_model=MistakeOut)
def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
_validate_mistake_payload(payload)
item = Mistake(**payload.model_dump())
db.add(item)
db.commit()
db.refresh(item)
return item
@app.put("/api/mistakes/{item_id}", response_model=MistakeOut)
def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)):
_validate_mistake_payload(payload)
item = db.get(Mistake, item_id)
if not item:
raise HTTPException(status_code=404, detail="Mistake not found")
for k, v in payload.model_dump().items():
setattr(item, k, v)
db.commit()
db.refresh(item)
return item
@app.delete("/api/mistakes/{item_id}")
def delete_mistake(item_id: int, db: Session = Depends(get_db)):
item = db.get(Mistake, item_id)
if not item:
raise HTTPException(status_code=404, detail="Mistake not found")
db.delete(item)
db.commit()
return {"success": True}
@app.get("/api/mistakes/export/pdf")
def export_mistakes_pdf(
category: str | None = None,
start_date: date | None = None,
end_date: date | None = None,
ids: str | None = None,
content_mode: str = Query("full", pattern="^(full|question_only)$"),
db: Session = Depends(get_db),
):
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
buf = BytesIO()
pdf = canvas.Canvas(buf, pagesize=A4)
pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light"))
pdf.setFont("STSong-Light", 12)
y = 800
left = 48
right = 560
max_width = right - left
pdf.drawString(left, y, "公考助手 - 错题导出")
y -= 28
for idx, item in enumerate(items, start=1):
if y < 90:
pdf.showPage()
pdf.setFont("STSong-Light", 12)
y = 800
blocks = _mistake_export_blocks(item, content_mode)
for bi, block in enumerate(blocks):
# 题号与题干同一行开头避免「1.」单独成行
text = f"{idx}. {block}" if bi == 0 else block
lines = _wrap_pdf_text(text, max_width=max_width)
if not lines:
continue
for line in lines:
if y < 70:
pdf.showPage()
pdf.setFont("STSong-Light", 12)
y = 800
pdf.drawString(left, y, line)
y -= 18
y -= 6
y -= 8
pdf.save()
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": 'attachment; filename="mistakes.pdf"'},
)
@app.get("/api/mistakes/export/docx")
def export_mistakes_docx(
category: str | None = None,
start_date: date | None = None,
end_date: date | None = None,
ids: str | None = None,
content_mode: str = Query("full", pattern="^(full|question_only)$"),
db: Session = Depends(get_db),
):
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
doc = Document()
doc.add_heading("公考助手 - 错题导出", level=1)
for idx, item in enumerate(items, start=1):
blocks = _mistake_export_blocks(item, content_mode)
for bi, block in enumerate(blocks):
# 题号与题干同段避免单独一行只有「1.」
para = f"{idx}. {block}" if bi == 0 else block
doc.add_paragraph(para)
buf = BytesIO()
doc.save(buf)
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": 'attachment; filename="mistakes.docx"'},
)
@app.get("/api/scores", response_model=list[ScoreRecordOut])
def list_scores(
start_date: date | None = None,
end_date: date | None = None,
db: Session = Depends(get_db),
):
stmt = select(ScoreRecord)
if start_date:
stmt = stmt.where(ScoreRecord.exam_date >= start_date)
if end_date:
stmt = stmt.where(ScoreRecord.exam_date <= end_date)
stmt = stmt.order_by(asc(ScoreRecord.exam_date))
return db.scalars(stmt).all()
@app.post("/api/scores", response_model=ScoreRecordOut)
def create_score(payload: ScoreRecordCreate, db: Session = Depends(get_db)):
_validate_score_date(payload.exam_date)
item = ScoreRecord(**payload.model_dump())
db.add(item)
db.commit()
db.refresh(item)
return item
@app.put("/api/scores/{item_id}", response_model=ScoreRecordOut)
def update_score(item_id: int, payload: ScoreRecordUpdate, db: Session = Depends(get_db)):
_validate_score_date(payload.exam_date)
item = db.get(ScoreRecord, item_id)
if not item:
raise HTTPException(status_code=404, detail="Score record not found")
for k, v in payload.model_dump().items():
setattr(item, k, v)
db.commit()
db.refresh(item)
return item
@app.delete("/api/scores/{item_id}")
def delete_score(item_id: int, db: Session = Depends(get_db)):
item = db.get(ScoreRecord, item_id)
if not item:
raise HTTPException(status_code=404, detail="Score record not found")
db.delete(item)
db.commit()
return {"success": True}
@app.get("/api/scores/stats", response_model=ScoreStats)
def score_stats(db: Session = Depends(get_db)):
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.exam_date))).all()
if not scores:
return ScoreStats(highest=0, lowest=0, average=0, improvement=0)
highest = max(item.total_score for item in scores)
lowest = min(item.total_score for item in scores)
avg = db.scalar(select(func.avg(ScoreRecord.total_score))) or 0
improvement = scores[-1].total_score - scores[0].total_score
return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement)
def _qwen_base_url() -> str:
return os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1").strip().rstrip("/")
def _get_qwen_api_key() -> str:
"""去除首尾空白与常见误加的引号,避免 .env 里写成 'sk-xxx' 导致鉴权失败。"""
raw = os.getenv("QWEN_API_KEY", "") or ""
return raw.strip().strip('"').strip("'").strip()
def _raise_for_qwen_http_error(resp: httpx.Response, prefix: str) -> None:
"""HTTP 非 2xx 时解析 DashScope 错误体,对 invalid_api_key 返回 401 + 明确说明。"""
if resp.status_code < 300:
return
text = resp.text
try:
data = resp.json()
err = data.get("error")
if isinstance(err, dict):
code = str(err.get("code") or "")
if code == "invalid_api_key":
raise HTTPException(
status_code=401,
detail=(
"阿里云 DashScope API Key 无效或未生效。"
"请到阿里云百炼 / Model Studio 控制台创建 API Key通常以 sk- 开头),"
"写入项目根目录 .env 的 QWEN_API_KEY=,勿加引号;"
"修改后执行: docker compose up -d --build backend"
),
)
msg = err.get("message") or text
raise HTTPException(status_code=502, detail=f"{prefix}: {msg}")
except HTTPException:
raise
except (ValueError, TypeError, KeyError):
pass
raise HTTPException(status_code=502, detail=f"{prefix}: {text[:1200]}")
def _httpx_trust_env() -> bool:
"""默认不信任环境变量中的代理,避免 Docker/IDE 注入空代理导致 ConnectError需走系统代理时设 HTTPX_TRUST_ENV=1。"""
return os.getenv("HTTPX_TRUST_ENV", "0").lower() in ("1", "true", "yes")
def _qwen_http_client(timeout_sec: float = 60.0) -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=httpx.Timeout(timeout_sec, connect=20.0),
trust_env=_httpx_trust_env(),
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
)
def _message_content_to_str(content: Any) -> str:
"""OpenAI 兼容接口里 message.content 可能是 str 或多段结构。"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text" and "text" in part:
parts.append(str(part["text"]))
elif "text" in part:
parts.append(str(part["text"]))
elif isinstance(part, str):
parts.append(part)
return "".join(parts)
return str(content)
def _openai_completion_assistant_text(data: dict) -> str:
"""从 chat/completions JSON 中取出助手文本;若含 error 或无 choices 则抛错。"""
err = data.get("error")
if err is not None:
if isinstance(err, dict):
code = str(err.get("code") or "")
if code == "invalid_api_key":
raise HTTPException(
status_code=401,
detail=(
"阿里云 DashScope API Key 无效。"
"请在 .env 中填写正确的 QWEN_API_KEY 并重启 backend。"
),
)
msg = err.get("message") or err.get("code") or json.dumps(err, ensure_ascii=False)
else:
msg = str(err)
raise HTTPException(status_code=502, detail=f"千问接口错误: {msg}")
choices = data.get("choices")
if not choices:
raise HTTPException(
status_code=502,
detail=f"千问返回异常(无 choices请检查模型名与权限。原始片段: {json.dumps(data, ensure_ascii=False)[:800]}",
)
msg = choices[0].get("message") or {}
return _message_content_to_str(msg.get("content"))
async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
api_key = _get_qwen_api_key()
if not api_key:
return (
"当前未配置千问 API Key已返回本地降级提示。\n"
"请在项目根目录创建 .env 并配置 QWEN_API_KEY 后重启服务。\n"
"示例:\n"
"QWEN_API_KEY=your_qwen_api_key_here\n"
"QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n"
"QWEN_MODEL=qwen-plus"
)
base_url = _qwen_base_url()
model = os.getenv("QWEN_MODEL", "qwen-plus")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": 0.4,
}
url = f"{base_url}/chat/completions"
try:
async with _qwen_http_client(40.0) as client:
resp = await client.post(url, headers=headers, json=payload)
except httpx.ConnectError as e:
raise HTTPException(
status_code=502,
detail=(
f"无法连接千问接口({url})。请检查本机/容器能否访问外网、DNS 是否正常;"
"若在 Docker 中可尝试为 backend 配置 dns 或关闭错误代理。"
"默认已忽略 HTTP(S)_PROXY若需代理请设置 HTTPX_TRUST_ENV=1。"
f" 原始错误: {e!s}"
),
) from e
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail=f"千问请求超时: {e!s}") from e
_raise_for_qwen_http_error(resp, "千问请求失败")
try:
data = resp.json()
except ValueError:
raise HTTPException(status_code=502, detail=f"千问返回非 JSON: {resp.text[:600]}")
return _openai_completion_assistant_text(data)
async def _call_qwen_vision(system_prompt: str, user_prompt: str, image_data_url: str) -> str:
api_key = _get_qwen_api_key()
if not api_key:
return (
"当前未配置千问 API Key无法执行 OCR。\n"
"请在 .env 中配置 QWEN_API_KEY 后重试。"
)
base_url = _qwen_base_url()
model = os.getenv("QWEN_VL_MODEL", "qwen-vl-plus")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 与 DashScope 文档一致:先图后文,利于多模态路由
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_data_url}},
{"type": "text", "text": user_prompt},
],
},
],
"temperature": 0.2,
}
url = f"{base_url}/chat/completions"
try:
async with _qwen_http_client(60.0) as client:
resp = await client.post(url, headers=headers, json=payload)
except httpx.ConnectError as e:
raise HTTPException(
status_code=502,
detail=(
f"无法连接千问接口OCR{url})。请检查网络与 DNS"
"默认已忽略 HTTP(S)_PROXY若需代理请设置 HTTPX_TRUST_ENV=1。"
f" 原始错误: {e!s}"
),
) from e
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail=f"OCR 请求超时: {e!s}") from e
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"OCR 网络请求失败: {e!s}") from e
_raise_for_qwen_http_error(resp, "OCR 请求失败")
try:
data = resp.json()
except ValueError:
raise HTTPException(status_code=502, detail=f"OCR 返回非 JSON: {resp.text[:600]}")
return _openai_completion_assistant_text(data)
@app.get("/api/data/export")
def export_user_data(
format: str = Query("zip", pattern="^(zip|json)$"),
include_files: bool = True,
db: Session = Depends(get_db),
):
payload = _dump_all_data(db)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if format == "json":
buf = BytesIO(json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8"))
return StreamingResponse(
buf,
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.json"'},
)
used_upload_files: set[str] = set()
for item in payload["resources"]:
name = _extract_upload_filename(item.get("url"))
if name:
used_upload_files.add(name)
for item in payload["mistakes"]:
name = _extract_upload_filename(item.get("image_url"))
if name:
used_upload_files.add(name)
buf = BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zip_ref:
zip_ref.writestr("data.json", json.dumps(payload, ensure_ascii=False, indent=2))
if include_files:
for file_name in sorted(used_upload_files):
path = UPLOAD_DIR / file_name
if path.exists() and path.is_file():
zip_ref.write(path, arcname=f"uploads/{file_name}")
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.zip"'},
)
@app.post("/api/data/import")
async def import_user_data(
file: UploadFile = File(...),
mode: str = Query("merge", pattern="^(merge|replace)$"),
db: Session = Depends(get_db),
):
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="导入文件为空")
if len(content) > 100 * 1024 * 1024:
raise HTTPException(status_code=400, detail="导入文件不能超过 100MB")
suffix = Path(file.filename or "").suffix.lower()
payload: dict
zip_ref: zipfile.ZipFile | None = None
if suffix == ".json":
try:
payload = json.loads(content.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise HTTPException(status_code=400, detail=f"JSON 解析失败: {exc}") from exc
elif suffix == ".zip":
try:
zip_ref = zipfile.ZipFile(BytesIO(content))
except zipfile.BadZipFile as exc:
raise HTTPException(status_code=400, detail="ZIP 文件损坏或格式错误") from exc
if "data.json" not in zip_ref.namelist():
raise HTTPException(status_code=400, detail="ZIP 中缺少 data.json")
try:
payload = json.loads(zip_ref.read("data.json").decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise HTTPException(status_code=400, detail=f"data.json 解析失败: {exc}") from exc
else:
raise HTTPException(status_code=400, detail="仅支持 .json 或 .zip 导入")
resources = payload.get("resources", [])
mistakes = payload.get("mistakes", [])
scores = payload.get("scores", [])
if not isinstance(resources, list) or not isinstance(mistakes, list) or not isinstance(scores, list):
raise HTTPException(status_code=400, detail="导入文件结构错误")
if mode == "replace":
for item in db.scalars(select(Resource)).all():
db.delete(item)
for item in db.scalars(select(Mistake)).all():
db.delete(item)
for item in db.scalars(select(ScoreRecord)).all():
db.delete(item)
db.commit()
imported = {"resources": 0, "mistakes": 0, "scores": 0}
for item in resources:
url = item.get("url")
if zip_ref is not None:
url = _restore_upload_url_from_zip(url, zip_ref)
obj = Resource(
title=item.get("title") or "未命名资源",
resource_type=item.get("resource_type") if item.get("resource_type") in {"link", "file"} else "link",
url=url,
file_name=item.get("file_name"),
category=item.get("category") or "未分类",
tags=item.get("tags"),
created_at=_safe_datetime(item.get("created_at")),
)
db.add(obj)
imported["resources"] += 1
for item in mistakes:
image_url = item.get("image_url")
if zip_ref is not None:
image_url = _restore_upload_url_from_zip(image_url, zip_ref)
difficulty = item.get("difficulty")
obj = Mistake(
title=item.get("title") or "未命名错题",
image_url=image_url,
category=item.get("category") or "其他",
difficulty=difficulty if difficulty in {"easy", "medium", "hard"} else None,
question_content=item.get("question_content"),
answer=item.get("answer"),
explanation=item.get("explanation"),
note=item.get("note"),
wrong_count=max(int(item.get("wrong_count") or 1), 1),
created_at=_safe_datetime(item.get("created_at")),
)
db.add(obj)
imported["mistakes"] += 1
for item in scores:
score = float(item.get("total_score") or 0)
obj = ScoreRecord(
exam_name=item.get("exam_name") or "未命名考试",
exam_date=_safe_date(item.get("exam_date")),
total_score=max(min(score, 200), 0),
module_scores=item.get("module_scores"),
created_at=_safe_datetime(item.get("created_at")),
)
db.add(obj)
imported["scores"] += 1
db.commit()
return {"success": True, "mode": mode, "imported": imported}
@app.post("/api/ocr/parse", response_model=OcrParseOut)
async def parse_ocr(payload: OcrParseIn):
file_name = _extract_upload_filename(payload.image_url)
if not file_name:
raise HTTPException(status_code=400, detail="仅支持 /uploads 下的图片做 OCR")
target = UPLOAD_DIR / file_name
if not target.exists() or not target.is_file():
raise HTTPException(status_code=404, detail="图片不存在或已删除")
suffix = target.suffix.lower()
mime = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}.get(suffix)
if not mime:
raise HTTPException(status_code=400, detail="仅支持 JPG/PNG/WebP OCR")
b64 = base64.b64encode(target.read_bytes()).decode("utf-8")
image_data_url = f"data:{mime};base64,{b64}"
ocr_prompt = (
"请识别图片中的题目,返回严格 JSON。"
"字段说明text 为整题完整纯文本(含材料、提问句、全部选项);"
"question_content 必须与 text 一致地表示「完整题干」,须包含阅读材料、填空/提问句、所有选项A B C D 等),"
"禁止只填写「依次填入…」等短提示句而省略材料和选项。"
"另含 title_suggestion、category_suggestion、difficulty_suggestion、answer、explanation。"
"无法确认的字段可填空字符串。"
)
if payload.prompt:
ocr_prompt = f"{ocr_prompt}\n补充要求:{payload.prompt}"
raw_text = await _call_qwen_vision(
"你是公考题目OCR与结构化助手。输出必须是 JSON不要额外解释。",
ocr_prompt,
image_data_url,
)
try:
parsed = json.loads(_extract_json_text(raw_text))
data = (
parsed
if isinstance(parsed, dict)
else {
"text": raw_text.strip(),
"title_suggestion": None,
"category_suggestion": None,
"difficulty_suggestion": None,
"question_content": raw_text.strip(),
"answer": "",
"explanation": "",
}
)
except json.JSONDecodeError:
data = {
"text": raw_text.strip(),
"title_suggestion": None,
"category_suggestion": None,
"difficulty_suggestion": None,
"question_content": raw_text.strip(),
"answer": "",
"explanation": "",
}
def _opt_str(val: Any) -> str | None:
if val is None:
return None
if isinstance(val, (dict, list)):
return None
s = str(val).strip()
return s if s else None
def _merge_question_body(text_raw: str, qc_raw: str | None) -> str | None:
"""模型常把全文放在 text却只把短问句放在 question_content合并时以更长、更完整的文本为准。"""
t = (text_raw or "").strip()
q = (qc_raw or "").strip()
if not t and not q:
return None
if not q:
return t or None
if not t:
return q or None
if len(t) > len(q):
return t
if len(q) > len(t):
return q
# 长度接近或相等:若一方包含另一方,取更长;否则保留 text整页 OCR 通常更全)
if t in q:
return q
if q in t:
return t
return t
text_out = str(data.get("text", "") or "").strip()
qc_model = _opt_str(data.get("question_content"))
question_merged = _merge_question_body(text_out, qc_model)
return OcrParseOut(
text=text_out,
title_suggestion=_opt_str(data.get("title_suggestion")),
category_suggestion=_opt_str(data.get("category_suggestion")),
difficulty_suggestion=_opt_str(data.get("difficulty_suggestion")),
question_content=_opt_str(question_merged),
answer=_opt_str(data.get("answer")),
explanation=_opt_str(data.get("explanation")),
)
@app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut)
async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
item = db.get(Mistake, item_id)
if not item:
raise HTTPException(status_code=404, detail="Mistake not found")
content = await _call_qwen(
"你是公考备考教练,请输出结构化、可执行的错题分析。",
(
f"错题标题: {item.title}\n"
f"分类: {item.category}\n"
f"难度: {item.difficulty or '未设置'}\n"
f"题目内容: {item.question_content or ''}\n"
f"答案: {item.answer or ''}\n"
f"解析: {item.explanation or ''}\n"
f"错误频次: {item.wrong_count}\n"
f"备注: {item.note or ''}\n\n"
"请按以下结构输出:\n"
"1) 错误根因\n2) 关键知识点\n3) 3步复盘法\n4) 3道同类训练建议\n5) 明日复习安排"
),
)
return AiMistakeAnalysisOut(analysis=content)
@app.post("/api/ai/study-plan", response_model=AiStudyPlanOut)
async def ai_study_plan(payload: AiStudyPlanIn, db: Session = Depends(get_db)):
since = date.today() - timedelta(days=30)
recent_scores = db.scalars(select(ScoreRecord).where(ScoreRecord.exam_date >= since).order_by(asc(ScoreRecord.exam_date))).all()
recent_mistakes = db.scalars(select(Mistake).order_by(desc(Mistake.created_at)).limit(20)).all()
score_text = ", ".join([f"{s.exam_date}:{s.total_score}" for s in recent_scores]) or "暂无成绩数据"
mistake_text = ", ".join([f"{m.category}-{m.title}" for m in recent_mistakes]) or "暂无错题数据"
content = await _call_qwen(
"你是公考学习规划师,请给出可执行计划并尽量量化。",
(
f"目标: {payload.goal}\n"
f"剩余天数: {payload.days_left}\n"
f"每天可学习小时数: {payload.daily_hours}\n"
f"近30天分数: {score_text}\n"
f"近期错题: {mistake_text}\n\n"
"请输出: 周计划表、每日任务模板、错题复盘节奏、模考安排、风险提醒。"
),
)
return AiStudyPlanOut(plan=content)