fix:优化试题内容和样式排版

This commit is contained in:
Daniel
2026-04-18 20:20:38 +08:00
parent 15e71a9231
commit 7cb9b89cb0
644 changed files with 152784 additions and 621 deletions

View File

@@ -1,5 +1,9 @@
import os
import json
import uuid
import zipfile
import base64
from typing import Any
from datetime import date, datetime, timedelta
from io import BytesIO
from pathlib import Path
@@ -14,7 +18,7 @@ from reportlab.lib.pagesizes import A4
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
from reportlab.pdfgen import canvas
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy import asc, desc, func, inspect, or_, select, text
from sqlalchemy.orm import Session
from .database import Base, engine, get_db
@@ -27,6 +31,8 @@ from .schemas import (
MistakeCreate,
MistakeOut,
MistakeUpdate,
OcrParseIn,
OcrParseOut,
ResourceBatchUpdate,
ResourceCreate,
ResourceOut,
@@ -39,6 +45,23 @@ from .schemas import (
Base.metadata.create_all(bind=engine)
def _migrate_mistake_columns() -> None:
inspector = inspect(engine)
if "mistakes" not in inspector.get_table_names():
return
existed = {col["name"] for col in inspector.get_columns("mistakes")}
with engine.begin() as conn:
if "question_content" not in existed:
conn.execute(text("ALTER TABLE mistakes ADD COLUMN question_content TEXT"))
if "answer" not in existed:
conn.execute(text("ALTER TABLE mistakes ADD COLUMN answer TEXT"))
if "explanation" not in existed:
conn.execute(text("ALTER TABLE mistakes ADD COLUMN explanation TEXT"))
_migrate_mistake_columns()
app = FastAPI(title="公考助手 API", version="1.0.0")
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@@ -69,6 +92,7 @@ def _query_mistakes_for_export(
category: str | None,
start_date: date | None,
end_date: date | None,
ids: list[int] | None = None,
):
stmt = select(Mistake)
if category:
@@ -77,12 +101,175 @@ def _query_mistakes_for_export(
stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time()))
if end_date:
stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time()))
if ids:
stmt = stmt.where(Mistake.id.in_(ids))
items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all()
if len(items) > 200:
raise HTTPException(status_code=400, detail="单次最多导出 200 题")
return items
def _validate_mistake_payload(payload: MistakeCreate | MistakeUpdate) -> None:
has_image = bool((payload.image_url or "").strip())
has_question = bool((payload.question_content or "").strip())
has_answer = bool((payload.answer or "").strip())
if not has_image and not has_question and not has_answer:
raise HTTPException(status_code=400, detail="请上传题目图片或填写试题/答案后再保存")
def _normalize_multiline_text(value: str | None) -> str:
if not value:
return ""
text_value = value.replace("\r\n", "\n").replace("\r", "\n")
lines = [line.strip() for line in text_value.split("\n")]
compact = [line for line in lines if line]
return "\n".join(compact).strip()
def _wrap_pdf_text(text: str, max_width: float, font_name: str = "STSong-Light", font_size: int = 12) -> list[str]:
normalized = _normalize_multiline_text(text)
if not normalized:
return []
wrapped: list[str] = []
for raw_line in normalized.split("\n"):
current = ""
for ch in raw_line:
candidate = f"{current}{ch}"
if pdfmetrics.stringWidth(candidate, font_name, font_size) <= max_width:
current = candidate
else:
if current:
wrapped.append(current)
current = ch
if current:
wrapped.append(current)
return wrapped
def _mistake_export_blocks(item: Mistake, content_mode: str) -> list[str]:
question = _normalize_multiline_text(item.question_content)
answer = _normalize_multiline_text(item.answer)
explanation = _normalize_multiline_text(item.explanation)
if not question:
question = "无题干与选项内容"
blocks: list[str] = [question]
if content_mode == "full":
# 「答案:」「解析:」与正文同一行开头,避免标签单独成行(与题号+题干规则一致)
blocks.append(f"答案: {answer or ''}")
blocks.append(f"解析: {explanation or ''}")
return blocks
def _extract_upload_filename(url: str | None) -> str | None:
if not url or not url.startswith("/uploads/"):
return None
return Path(url).name
def _safe_datetime(value: str | None) -> datetime:
if not value:
return datetime.utcnow()
try:
return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None)
except ValueError:
return datetime.utcnow()
def _safe_date(value: str | None) -> date:
if not value:
return date.today()
try:
return date.fromisoformat(value)
except ValueError:
return date.today()
def _extract_json_text(raw_text: str) -> str:
content = raw_text.strip()
if content.startswith("```"):
lines = content.splitlines()
if lines:
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
content = "\n".join(lines).strip()
return content
def _dump_all_data(db: Session) -> dict:
resources = db.scalars(select(Resource).order_by(asc(Resource.id))).all()
mistakes = db.scalars(select(Mistake).order_by(asc(Mistake.id))).all()
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.id))).all()
return {
"meta": {
"exported_at": datetime.utcnow().isoformat(),
"version": "1.1.0",
},
"resources": [
{
"id": item.id,
"title": item.title,
"resource_type": item.resource_type,
"url": item.url,
"file_name": item.file_name,
"category": item.category,
"tags": item.tags,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in resources
],
"mistakes": [
{
"id": item.id,
"title": item.title,
"image_url": item.image_url,
"category": item.category,
"difficulty": item.difficulty,
"question_content": item.question_content,
"answer": item.answer,
"explanation": item.explanation,
"note": item.note,
"wrong_count": item.wrong_count,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in mistakes
],
"scores": [
{
"id": item.id,
"exam_name": item.exam_name,
"exam_date": item.exam_date.isoformat() if item.exam_date else None,
"total_score": item.total_score,
"module_scores": item.module_scores,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in scores
],
}
def _restore_upload_url_from_zip(url: str | None, zip_ref: zipfile.ZipFile) -> str | None:
if not url:
return None
file_name = _extract_upload_filename(url)
if not file_name:
return url
zip_path = f"uploads/{file_name}"
if zip_path not in zip_ref.namelist():
return url
data = zip_ref.read(zip_path)
target_name = file_name
target_path = UPLOAD_DIR / target_name
if target_path.exists():
target_name = f"{uuid.uuid4().hex}_{file_name}"
target_path = UPLOAD_DIR / target_name
target_path.write_bytes(data)
return f"/uploads/{target_name}"
@app.post("/api/upload")
async def upload_file(file: UploadFile = File(...)):
suffix = Path(file.filename or "").suffix.lower()
@@ -197,7 +384,16 @@ def list_mistakes(
if category:
stmt = stmt.where(Mistake.category == category)
if keyword:
stmt = stmt.where(or_(Mistake.note.ilike(f"%{keyword}%"), Mistake.title.ilike(f"%{keyword}%")))
stmt = stmt.where(
or_(
Mistake.note.ilike(f"%{keyword}%"),
Mistake.title.ilike(f"%{keyword}%"),
Mistake.question_content.ilike(f"%{keyword}%"),
Mistake.answer.ilike(f"%{keyword}%"),
Mistake.explanation.ilike(f"%{keyword}%"),
Mistake.image_url.ilike(f"%{keyword}%"),
)
)
sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
return db.scalars(stmt).all()
@@ -205,6 +401,7 @@ def list_mistakes(
@app.post("/api/mistakes", response_model=MistakeOut)
def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
_validate_mistake_payload(payload)
item = Mistake(**payload.model_dump())
db.add(item)
db.commit()
@@ -214,6 +411,7 @@ def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
@app.put("/api/mistakes/{item_id}", response_model=MistakeOut)
def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)):
_validate_mistake_payload(payload)
item = db.get(Mistake, item_id)
if not item:
raise HTTPException(status_code=404, detail="Mistake not found")
@@ -239,32 +437,44 @@ def export_mistakes_pdf(
category: str | None = None,
start_date: date | None = None,
end_date: date | None = None,
ids: str | None = None,
content_mode: str = Query("full", pattern="^(full|question_only)$"),
db: Session = Depends(get_db),
):
items = _query_mistakes_for_export(db, category, start_date, end_date)
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
buf = BytesIO()
pdf = canvas.Canvas(buf, pagesize=A4)
pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light"))
pdf.setFont("STSong-Light", 12)
y = 800
pdf.drawString(50, y, "公考助手 - 错题导出")
y -= 30
left = 48
right = 560
max_width = right - left
pdf.drawString(left, y, "公考助手 - 错题导出")
y -= 28
for idx, item in enumerate(items, start=1):
lines = [
f"{idx}. {item.title}",
f"分类: {item.category} 难度: {item.difficulty or '未设置'} 错误频次: {item.wrong_count}",
f"备注: {item.note or ''}",
"答题区: _______________________________",
]
for line in lines:
if y < 70:
pdf.showPage()
pdf.setFont("STSong-Light", 12)
y = 800
pdf.drawString(50, y, line[:90])
y -= 22
y -= 6
if y < 90:
pdf.showPage()
pdf.setFont("STSong-Light", 12)
y = 800
blocks = _mistake_export_blocks(item, content_mode)
for bi, block in enumerate(blocks):
# 题号与题干同一行开头避免「1.」单独成行
text = f"{idx}. {block}" if bi == 0 else block
lines = _wrap_pdf_text(text, max_width=max_width)
if not lines:
continue
for line in lines:
if y < 70:
pdf.showPage()
pdf.setFont("STSong-Light", 12)
y = 800
pdf.drawString(left, y, line)
y -= 18
y -= 6
y -= 8
pdf.save()
buf.seek(0)
@@ -280,16 +490,20 @@ def export_mistakes_docx(
category: str | None = None,
start_date: date | None = None,
end_date: date | None = None,
ids: str | None = None,
content_mode: str = Query("full", pattern="^(full|question_only)$"),
db: Session = Depends(get_db),
):
items = _query_mistakes_for_export(db, category, start_date, end_date)
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
doc = Document()
doc.add_heading("公考助手 - 错题导出", level=1)
for idx, item in enumerate(items, start=1):
doc.add_paragraph(f"{idx}. {item.title}")
doc.add_paragraph(f"分类: {item.category} | 难度: {item.difficulty or '未设置'} | 错误频次: {item.wrong_count}")
doc.add_paragraph(f"备注: {item.note or ''}")
doc.add_paragraph("答题区: ________________________________________")
blocks = _mistake_export_blocks(item, content_mode)
for bi, block in enumerate(blocks):
# 题号与题干同段避免单独一行只有「1.」
para = f"{idx}. {block}" if bi == 0 else block
doc.add_paragraph(para)
buf = BytesIO()
doc.save(buf)
@@ -362,8 +576,108 @@ def score_stats(db: Session = Depends(get_db)):
return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement)
def _qwen_base_url() -> str:
return os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1").strip().rstrip("/")
def _get_qwen_api_key() -> str:
"""去除首尾空白与常见误加的引号,避免 .env 里写成 'sk-xxx' 导致鉴权失败。"""
raw = os.getenv("QWEN_API_KEY", "") or ""
return raw.strip().strip('"').strip("'").strip()
def _raise_for_qwen_http_error(resp: httpx.Response, prefix: str) -> None:
"""HTTP 非 2xx 时解析 DashScope 错误体,对 invalid_api_key 返回 401 + 明确说明。"""
if resp.status_code < 300:
return
text = resp.text
try:
data = resp.json()
err = data.get("error")
if isinstance(err, dict):
code = str(err.get("code") or "")
if code == "invalid_api_key":
raise HTTPException(
status_code=401,
detail=(
"阿里云 DashScope API Key 无效或未生效。"
"请到阿里云百炼 / Model Studio 控制台创建 API Key通常以 sk- 开头),"
"写入项目根目录 .env 的 QWEN_API_KEY=,勿加引号;"
"修改后执行: docker compose up -d --build backend"
),
)
msg = err.get("message") or text
raise HTTPException(status_code=502, detail=f"{prefix}: {msg}")
except HTTPException:
raise
except (ValueError, TypeError, KeyError):
pass
raise HTTPException(status_code=502, detail=f"{prefix}: {text[:1200]}")
def _httpx_trust_env() -> bool:
"""默认不信任环境变量中的代理,避免 Docker/IDE 注入空代理导致 ConnectError需走系统代理时设 HTTPX_TRUST_ENV=1。"""
return os.getenv("HTTPX_TRUST_ENV", "0").lower() in ("1", "true", "yes")
def _qwen_http_client(timeout_sec: float = 60.0) -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=httpx.Timeout(timeout_sec, connect=20.0),
trust_env=_httpx_trust_env(),
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
)
def _message_content_to_str(content: Any) -> str:
"""OpenAI 兼容接口里 message.content 可能是 str 或多段结构。"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text" and "text" in part:
parts.append(str(part["text"]))
elif "text" in part:
parts.append(str(part["text"]))
elif isinstance(part, str):
parts.append(part)
return "".join(parts)
return str(content)
def _openai_completion_assistant_text(data: dict) -> str:
"""从 chat/completions JSON 中取出助手文本;若含 error 或无 choices 则抛错。"""
err = data.get("error")
if err is not None:
if isinstance(err, dict):
code = str(err.get("code") or "")
if code == "invalid_api_key":
raise HTTPException(
status_code=401,
detail=(
"阿里云 DashScope API Key 无效。"
"请在 .env 中填写正确的 QWEN_API_KEY 并重启 backend。"
),
)
msg = err.get("message") or err.get("code") or json.dumps(err, ensure_ascii=False)
else:
msg = str(err)
raise HTTPException(status_code=502, detail=f"千问接口错误: {msg}")
choices = data.get("choices")
if not choices:
raise HTTPException(
status_code=502,
detail=f"千问返回异常(无 choices请检查模型名与权限。原始片段: {json.dumps(data, ensure_ascii=False)[:800]}",
)
msg = choices[0].get("message") or {}
return _message_content_to_str(msg.get("content"))
async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
api_key = os.getenv("QWEN_API_KEY", "")
api_key = _get_qwen_api_key()
if not api_key:
return (
"当前未配置千问 API Key已返回本地降级提示。\n"
@@ -373,7 +687,7 @@ async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
"QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n"
"QWEN_MODEL=qwen-plus"
)
base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
base_url = _qwen_base_url()
model = os.getenv("QWEN_MODEL", "qwen-plus")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
@@ -384,12 +698,333 @@ async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
],
"temperature": 0.4,
}
async with httpx.AsyncClient(timeout=40) as client:
resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload)
if resp.status_code >= 300:
raise HTTPException(status_code=502, detail=f"千问请求失败: {resp.text}")
data = resp.json()
return data["choices"][0]["message"]["content"]
url = f"{base_url}/chat/completions"
try:
async with _qwen_http_client(40.0) as client:
resp = await client.post(url, headers=headers, json=payload)
except httpx.ConnectError as e:
raise HTTPException(
status_code=502,
detail=(
f"无法连接千问接口({url})。请检查本机/容器能否访问外网、DNS 是否正常;"
"若在 Docker 中可尝试为 backend 配置 dns 或关闭错误代理。"
"默认已忽略 HTTP(S)_PROXY若需代理请设置 HTTPX_TRUST_ENV=1。"
f" 原始错误: {e!s}"
),
) from e
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail=f"千问请求超时: {e!s}") from e
_raise_for_qwen_http_error(resp, "千问请求失败")
try:
data = resp.json()
except ValueError:
raise HTTPException(status_code=502, detail=f"千问返回非 JSON: {resp.text[:600]}")
return _openai_completion_assistant_text(data)
async def _call_qwen_vision(system_prompt: str, user_prompt: str, image_data_url: str) -> str:
api_key = _get_qwen_api_key()
if not api_key:
return (
"当前未配置千问 API Key无法执行 OCR。\n"
"请在 .env 中配置 QWEN_API_KEY 后重试。"
)
base_url = _qwen_base_url()
model = os.getenv("QWEN_VL_MODEL", "qwen-vl-plus")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# 与 DashScope 文档一致:先图后文,利于多模态路由
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_data_url}},
{"type": "text", "text": user_prompt},
],
},
],
"temperature": 0.2,
}
url = f"{base_url}/chat/completions"
try:
async with _qwen_http_client(60.0) as client:
resp = await client.post(url, headers=headers, json=payload)
except httpx.ConnectError as e:
raise HTTPException(
status_code=502,
detail=(
f"无法连接千问接口OCR{url})。请检查网络与 DNS"
"默认已忽略 HTTP(S)_PROXY若需代理请设置 HTTPX_TRUST_ENV=1。"
f" 原始错误: {e!s}"
),
) from e
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail=f"OCR 请求超时: {e!s}") from e
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"OCR 网络请求失败: {e!s}") from e
_raise_for_qwen_http_error(resp, "OCR 请求失败")
try:
data = resp.json()
except ValueError:
raise HTTPException(status_code=502, detail=f"OCR 返回非 JSON: {resp.text[:600]}")
return _openai_completion_assistant_text(data)
@app.get("/api/data/export")
def export_user_data(
format: str = Query("zip", pattern="^(zip|json)$"),
include_files: bool = True,
db: Session = Depends(get_db),
):
payload = _dump_all_data(db)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if format == "json":
buf = BytesIO(json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8"))
return StreamingResponse(
buf,
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.json"'},
)
used_upload_files: set[str] = set()
for item in payload["resources"]:
name = _extract_upload_filename(item.get("url"))
if name:
used_upload_files.add(name)
for item in payload["mistakes"]:
name = _extract_upload_filename(item.get("image_url"))
if name:
used_upload_files.add(name)
buf = BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zip_ref:
zip_ref.writestr("data.json", json.dumps(payload, ensure_ascii=False, indent=2))
if include_files:
for file_name in sorted(used_upload_files):
path = UPLOAD_DIR / file_name
if path.exists() and path.is_file():
zip_ref.write(path, arcname=f"uploads/{file_name}")
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.zip"'},
)
@app.post("/api/data/import")
async def import_user_data(
file: UploadFile = File(...),
mode: str = Query("merge", pattern="^(merge|replace)$"),
db: Session = Depends(get_db),
):
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="导入文件为空")
if len(content) > 100 * 1024 * 1024:
raise HTTPException(status_code=400, detail="导入文件不能超过 100MB")
suffix = Path(file.filename or "").suffix.lower()
payload: dict
zip_ref: zipfile.ZipFile | None = None
if suffix == ".json":
try:
payload = json.loads(content.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise HTTPException(status_code=400, detail=f"JSON 解析失败: {exc}") from exc
elif suffix == ".zip":
try:
zip_ref = zipfile.ZipFile(BytesIO(content))
except zipfile.BadZipFile as exc:
raise HTTPException(status_code=400, detail="ZIP 文件损坏或格式错误") from exc
if "data.json" not in zip_ref.namelist():
raise HTTPException(status_code=400, detail="ZIP 中缺少 data.json")
try:
payload = json.loads(zip_ref.read("data.json").decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise HTTPException(status_code=400, detail=f"data.json 解析失败: {exc}") from exc
else:
raise HTTPException(status_code=400, detail="仅支持 .json 或 .zip 导入")
resources = payload.get("resources", [])
mistakes = payload.get("mistakes", [])
scores = payload.get("scores", [])
if not isinstance(resources, list) or not isinstance(mistakes, list) or not isinstance(scores, list):
raise HTTPException(status_code=400, detail="导入文件结构错误")
if mode == "replace":
for item in db.scalars(select(Resource)).all():
db.delete(item)
for item in db.scalars(select(Mistake)).all():
db.delete(item)
for item in db.scalars(select(ScoreRecord)).all():
db.delete(item)
db.commit()
imported = {"resources": 0, "mistakes": 0, "scores": 0}
for item in resources:
url = item.get("url")
if zip_ref is not None:
url = _restore_upload_url_from_zip(url, zip_ref)
obj = Resource(
title=item.get("title") or "未命名资源",
resource_type=item.get("resource_type") if item.get("resource_type") in {"link", "file"} else "link",
url=url,
file_name=item.get("file_name"),
category=item.get("category") or "未分类",
tags=item.get("tags"),
created_at=_safe_datetime(item.get("created_at")),
)
db.add(obj)
imported["resources"] += 1
for item in mistakes:
image_url = item.get("image_url")
if zip_ref is not None:
image_url = _restore_upload_url_from_zip(image_url, zip_ref)
difficulty = item.get("difficulty")
obj = Mistake(
title=item.get("title") or "未命名错题",
image_url=image_url,
category=item.get("category") or "其他",
difficulty=difficulty if difficulty in {"easy", "medium", "hard"} else None,
question_content=item.get("question_content"),
answer=item.get("answer"),
explanation=item.get("explanation"),
note=item.get("note"),
wrong_count=max(int(item.get("wrong_count") or 1), 1),
created_at=_safe_datetime(item.get("created_at")),
)
db.add(obj)
imported["mistakes"] += 1
for item in scores:
score = float(item.get("total_score") or 0)
obj = ScoreRecord(
exam_name=item.get("exam_name") or "未命名考试",
exam_date=_safe_date(item.get("exam_date")),
total_score=max(min(score, 200), 0),
module_scores=item.get("module_scores"),
created_at=_safe_datetime(item.get("created_at")),
)
db.add(obj)
imported["scores"] += 1
db.commit()
return {"success": True, "mode": mode, "imported": imported}
@app.post("/api/ocr/parse", response_model=OcrParseOut)
async def parse_ocr(payload: OcrParseIn):
file_name = _extract_upload_filename(payload.image_url)
if not file_name:
raise HTTPException(status_code=400, detail="仅支持 /uploads 下的图片做 OCR")
target = UPLOAD_DIR / file_name
if not target.exists() or not target.is_file():
raise HTTPException(status_code=404, detail="图片不存在或已删除")
suffix = target.suffix.lower()
mime = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}.get(suffix)
if not mime:
raise HTTPException(status_code=400, detail="仅支持 JPG/PNG/WebP OCR")
b64 = base64.b64encode(target.read_bytes()).decode("utf-8")
image_data_url = f"data:{mime};base64,{b64}"
ocr_prompt = (
"请识别图片中的题目,返回严格 JSON。"
"字段说明text 为整题完整纯文本(含材料、提问句、全部选项);"
"question_content 必须与 text 一致地表示「完整题干」,须包含阅读材料、填空/提问句、所有选项A B C D 等),"
"禁止只填写「依次填入…」等短提示句而省略材料和选项。"
"另含 title_suggestion、category_suggestion、difficulty_suggestion、answer、explanation。"
"无法确认的字段可填空字符串。"
)
if payload.prompt:
ocr_prompt = f"{ocr_prompt}\n补充要求:{payload.prompt}"
raw_text = await _call_qwen_vision(
"你是公考题目OCR与结构化助手。输出必须是 JSON不要额外解释。",
ocr_prompt,
image_data_url,
)
try:
parsed = json.loads(_extract_json_text(raw_text))
data = (
parsed
if isinstance(parsed, dict)
else {
"text": raw_text.strip(),
"title_suggestion": None,
"category_suggestion": None,
"difficulty_suggestion": None,
"question_content": raw_text.strip(),
"answer": "",
"explanation": "",
}
)
except json.JSONDecodeError:
data = {
"text": raw_text.strip(),
"title_suggestion": None,
"category_suggestion": None,
"difficulty_suggestion": None,
"question_content": raw_text.strip(),
"answer": "",
"explanation": "",
}
def _opt_str(val: Any) -> str | None:
if val is None:
return None
if isinstance(val, (dict, list)):
return None
s = str(val).strip()
return s if s else None
def _merge_question_body(text_raw: str, qc_raw: str | None) -> str | None:
"""模型常把全文放在 text却只把短问句放在 question_content合并时以更长、更完整的文本为准。"""
t = (text_raw or "").strip()
q = (qc_raw or "").strip()
if not t and not q:
return None
if not q:
return t or None
if not t:
return q or None
if len(t) > len(q):
return t
if len(q) > len(t):
return q
# 长度接近或相等:若一方包含另一方,取更长;否则保留 text整页 OCR 通常更全)
if t in q:
return q
if q in t:
return t
return t
text_out = str(data.get("text", "") or "").strip()
qc_model = _opt_str(data.get("question_content"))
question_merged = _merge_question_body(text_out, qc_model)
return OcrParseOut(
text=text_out,
title_suggestion=_opt_str(data.get("title_suggestion")),
category_suggestion=_opt_str(data.get("category_suggestion")),
difficulty_suggestion=_opt_str(data.get("difficulty_suggestion")),
question_content=_opt_str(question_merged),
answer=_opt_str(data.get("answer")),
explanation=_opt_str(data.get("explanation")),
)
@app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut)
@@ -403,6 +1038,9 @@ async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
f"错题标题: {item.title}\n"
f"分类: {item.category}\n"
f"难度: {item.difficulty or '未设置'}\n"
f"题目内容: {item.question_content or ''}\n"
f"答案: {item.answer or ''}\n"
f"解析: {item.explanation or ''}\n"
f"错误频次: {item.wrong_count}\n"
f"备注: {item.note or ''}\n\n"
"请按以下结构输出:\n"

View File

@@ -27,6 +27,9 @@ class Mistake(Base):
image_url: Mapped[str | None] = mapped_column(String(1024), nullable=True)
category: Mapped[str] = mapped_column(String(50), nullable=False)
difficulty: Mapped[str | None] = mapped_column(String(20), nullable=True) # easy/medium/hard
question_content: Mapped[str | None] = mapped_column(Text, nullable=True)
answer: Mapped[str | None] = mapped_column(Text, nullable=True)
explanation: Mapped[str | None] = mapped_column(Text, nullable=True)
note: Mapped[str | None] = mapped_column(Text, nullable=True)
wrong_count: Mapped[int] = mapped_column(Integer, default=1)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)

View File

@@ -31,7 +31,10 @@ class MistakeBase(BaseModel):
image_url: str | None = None
category: str
difficulty: str | None = Field(None, pattern="^(easy|medium|hard)$")
note: str | None = Field(None, max_length=500)
question_content: str | None = Field(None, max_length=8000)
answer: str | None = Field(None, max_length=4000)
explanation: str | None = Field(None, max_length=8000)
note: str | None = Field(None, max_length=4000)
wrong_count: int = Field(1, ge=1)
@@ -99,3 +102,18 @@ class AiStudyPlanIn(BaseModel):
class AiStudyPlanOut(BaseModel):
plan: str
class OcrParseIn(BaseModel):
image_url: str = Field(..., max_length=1024)
prompt: str | None = Field(None, max_length=500)
class OcrParseOut(BaseModel):
text: str
title_suggestion: str | None = None
category_suggestion: str | None = None
difficulty_suggestion: str | None = None
question_content: str | None = None
answer: str | None = None
explanation: str | None = None