fix:优化试题内容和样式排版
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import zipfile
|
||||
import base64
|
||||
from typing import Any
|
||||
from datetime import date, datetime, timedelta
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
@@ -14,7 +18,7 @@ from reportlab.lib.pagesizes import A4
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
|
||||
from reportlab.pdfgen import canvas
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
from sqlalchemy import asc, desc, func, inspect, or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .database import Base, engine, get_db
|
||||
@@ -27,6 +31,8 @@ from .schemas import (
|
||||
MistakeCreate,
|
||||
MistakeOut,
|
||||
MistakeUpdate,
|
||||
OcrParseIn,
|
||||
OcrParseOut,
|
||||
ResourceBatchUpdate,
|
||||
ResourceCreate,
|
||||
ResourceOut,
|
||||
@@ -39,6 +45,23 @@ from .schemas import (
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def _migrate_mistake_columns() -> None:
|
||||
inspector = inspect(engine)
|
||||
if "mistakes" not in inspector.get_table_names():
|
||||
return
|
||||
existed = {col["name"] for col in inspector.get_columns("mistakes")}
|
||||
with engine.begin() as conn:
|
||||
if "question_content" not in existed:
|
||||
conn.execute(text("ALTER TABLE mistakes ADD COLUMN question_content TEXT"))
|
||||
if "answer" not in existed:
|
||||
conn.execute(text("ALTER TABLE mistakes ADD COLUMN answer TEXT"))
|
||||
if "explanation" not in existed:
|
||||
conn.execute(text("ALTER TABLE mistakes ADD COLUMN explanation TEXT"))
|
||||
|
||||
|
||||
_migrate_mistake_columns()
|
||||
|
||||
app = FastAPI(title="公考助手 API", version="1.0.0")
|
||||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -69,6 +92,7 @@ def _query_mistakes_for_export(
|
||||
category: str | None,
|
||||
start_date: date | None,
|
||||
end_date: date | None,
|
||||
ids: list[int] | None = None,
|
||||
):
|
||||
stmt = select(Mistake)
|
||||
if category:
|
||||
@@ -77,12 +101,175 @@ def _query_mistakes_for_export(
|
||||
stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time()))
|
||||
if end_date:
|
||||
stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time()))
|
||||
if ids:
|
||||
stmt = stmt.where(Mistake.id.in_(ids))
|
||||
items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all()
|
||||
if len(items) > 200:
|
||||
raise HTTPException(status_code=400, detail="单次最多导出 200 题")
|
||||
return items
|
||||
|
||||
|
||||
def _validate_mistake_payload(payload: MistakeCreate | MistakeUpdate) -> None:
|
||||
has_image = bool((payload.image_url or "").strip())
|
||||
has_question = bool((payload.question_content or "").strip())
|
||||
has_answer = bool((payload.answer or "").strip())
|
||||
if not has_image and not has_question and not has_answer:
|
||||
raise HTTPException(status_code=400, detail="请上传题目图片或填写试题/答案后再保存")
|
||||
|
||||
|
||||
def _normalize_multiline_text(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
text_value = value.replace("\r\n", "\n").replace("\r", "\n")
|
||||
lines = [line.strip() for line in text_value.split("\n")]
|
||||
compact = [line for line in lines if line]
|
||||
return "\n".join(compact).strip()
|
||||
|
||||
|
||||
def _wrap_pdf_text(text: str, max_width: float, font_name: str = "STSong-Light", font_size: int = 12) -> list[str]:
|
||||
normalized = _normalize_multiline_text(text)
|
||||
if not normalized:
|
||||
return []
|
||||
wrapped: list[str] = []
|
||||
for raw_line in normalized.split("\n"):
|
||||
current = ""
|
||||
for ch in raw_line:
|
||||
candidate = f"{current}{ch}"
|
||||
if pdfmetrics.stringWidth(candidate, font_name, font_size) <= max_width:
|
||||
current = candidate
|
||||
else:
|
||||
if current:
|
||||
wrapped.append(current)
|
||||
current = ch
|
||||
if current:
|
||||
wrapped.append(current)
|
||||
return wrapped
|
||||
|
||||
|
||||
def _mistake_export_blocks(item: Mistake, content_mode: str) -> list[str]:
|
||||
question = _normalize_multiline_text(item.question_content)
|
||||
answer = _normalize_multiline_text(item.answer)
|
||||
explanation = _normalize_multiline_text(item.explanation)
|
||||
if not question:
|
||||
question = "无题干与选项内容"
|
||||
|
||||
blocks: list[str] = [question]
|
||||
if content_mode == "full":
|
||||
# 「答案:」「解析:」与正文同一行开头,避免标签单独成行(与题号+题干规则一致)
|
||||
blocks.append(f"答案: {answer or '无'}")
|
||||
blocks.append(f"解析: {explanation or '无'}")
|
||||
return blocks
|
||||
|
||||
|
||||
def _extract_upload_filename(url: str | None) -> str | None:
|
||||
if not url or not url.startswith("/uploads/"):
|
||||
return None
|
||||
return Path(url).name
|
||||
|
||||
|
||||
def _safe_datetime(value: str | None) -> datetime:
|
||||
if not value:
|
||||
return datetime.utcnow()
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
except ValueError:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
def _safe_date(value: str | None) -> date:
|
||||
if not value:
|
||||
return date.today()
|
||||
try:
|
||||
return date.fromisoformat(value)
|
||||
except ValueError:
|
||||
return date.today()
|
||||
|
||||
|
||||
def _extract_json_text(raw_text: str) -> str:
|
||||
content = raw_text.strip()
|
||||
if content.startswith("```"):
|
||||
lines = content.splitlines()
|
||||
if lines:
|
||||
lines = lines[1:]
|
||||
if lines and lines[-1].strip() == "```":
|
||||
lines = lines[:-1]
|
||||
content = "\n".join(lines).strip()
|
||||
return content
|
||||
|
||||
|
||||
def _dump_all_data(db: Session) -> dict:
|
||||
resources = db.scalars(select(Resource).order_by(asc(Resource.id))).all()
|
||||
mistakes = db.scalars(select(Mistake).order_by(asc(Mistake.id))).all()
|
||||
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.id))).all()
|
||||
return {
|
||||
"meta": {
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
"version": "1.1.0",
|
||||
},
|
||||
"resources": [
|
||||
{
|
||||
"id": item.id,
|
||||
"title": item.title,
|
||||
"resource_type": item.resource_type,
|
||||
"url": item.url,
|
||||
"file_name": item.file_name,
|
||||
"category": item.category,
|
||||
"tags": item.tags,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
}
|
||||
for item in resources
|
||||
],
|
||||
"mistakes": [
|
||||
{
|
||||
"id": item.id,
|
||||
"title": item.title,
|
||||
"image_url": item.image_url,
|
||||
"category": item.category,
|
||||
"difficulty": item.difficulty,
|
||||
"question_content": item.question_content,
|
||||
"answer": item.answer,
|
||||
"explanation": item.explanation,
|
||||
"note": item.note,
|
||||
"wrong_count": item.wrong_count,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
}
|
||||
for item in mistakes
|
||||
],
|
||||
"scores": [
|
||||
{
|
||||
"id": item.id,
|
||||
"exam_name": item.exam_name,
|
||||
"exam_date": item.exam_date.isoformat() if item.exam_date else None,
|
||||
"total_score": item.total_score,
|
||||
"module_scores": item.module_scores,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
}
|
||||
for item in scores
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _restore_upload_url_from_zip(url: str | None, zip_ref: zipfile.ZipFile) -> str | None:
|
||||
if not url:
|
||||
return None
|
||||
file_name = _extract_upload_filename(url)
|
||||
if not file_name:
|
||||
return url
|
||||
|
||||
zip_path = f"uploads/{file_name}"
|
||||
if zip_path not in zip_ref.namelist():
|
||||
return url
|
||||
|
||||
data = zip_ref.read(zip_path)
|
||||
target_name = file_name
|
||||
target_path = UPLOAD_DIR / target_name
|
||||
if target_path.exists():
|
||||
target_name = f"{uuid.uuid4().hex}_{file_name}"
|
||||
target_path = UPLOAD_DIR / target_name
|
||||
target_path.write_bytes(data)
|
||||
return f"/uploads/{target_name}"
|
||||
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
suffix = Path(file.filename or "").suffix.lower()
|
||||
@@ -197,7 +384,16 @@ def list_mistakes(
|
||||
if category:
|
||||
stmt = stmt.where(Mistake.category == category)
|
||||
if keyword:
|
||||
stmt = stmt.where(or_(Mistake.note.ilike(f"%{keyword}%"), Mistake.title.ilike(f"%{keyword}%")))
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Mistake.note.ilike(f"%{keyword}%"),
|
||||
Mistake.title.ilike(f"%{keyword}%"),
|
||||
Mistake.question_content.ilike(f"%{keyword}%"),
|
||||
Mistake.answer.ilike(f"%{keyword}%"),
|
||||
Mistake.explanation.ilike(f"%{keyword}%"),
|
||||
Mistake.image_url.ilike(f"%{keyword}%"),
|
||||
)
|
||||
)
|
||||
sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count
|
||||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||||
return db.scalars(stmt).all()
|
||||
@@ -205,6 +401,7 @@ def list_mistakes(
|
||||
|
||||
@app.post("/api/mistakes", response_model=MistakeOut)
|
||||
def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
|
||||
_validate_mistake_payload(payload)
|
||||
item = Mistake(**payload.model_dump())
|
||||
db.add(item)
|
||||
db.commit()
|
||||
@@ -214,6 +411,7 @@ def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
|
||||
|
||||
@app.put("/api/mistakes/{item_id}", response_model=MistakeOut)
|
||||
def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)):
|
||||
_validate_mistake_payload(payload)
|
||||
item = db.get(Mistake, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||||
@@ -239,32 +437,44 @@ def export_mistakes_pdf(
|
||||
category: str | None = None,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
ids: str | None = None,
|
||||
content_mode: str = Query("full", pattern="^(full|question_only)$"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
items = _query_mistakes_for_export(db, category, start_date, end_date)
|
||||
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
|
||||
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
|
||||
buf = BytesIO()
|
||||
pdf = canvas.Canvas(buf, pagesize=A4)
|
||||
pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light"))
|
||||
pdf.setFont("STSong-Light", 12)
|
||||
|
||||
y = 800
|
||||
pdf.drawString(50, y, "公考助手 - 错题导出")
|
||||
y -= 30
|
||||
left = 48
|
||||
right = 560
|
||||
max_width = right - left
|
||||
pdf.drawString(left, y, "公考助手 - 错题导出")
|
||||
y -= 28
|
||||
for idx, item in enumerate(items, start=1):
|
||||
lines = [
|
||||
f"{idx}. {item.title}",
|
||||
f"分类: {item.category} 难度: {item.difficulty or '未设置'} 错误频次: {item.wrong_count}",
|
||||
f"备注: {item.note or '无'}",
|
||||
"答题区: _______________________________",
|
||||
]
|
||||
for line in lines:
|
||||
if y < 70:
|
||||
pdf.showPage()
|
||||
pdf.setFont("STSong-Light", 12)
|
||||
y = 800
|
||||
pdf.drawString(50, y, line[:90])
|
||||
y -= 22
|
||||
y -= 6
|
||||
if y < 90:
|
||||
pdf.showPage()
|
||||
pdf.setFont("STSong-Light", 12)
|
||||
y = 800
|
||||
blocks = _mistake_export_blocks(item, content_mode)
|
||||
for bi, block in enumerate(blocks):
|
||||
# 题号与题干同一行开头,避免「1.」单独成行
|
||||
text = f"{idx}. {block}" if bi == 0 else block
|
||||
lines = _wrap_pdf_text(text, max_width=max_width)
|
||||
if not lines:
|
||||
continue
|
||||
for line in lines:
|
||||
if y < 70:
|
||||
pdf.showPage()
|
||||
pdf.setFont("STSong-Light", 12)
|
||||
y = 800
|
||||
pdf.drawString(left, y, line)
|
||||
y -= 18
|
||||
y -= 6
|
||||
y -= 8
|
||||
|
||||
pdf.save()
|
||||
buf.seek(0)
|
||||
@@ -280,16 +490,20 @@ def export_mistakes_docx(
|
||||
category: str | None = None,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
ids: str | None = None,
|
||||
content_mode: str = Query("full", pattern="^(full|question_only)$"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
items = _query_mistakes_for_export(db, category, start_date, end_date)
|
||||
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
|
||||
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
|
||||
doc = Document()
|
||||
doc.add_heading("公考助手 - 错题导出", level=1)
|
||||
for idx, item in enumerate(items, start=1):
|
||||
doc.add_paragraph(f"{idx}. {item.title}")
|
||||
doc.add_paragraph(f"分类: {item.category} | 难度: {item.difficulty or '未设置'} | 错误频次: {item.wrong_count}")
|
||||
doc.add_paragraph(f"备注: {item.note or '无'}")
|
||||
doc.add_paragraph("答题区: ________________________________________")
|
||||
blocks = _mistake_export_blocks(item, content_mode)
|
||||
for bi, block in enumerate(blocks):
|
||||
# 题号与题干同段,避免单独一行只有「1.」
|
||||
para = f"{idx}. {block}" if bi == 0 else block
|
||||
doc.add_paragraph(para)
|
||||
|
||||
buf = BytesIO()
|
||||
doc.save(buf)
|
||||
@@ -362,8 +576,108 @@ def score_stats(db: Session = Depends(get_db)):
|
||||
return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement)
|
||||
|
||||
|
||||
def _qwen_base_url() -> str:
|
||||
return os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1").strip().rstrip("/")
|
||||
|
||||
|
||||
def _get_qwen_api_key() -> str:
|
||||
"""去除首尾空白与常见误加的引号,避免 .env 里写成 'sk-xxx' 导致鉴权失败。"""
|
||||
raw = os.getenv("QWEN_API_KEY", "") or ""
|
||||
return raw.strip().strip('"').strip("'").strip()
|
||||
|
||||
|
||||
def _raise_for_qwen_http_error(resp: httpx.Response, prefix: str) -> None:
|
||||
"""HTTP 非 2xx 时解析 DashScope 错误体,对 invalid_api_key 返回 401 + 明确说明。"""
|
||||
if resp.status_code < 300:
|
||||
return
|
||||
text = resp.text
|
||||
try:
|
||||
data = resp.json()
|
||||
err = data.get("error")
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or "")
|
||||
if code == "invalid_api_key":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=(
|
||||
"阿里云 DashScope API Key 无效或未生效。"
|
||||
"请到阿里云百炼 / Model Studio 控制台创建 API Key(通常以 sk- 开头),"
|
||||
"写入项目根目录 .env 的 QWEN_API_KEY=,勿加引号;"
|
||||
"修改后执行: docker compose up -d --build backend"
|
||||
),
|
||||
)
|
||||
msg = err.get("message") or text
|
||||
raise HTTPException(status_code=502, detail=f"{prefix}: {msg}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except (ValueError, TypeError, KeyError):
|
||||
pass
|
||||
raise HTTPException(status_code=502, detail=f"{prefix}: {text[:1200]}")
|
||||
|
||||
|
||||
def _httpx_trust_env() -> bool:
|
||||
"""默认不信任环境变量中的代理,避免 Docker/IDE 注入空代理导致 ConnectError;需走系统代理时设 HTTPX_TRUST_ENV=1。"""
|
||||
return os.getenv("HTTPX_TRUST_ENV", "0").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
def _qwen_http_client(timeout_sec: float = 60.0) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout_sec, connect=20.0),
|
||||
trust_env=_httpx_trust_env(),
|
||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||||
)
|
||||
|
||||
|
||||
def _message_content_to_str(content: Any) -> str:
|
||||
"""OpenAI 兼容接口里 message.content 可能是 str 或多段结构。"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if part.get("type") == "text" and "text" in part:
|
||||
parts.append(str(part["text"]))
|
||||
elif "text" in part:
|
||||
parts.append(str(part["text"]))
|
||||
elif isinstance(part, str):
|
||||
parts.append(part)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _openai_completion_assistant_text(data: dict) -> str:
|
||||
"""从 chat/completions JSON 中取出助手文本;若含 error 或无 choices 则抛错。"""
|
||||
err = data.get("error")
|
||||
if err is not None:
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or "")
|
||||
if code == "invalid_api_key":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=(
|
||||
"阿里云 DashScope API Key 无效。"
|
||||
"请在 .env 中填写正确的 QWEN_API_KEY 并重启 backend。"
|
||||
),
|
||||
)
|
||||
msg = err.get("message") or err.get("code") or json.dumps(err, ensure_ascii=False)
|
||||
else:
|
||||
msg = str(err)
|
||||
raise HTTPException(status_code=502, detail=f"千问接口错误: {msg}")
|
||||
choices = data.get("choices")
|
||||
if not choices:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"千问返回异常(无 choices),请检查模型名与权限。原始片段: {json.dumps(data, ensure_ascii=False)[:800]}",
|
||||
)
|
||||
msg = choices[0].get("message") or {}
|
||||
return _message_content_to_str(msg.get("content"))
|
||||
|
||||
|
||||
async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
|
||||
api_key = os.getenv("QWEN_API_KEY", "")
|
||||
api_key = _get_qwen_api_key()
|
||||
if not api_key:
|
||||
return (
|
||||
"当前未配置千问 API Key,已返回本地降级提示。\n"
|
||||
@@ -373,7 +687,7 @@ async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
|
||||
"QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n"
|
||||
"QWEN_MODEL=qwen-plus"
|
||||
)
|
||||
base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
base_url = _qwen_base_url()
|
||||
model = os.getenv("QWEN_MODEL", "qwen-plus")
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
@@ -384,12 +698,333 @@ async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
|
||||
],
|
||||
"temperature": 0.4,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=40) as client:
|
||||
resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload)
|
||||
if resp.status_code >= 300:
|
||||
raise HTTPException(status_code=502, detail=f"千问请求失败: {resp.text}")
|
||||
data = resp.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
url = f"{base_url}/chat/completions"
|
||||
try:
|
||||
async with _qwen_http_client(40.0) as client:
|
||||
resp = await client.post(url, headers=headers, json=payload)
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
f"无法连接千问接口({url})。请检查本机/容器能否访问外网、DNS 是否正常;"
|
||||
"若在 Docker 中可尝试为 backend 配置 dns 或关闭错误代理。"
|
||||
"默认已忽略 HTTP(S)_PROXY,若需代理请设置 HTTPX_TRUST_ENV=1。"
|
||||
f" 原始错误: {e!s}"
|
||||
),
|
||||
) from e
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail=f"千问请求超时: {e!s}") from e
|
||||
_raise_for_qwen_http_error(resp, "千问请求失败")
|
||||
try:
|
||||
data = resp.json()
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=502, detail=f"千问返回非 JSON: {resp.text[:600]}")
|
||||
return _openai_completion_assistant_text(data)
|
||||
|
||||
|
||||
async def _call_qwen_vision(system_prompt: str, user_prompt: str, image_data_url: str) -> str:
|
||||
api_key = _get_qwen_api_key()
|
||||
if not api_key:
|
||||
return (
|
||||
"当前未配置千问 API Key,无法执行 OCR。\n"
|
||||
"请在 .env 中配置 QWEN_API_KEY 后重试。"
|
||||
)
|
||||
base_url = _qwen_base_url()
|
||||
model = os.getenv("QWEN_VL_MODEL", "qwen-vl-plus")
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
# 与 DashScope 文档一致:先图后文,利于多模态路由
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||
{"type": "text", "text": user_prompt},
|
||||
],
|
||||
},
|
||||
],
|
||||
"temperature": 0.2,
|
||||
}
|
||||
url = f"{base_url}/chat/completions"
|
||||
try:
|
||||
async with _qwen_http_client(60.0) as client:
|
||||
resp = await client.post(url, headers=headers, json=payload)
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
f"无法连接千问接口(OCR,{url})。请检查网络与 DNS;"
|
||||
"默认已忽略 HTTP(S)_PROXY,若需代理请设置 HTTPX_TRUST_ENV=1。"
|
||||
f" 原始错误: {e!s}"
|
||||
),
|
||||
) from e
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail=f"OCR 请求超时: {e!s}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"OCR 网络请求失败: {e!s}") from e
|
||||
_raise_for_qwen_http_error(resp, "OCR 请求失败")
|
||||
try:
|
||||
data = resp.json()
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=502, detail=f"OCR 返回非 JSON: {resp.text[:600]}")
|
||||
return _openai_completion_assistant_text(data)
|
||||
|
||||
|
||||
@app.get("/api/data/export")
|
||||
def export_user_data(
|
||||
format: str = Query("zip", pattern="^(zip|json)$"),
|
||||
include_files: bool = True,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
payload = _dump_all_data(db)
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if format == "json":
|
||||
buf = BytesIO(json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8"))
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.json"'},
|
||||
)
|
||||
|
||||
used_upload_files: set[str] = set()
|
||||
for item in payload["resources"]:
|
||||
name = _extract_upload_filename(item.get("url"))
|
||||
if name:
|
||||
used_upload_files.add(name)
|
||||
for item in payload["mistakes"]:
|
||||
name = _extract_upload_filename(item.get("image_url"))
|
||||
if name:
|
||||
used_upload_files.add(name)
|
||||
|
||||
buf = BytesIO()
|
||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zip_ref:
|
||||
zip_ref.writestr("data.json", json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
if include_files:
|
||||
for file_name in sorted(used_upload_files):
|
||||
path = UPLOAD_DIR / file_name
|
||||
if path.exists() and path.is_file():
|
||||
zip_ref.write(path, arcname=f"uploads/{file_name}")
|
||||
buf.seek(0)
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.zip"'},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/data/import")
|
||||
async def import_user_data(
|
||||
file: UploadFile = File(...),
|
||||
mode: str = Query("merge", pattern="^(merge|replace)$"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="导入文件为空")
|
||||
if len(content) > 100 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="导入文件不能超过 100MB")
|
||||
|
||||
suffix = Path(file.filename or "").suffix.lower()
|
||||
payload: dict
|
||||
zip_ref: zipfile.ZipFile | None = None
|
||||
|
||||
if suffix == ".json":
|
||||
try:
|
||||
payload = json.loads(content.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise HTTPException(status_code=400, detail=f"JSON 解析失败: {exc}") from exc
|
||||
elif suffix == ".zip":
|
||||
try:
|
||||
zip_ref = zipfile.ZipFile(BytesIO(content))
|
||||
except zipfile.BadZipFile as exc:
|
||||
raise HTTPException(status_code=400, detail="ZIP 文件损坏或格式错误") from exc
|
||||
if "data.json" not in zip_ref.namelist():
|
||||
raise HTTPException(status_code=400, detail="ZIP 中缺少 data.json")
|
||||
try:
|
||||
payload = json.loads(zip_ref.read("data.json").decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise HTTPException(status_code=400, detail=f"data.json 解析失败: {exc}") from exc
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="仅支持 .json 或 .zip 导入")
|
||||
|
||||
resources = payload.get("resources", [])
|
||||
mistakes = payload.get("mistakes", [])
|
||||
scores = payload.get("scores", [])
|
||||
if not isinstance(resources, list) or not isinstance(mistakes, list) or not isinstance(scores, list):
|
||||
raise HTTPException(status_code=400, detail="导入文件结构错误")
|
||||
|
||||
if mode == "replace":
|
||||
for item in db.scalars(select(Resource)).all():
|
||||
db.delete(item)
|
||||
for item in db.scalars(select(Mistake)).all():
|
||||
db.delete(item)
|
||||
for item in db.scalars(select(ScoreRecord)).all():
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
|
||||
imported = {"resources": 0, "mistakes": 0, "scores": 0}
|
||||
|
||||
for item in resources:
|
||||
url = item.get("url")
|
||||
if zip_ref is not None:
|
||||
url = _restore_upload_url_from_zip(url, zip_ref)
|
||||
obj = Resource(
|
||||
title=item.get("title") or "未命名资源",
|
||||
resource_type=item.get("resource_type") if item.get("resource_type") in {"link", "file"} else "link",
|
||||
url=url,
|
||||
file_name=item.get("file_name"),
|
||||
category=item.get("category") or "未分类",
|
||||
tags=item.get("tags"),
|
||||
created_at=_safe_datetime(item.get("created_at")),
|
||||
)
|
||||
db.add(obj)
|
||||
imported["resources"] += 1
|
||||
|
||||
for item in mistakes:
|
||||
image_url = item.get("image_url")
|
||||
if zip_ref is not None:
|
||||
image_url = _restore_upload_url_from_zip(image_url, zip_ref)
|
||||
difficulty = item.get("difficulty")
|
||||
obj = Mistake(
|
||||
title=item.get("title") or "未命名错题",
|
||||
image_url=image_url,
|
||||
category=item.get("category") or "其他",
|
||||
difficulty=difficulty if difficulty in {"easy", "medium", "hard"} else None,
|
||||
question_content=item.get("question_content"),
|
||||
answer=item.get("answer"),
|
||||
explanation=item.get("explanation"),
|
||||
note=item.get("note"),
|
||||
wrong_count=max(int(item.get("wrong_count") or 1), 1),
|
||||
created_at=_safe_datetime(item.get("created_at")),
|
||||
)
|
||||
db.add(obj)
|
||||
imported["mistakes"] += 1
|
||||
|
||||
for item in scores:
|
||||
score = float(item.get("total_score") or 0)
|
||||
obj = ScoreRecord(
|
||||
exam_name=item.get("exam_name") or "未命名考试",
|
||||
exam_date=_safe_date(item.get("exam_date")),
|
||||
total_score=max(min(score, 200), 0),
|
||||
module_scores=item.get("module_scores"),
|
||||
created_at=_safe_datetime(item.get("created_at")),
|
||||
)
|
||||
db.add(obj)
|
||||
imported["scores"] += 1
|
||||
|
||||
db.commit()
|
||||
return {"success": True, "mode": mode, "imported": imported}
|
||||
|
||||
|
||||
@app.post("/api/ocr/parse", response_model=OcrParseOut)
|
||||
async def parse_ocr(payload: OcrParseIn):
|
||||
file_name = _extract_upload_filename(payload.image_url)
|
||||
if not file_name:
|
||||
raise HTTPException(status_code=400, detail="仅支持 /uploads 下的图片做 OCR")
|
||||
target = UPLOAD_DIR / file_name
|
||||
if not target.exists() or not target.is_file():
|
||||
raise HTTPException(status_code=404, detail="图片不存在或已删除")
|
||||
|
||||
suffix = target.suffix.lower()
|
||||
mime = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
}.get(suffix)
|
||||
if not mime:
|
||||
raise HTTPException(status_code=400, detail="仅支持 JPG/PNG/WebP OCR")
|
||||
|
||||
b64 = base64.b64encode(target.read_bytes()).decode("utf-8")
|
||||
image_data_url = f"data:{mime};base64,{b64}"
|
||||
ocr_prompt = (
|
||||
"请识别图片中的题目,返回严格 JSON。"
|
||||
"字段说明:text 为整题完整纯文本(含材料、提问句、全部选项);"
|
||||
"question_content 必须与 text 一致地表示「完整题干」,须包含阅读材料、填空/提问句、所有选项(A B C D 等),"
|
||||
"禁止只填写「依次填入…」等短提示句而省略材料和选项。"
|
||||
"另含 title_suggestion、category_suggestion、difficulty_suggestion、answer、explanation。"
|
||||
"无法确认的字段可填空字符串。"
|
||||
)
|
||||
if payload.prompt:
|
||||
ocr_prompt = f"{ocr_prompt}\n补充要求:{payload.prompt}"
|
||||
|
||||
raw_text = await _call_qwen_vision(
|
||||
"你是公考题目OCR与结构化助手。输出必须是 JSON,不要额外解释。",
|
||||
ocr_prompt,
|
||||
image_data_url,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed = json.loads(_extract_json_text(raw_text))
|
||||
data = (
|
||||
parsed
|
||||
if isinstance(parsed, dict)
|
||||
else {
|
||||
"text": raw_text.strip(),
|
||||
"title_suggestion": None,
|
||||
"category_suggestion": None,
|
||||
"difficulty_suggestion": None,
|
||||
"question_content": raw_text.strip(),
|
||||
"answer": "",
|
||||
"explanation": "",
|
||||
}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
data = {
|
||||
"text": raw_text.strip(),
|
||||
"title_suggestion": None,
|
||||
"category_suggestion": None,
|
||||
"difficulty_suggestion": None,
|
||||
"question_content": raw_text.strip(),
|
||||
"answer": "",
|
||||
"explanation": "",
|
||||
}
|
||||
|
||||
def _opt_str(val: Any) -> str | None:
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (dict, list)):
|
||||
return None
|
||||
s = str(val).strip()
|
||||
return s if s else None
|
||||
|
||||
def _merge_question_body(text_raw: str, qc_raw: str | None) -> str | None:
|
||||
"""模型常把全文放在 text,却只把短问句放在 question_content;合并时以更长、更完整的文本为准。"""
|
||||
t = (text_raw or "").strip()
|
||||
q = (qc_raw or "").strip()
|
||||
if not t and not q:
|
||||
return None
|
||||
if not q:
|
||||
return t or None
|
||||
if not t:
|
||||
return q or None
|
||||
if len(t) > len(q):
|
||||
return t
|
||||
if len(q) > len(t):
|
||||
return q
|
||||
# 长度接近或相等:若一方包含另一方,取更长;否则保留 text(整页 OCR 通常更全)
|
||||
if t in q:
|
||||
return q
|
||||
if q in t:
|
||||
return t
|
||||
return t
|
||||
|
||||
text_out = str(data.get("text", "") or "").strip()
|
||||
qc_model = _opt_str(data.get("question_content"))
|
||||
question_merged = _merge_question_body(text_out, qc_model)
|
||||
|
||||
return OcrParseOut(
|
||||
text=text_out,
|
||||
title_suggestion=_opt_str(data.get("title_suggestion")),
|
||||
category_suggestion=_opt_str(data.get("category_suggestion")),
|
||||
difficulty_suggestion=_opt_str(data.get("difficulty_suggestion")),
|
||||
question_content=_opt_str(question_merged),
|
||||
answer=_opt_str(data.get("answer")),
|
||||
explanation=_opt_str(data.get("explanation")),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut)
|
||||
@@ -403,6 +1038,9 @@ async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
|
||||
f"错题标题: {item.title}\n"
|
||||
f"分类: {item.category}\n"
|
||||
f"难度: {item.difficulty or '未设置'}\n"
|
||||
f"题目内容: {item.question_content or '无'}\n"
|
||||
f"答案: {item.answer or '无'}\n"
|
||||
f"解析: {item.explanation or '无'}\n"
|
||||
f"错误频次: {item.wrong_count}\n"
|
||||
f"备注: {item.note or '无'}\n\n"
|
||||
"请按以下结构输出:\n"
|
||||
|
||||
Reference in New Issue
Block a user