1089 lines
40 KiB
Python
1089 lines
40 KiB
Python
import os
|
||
import json
|
||
import uuid
|
||
import zipfile
|
||
import base64
|
||
from typing import Any
|
||
from datetime import date, datetime, timedelta
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from docx import Document
|
||
from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from reportlab.lib.pagesizes import A4
|
||
from reportlab.pdfbase import pdfmetrics
|
||
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
|
||
from reportlab.pdfgen import canvas
|
||
from sqlalchemy import asc, desc, func, inspect, or_, select, text
|
||
from sqlalchemy.orm import Session
|
||
|
||
from .database import Base, engine, get_db
|
||
from .models import Mistake, Resource, ScoreRecord
|
||
from .schemas import (
|
||
AiMistakeAnalysisOut,
|
||
AiStudyPlanIn,
|
||
AiStudyPlanOut,
|
||
IdBatchPayload,
|
||
MistakeCreate,
|
||
MistakeOut,
|
||
MistakeUpdate,
|
||
OcrParseIn,
|
||
OcrParseOut,
|
||
ResourceBatchUpdate,
|
||
ResourceCreate,
|
||
ResourceOut,
|
||
ResourceUpdate,
|
||
ScoreRecordCreate,
|
||
ScoreRecordOut,
|
||
ScoreRecordUpdate,
|
||
ScoreStats,
|
||
)
|
||
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
|
||
def _migrate_mistake_columns() -> None:
|
||
inspector = inspect(engine)
|
||
if "mistakes" not in inspector.get_table_names():
|
||
return
|
||
existed = {col["name"] for col in inspector.get_columns("mistakes")}
|
||
with engine.begin() as conn:
|
||
if "question_content" not in existed:
|
||
conn.execute(text("ALTER TABLE mistakes ADD COLUMN question_content TEXT"))
|
||
if "answer" not in existed:
|
||
conn.execute(text("ALTER TABLE mistakes ADD COLUMN answer TEXT"))
|
||
if "explanation" not in existed:
|
||
conn.execute(text("ALTER TABLE mistakes ADD COLUMN explanation TEXT"))
|
||
|
||
|
||
_migrate_mistake_columns()
|
||
|
||
app = FastAPI(title="学习伙伴 API", version="1.0.0")
|
||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
||
|
||
origins = os.getenv("CORS_ORIGINS", "http://localhost:5173").split(",")
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=[origin.strip() for origin in origins if origin.strip()],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
return {"status": "ok"}
|
||
|
||
|
||
def _validate_score_date(exam_date: date) -> None:
|
||
if exam_date > date.today():
|
||
raise HTTPException(status_code=400, detail="考试时间不能晚于今天")
|
||
|
||
|
||
def _query_mistakes_for_export(
|
||
db: Session,
|
||
category: str | None,
|
||
start_date: date | None,
|
||
end_date: date | None,
|
||
ids: list[int] | None = None,
|
||
):
|
||
stmt = select(Mistake)
|
||
if category:
|
||
stmt = stmt.where(Mistake.category == category)
|
||
if start_date:
|
||
stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time()))
|
||
if end_date:
|
||
stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time()))
|
||
if ids:
|
||
stmt = stmt.where(Mistake.id.in_(ids))
|
||
items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all()
|
||
if len(items) > 200:
|
||
raise HTTPException(status_code=400, detail="单次最多导出 200 题")
|
||
return items
|
||
|
||
|
||
def _validate_mistake_payload(payload: MistakeCreate | MistakeUpdate) -> None:
|
||
has_image = bool((payload.image_url or "").strip())
|
||
has_question = bool((payload.question_content or "").strip())
|
||
has_answer = bool((payload.answer or "").strip())
|
||
if not has_image and not has_question and not has_answer:
|
||
raise HTTPException(status_code=400, detail="请上传题目图片或填写试题/答案后再保存")
|
||
|
||
|
||
def _normalize_multiline_text(value: str | None) -> str:
|
||
if not value:
|
||
return ""
|
||
text_value = value.replace("\r\n", "\n").replace("\r", "\n")
|
||
lines = [line.strip() for line in text_value.split("\n")]
|
||
compact = [line for line in lines if line]
|
||
return "\n".join(compact).strip()
|
||
|
||
|
||
def _wrap_pdf_text(text: str, max_width: float, font_name: str = "STSong-Light", font_size: int = 12) -> list[str]:
|
||
normalized = _normalize_multiline_text(text)
|
||
if not normalized:
|
||
return []
|
||
wrapped: list[str] = []
|
||
for raw_line in normalized.split("\n"):
|
||
current = ""
|
||
for ch in raw_line:
|
||
candidate = f"{current}{ch}"
|
||
if pdfmetrics.stringWidth(candidate, font_name, font_size) <= max_width:
|
||
current = candidate
|
||
else:
|
||
if current:
|
||
wrapped.append(current)
|
||
current = ch
|
||
if current:
|
||
wrapped.append(current)
|
||
return wrapped
|
||
|
||
|
||
def _mistake_export_blocks(item: Mistake, content_mode: str) -> list[str]:
|
||
question = _normalize_multiline_text(item.question_content)
|
||
answer = _normalize_multiline_text(item.answer)
|
||
explanation = _normalize_multiline_text(item.explanation)
|
||
if not question:
|
||
question = "无题干与选项内容"
|
||
|
||
blocks: list[str] = [question]
|
||
if content_mode == "full":
|
||
# 「答案:」「解析:」与正文同一行开头,避免标签单独成行(与题号+题干规则一致)
|
||
blocks.append(f"答案: {answer or '无'}")
|
||
blocks.append(f"解析: {explanation or '无'}")
|
||
return blocks
|
||
|
||
|
||
def _extract_upload_filename(url: str | None) -> str | None:
|
||
if not url or not url.startswith("/uploads/"):
|
||
return None
|
||
return Path(url).name
|
||
|
||
|
||
def _safe_datetime(value: str | None) -> datetime:
|
||
if not value:
|
||
return datetime.utcnow()
|
||
try:
|
||
return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None)
|
||
except ValueError:
|
||
return datetime.utcnow()
|
||
|
||
|
||
def _safe_date(value: str | None) -> date:
|
||
if not value:
|
||
return date.today()
|
||
try:
|
||
return date.fromisoformat(value)
|
||
except ValueError:
|
||
return date.today()
|
||
|
||
|
||
def _extract_json_text(raw_text: str) -> str:
|
||
content = raw_text.strip()
|
||
if content.startswith("```"):
|
||
lines = content.splitlines()
|
||
if lines:
|
||
lines = lines[1:]
|
||
if lines and lines[-1].strip() == "```":
|
||
lines = lines[:-1]
|
||
content = "\n".join(lines).strip()
|
||
return content
|
||
|
||
|
||
def _dump_all_data(db: Session) -> dict:
|
||
resources = db.scalars(select(Resource).order_by(asc(Resource.id))).all()
|
||
mistakes = db.scalars(select(Mistake).order_by(asc(Mistake.id))).all()
|
||
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.id))).all()
|
||
return {
|
||
"meta": {
|
||
"exported_at": datetime.utcnow().isoformat(),
|
||
"version": "1.1.0",
|
||
},
|
||
"resources": [
|
||
{
|
||
"id": item.id,
|
||
"title": item.title,
|
||
"resource_type": item.resource_type,
|
||
"url": item.url,
|
||
"file_name": item.file_name,
|
||
"category": item.category,
|
||
"tags": item.tags,
|
||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||
}
|
||
for item in resources
|
||
],
|
||
"mistakes": [
|
||
{
|
||
"id": item.id,
|
||
"title": item.title,
|
||
"image_url": item.image_url,
|
||
"category": item.category,
|
||
"difficulty": item.difficulty,
|
||
"question_content": item.question_content,
|
||
"answer": item.answer,
|
||
"explanation": item.explanation,
|
||
"note": item.note,
|
||
"wrong_count": item.wrong_count,
|
||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||
}
|
||
for item in mistakes
|
||
],
|
||
"scores": [
|
||
{
|
||
"id": item.id,
|
||
"exam_name": item.exam_name,
|
||
"exam_date": item.exam_date.isoformat() if item.exam_date else None,
|
||
"total_score": item.total_score,
|
||
"module_scores": item.module_scores,
|
||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||
}
|
||
for item in scores
|
||
],
|
||
}
|
||
|
||
|
||
def _restore_upload_url_from_zip(url: str | None, zip_ref: zipfile.ZipFile) -> str | None:
|
||
if not url:
|
||
return None
|
||
file_name = _extract_upload_filename(url)
|
||
if not file_name:
|
||
return url
|
||
|
||
zip_path = f"uploads/{file_name}"
|
||
if zip_path not in zip_ref.namelist():
|
||
return url
|
||
|
||
data = zip_ref.read(zip_path)
|
||
target_name = file_name
|
||
target_path = UPLOAD_DIR / target_name
|
||
if target_path.exists():
|
||
target_name = f"{uuid.uuid4().hex}_{file_name}"
|
||
target_path = UPLOAD_DIR / target_name
|
||
target_path.write_bytes(data)
|
||
return f"/uploads/{target_name}"
|
||
|
||
|
||
@app.post("/api/upload")
|
||
async def upload_file(file: UploadFile = File(...)):
|
||
suffix = Path(file.filename or "").suffix.lower()
|
||
allowed = {".pdf", ".doc", ".docx", ".jpg", ".jpeg", ".png", ".webp", ".heic", ".heif"}
|
||
mime_to_suffix = {
|
||
"application/pdf": ".pdf",
|
||
"application/msword": ".doc",
|
||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||
"image/jpeg": ".jpg",
|
||
"image/png": ".png",
|
||
"image/webp": ".webp",
|
||
"image/heic": ".heic",
|
||
"image/heif": ".heif",
|
||
}
|
||
if suffix not in allowed:
|
||
guessed = mime_to_suffix.get((file.content_type or "").lower())
|
||
if guessed:
|
||
suffix = guessed
|
||
else:
|
||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||
content = await file.read()
|
||
if len(content) > 50 * 1024 * 1024:
|
||
raise HTTPException(status_code=400, detail="文件不能超过 50MB")
|
||
file_name = f"{uuid.uuid4().hex}{suffix}"
|
||
target = UPLOAD_DIR / file_name
|
||
target.write_bytes(content)
|
||
return {"url": f"/uploads/{file_name}", "original_name": file.filename}
|
||
|
||
|
||
@app.get("/api/resources", response_model=list[ResourceOut])
|
||
def list_resources(
|
||
q: str | None = None,
|
||
category: str | None = None,
|
||
tags: str | None = None,
|
||
resource_type: str | None = None,
|
||
sort_by: str = Query("created_at", pattern="^(created_at|title|name)$"),
|
||
order: str = Query("desc", pattern="^(asc|desc)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
stmt = select(Resource)
|
||
if q:
|
||
stmt = stmt.where(
|
||
or_(
|
||
Resource.title.ilike(f"%{q}%"),
|
||
Resource.tags.ilike(f"%{q}%"),
|
||
Resource.url.ilike(f"%{q}%"),
|
||
)
|
||
)
|
||
if category:
|
||
stmt = stmt.where(Resource.category == category)
|
||
if tags:
|
||
stmt = stmt.where(Resource.tags.ilike(f"%{tags}%"))
|
||
if resource_type:
|
||
stmt = stmt.where(Resource.resource_type == resource_type)
|
||
|
||
sort_col = Resource.created_at if sort_by == "created_at" else Resource.title
|
||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||
return db.scalars(stmt).all()
|
||
|
||
|
||
@app.post("/api/resources", response_model=ResourceOut)
|
||
def create_resource(payload: ResourceCreate, db: Session = Depends(get_db)):
|
||
item = Resource(**payload.model_dump())
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.put("/api/resources/{item_id}", response_model=ResourceOut)
|
||
def update_resource(item_id: int, payload: ResourceUpdate, db: Session = Depends(get_db)):
|
||
item = db.get(Resource, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Resource not found")
|
||
for k, v in payload.model_dump().items():
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.delete("/api/resources/{item_id}")
|
||
def delete_resource(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(Resource, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Resource not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True}
|
||
|
||
|
||
@app.patch("/api/resources/batch")
|
||
def batch_update_resources(payload: ResourceBatchUpdate, db: Session = Depends(get_db)):
|
||
if not payload.ids:
|
||
raise HTTPException(status_code=400, detail="ids 不能为空")
|
||
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
|
||
for item in items:
|
||
if payload.category is not None:
|
||
item.category = payload.category
|
||
if payload.tags is not None:
|
||
item.tags = payload.tags
|
||
db.commit()
|
||
return {"success": True, "count": len(items)}
|
||
|
||
|
||
@app.post("/api/resources/batch-delete")
|
||
def batch_delete_resources(payload: IdBatchPayload, db: Session = Depends(get_db)):
|
||
if not payload.ids:
|
||
raise HTTPException(status_code=400, detail="ids 不能为空")
|
||
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
|
||
for item in items:
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True, "count": len(items)}
|
||
|
||
|
||
@app.get("/api/mistakes", response_model=list[MistakeOut])
|
||
def list_mistakes(
|
||
category: str | None = None,
|
||
keyword: str | None = None,
|
||
sort_by: str = Query("created_at", pattern="^(created_at|wrong_count)$"),
|
||
order: str = Query("desc", pattern="^(asc|desc)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
stmt = select(Mistake)
|
||
if category:
|
||
stmt = stmt.where(Mistake.category == category)
|
||
if keyword:
|
||
stmt = stmt.where(
|
||
or_(
|
||
Mistake.note.ilike(f"%{keyword}%"),
|
||
Mistake.title.ilike(f"%{keyword}%"),
|
||
Mistake.question_content.ilike(f"%{keyword}%"),
|
||
Mistake.answer.ilike(f"%{keyword}%"),
|
||
Mistake.explanation.ilike(f"%{keyword}%"),
|
||
Mistake.image_url.ilike(f"%{keyword}%"),
|
||
)
|
||
)
|
||
sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count
|
||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||
return db.scalars(stmt).all()
|
||
|
||
|
||
@app.post("/api/mistakes", response_model=MistakeOut)
|
||
def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
|
||
_validate_mistake_payload(payload)
|
||
item = Mistake(**payload.model_dump())
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.put("/api/mistakes/{item_id}", response_model=MistakeOut)
|
||
def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)):
|
||
_validate_mistake_payload(payload)
|
||
item = db.get(Mistake, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||
for k, v in payload.model_dump().items():
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.delete("/api/mistakes/{item_id}")
|
||
def delete_mistake(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(Mistake, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True}
|
||
|
||
|
||
@app.get("/api/mistakes/export/pdf")
|
||
def export_mistakes_pdf(
|
||
category: str | None = None,
|
||
start_date: date | None = None,
|
||
end_date: date | None = None,
|
||
ids: str | None = None,
|
||
content_mode: str = Query("full", pattern="^(full|question_only)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
|
||
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
|
||
buf = BytesIO()
|
||
pdf = canvas.Canvas(buf, pagesize=A4)
|
||
pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light"))
|
||
pdf.setFont("STSong-Light", 12)
|
||
|
||
y = 800
|
||
left = 48
|
||
right = 560
|
||
max_width = right - left
|
||
pdf.drawString(left, y, "学习伙伴 - 错题导出")
|
||
y -= 28
|
||
for idx, item in enumerate(items, start=1):
|
||
if y < 90:
|
||
pdf.showPage()
|
||
pdf.setFont("STSong-Light", 12)
|
||
y = 800
|
||
blocks = _mistake_export_blocks(item, content_mode)
|
||
for bi, block in enumerate(blocks):
|
||
# 题号与题干同一行开头,避免「1.」单独成行
|
||
text = f"{idx}. {block}" if bi == 0 else block
|
||
lines = _wrap_pdf_text(text, max_width=max_width)
|
||
if not lines:
|
||
continue
|
||
for line in lines:
|
||
if y < 70:
|
||
pdf.showPage()
|
||
pdf.setFont("STSong-Light", 12)
|
||
y = 800
|
||
pdf.drawString(left, y, line)
|
||
y -= 18
|
||
y -= 6
|
||
y -= 8
|
||
|
||
pdf.save()
|
||
buf.seek(0)
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/pdf",
|
||
headers={"Content-Disposition": 'attachment; filename="mistakes.pdf"'},
|
||
)
|
||
|
||
|
||
@app.get("/api/mistakes/export/docx")
|
||
def export_mistakes_docx(
|
||
category: str | None = None,
|
||
start_date: date | None = None,
|
||
end_date: date | None = None,
|
||
ids: str | None = None,
|
||
content_mode: str = Query("full", pattern="^(full|question_only)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None
|
||
items = _query_mistakes_for_export(db, category, start_date, end_date, id_list)
|
||
doc = Document()
|
||
doc.add_heading("学习伙伴 - 错题导出", level=1)
|
||
for idx, item in enumerate(items, start=1):
|
||
blocks = _mistake_export_blocks(item, content_mode)
|
||
for bi, block in enumerate(blocks):
|
||
# 题号与题干同段,避免单独一行只有「1.」
|
||
para = f"{idx}. {block}" if bi == 0 else block
|
||
doc.add_paragraph(para)
|
||
|
||
buf = BytesIO()
|
||
doc.save(buf)
|
||
buf.seek(0)
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
headers={"Content-Disposition": 'attachment; filename="mistakes.docx"'},
|
||
)
|
||
|
||
|
||
@app.get("/api/scores", response_model=list[ScoreRecordOut])
|
||
def list_scores(
|
||
start_date: date | None = None,
|
||
end_date: date | None = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
stmt = select(ScoreRecord)
|
||
if start_date:
|
||
stmt = stmt.where(ScoreRecord.exam_date >= start_date)
|
||
if end_date:
|
||
stmt = stmt.where(ScoreRecord.exam_date <= end_date)
|
||
stmt = stmt.order_by(asc(ScoreRecord.exam_date))
|
||
return db.scalars(stmt).all()
|
||
|
||
|
||
@app.post("/api/scores", response_model=ScoreRecordOut)
|
||
def create_score(payload: ScoreRecordCreate, db: Session = Depends(get_db)):
|
||
_validate_score_date(payload.exam_date)
|
||
item = ScoreRecord(**payload.model_dump())
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.put("/api/scores/{item_id}", response_model=ScoreRecordOut)
|
||
def update_score(item_id: int, payload: ScoreRecordUpdate, db: Session = Depends(get_db)):
|
||
_validate_score_date(payload.exam_date)
|
||
item = db.get(ScoreRecord, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Score record not found")
|
||
for k, v in payload.model_dump().items():
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.delete("/api/scores/{item_id}")
|
||
def delete_score(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(ScoreRecord, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Score record not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True}
|
||
|
||
|
||
@app.get("/api/scores/stats", response_model=ScoreStats)
|
||
def score_stats(db: Session = Depends(get_db)):
|
||
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.exam_date))).all()
|
||
if not scores:
|
||
return ScoreStats(highest=0, lowest=0, average=0, improvement=0)
|
||
|
||
highest = max(item.total_score for item in scores)
|
||
lowest = min(item.total_score for item in scores)
|
||
avg = db.scalar(select(func.avg(ScoreRecord.total_score))) or 0
|
||
improvement = scores[-1].total_score - scores[0].total_score
|
||
return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement)
|
||
|
||
|
||
def _qwen_base_url() -> str:
|
||
return os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1").strip().rstrip("/")
|
||
|
||
|
||
def _get_qwen_api_key() -> str:
|
||
"""去除首尾空白与常见误加的引号,避免 .env 里写成 'sk-xxx' 导致鉴权失败。"""
|
||
raw = os.getenv("QWEN_API_KEY", "") or ""
|
||
return raw.strip().strip('"').strip("'").strip()
|
||
|
||
|
||
def _raise_for_qwen_http_error(resp: httpx.Response, prefix: str) -> None:
|
||
"""HTTP 非 2xx 时解析 DashScope 错误体,对 invalid_api_key 返回 401 + 明确说明。"""
|
||
if resp.status_code < 300:
|
||
return
|
||
text = resp.text
|
||
try:
|
||
data = resp.json()
|
||
err = data.get("error")
|
||
if isinstance(err, dict):
|
||
code = str(err.get("code") or "")
|
||
if code == "invalid_api_key":
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail=(
|
||
"阿里云 DashScope API Key 无效或未生效。"
|
||
"请到阿里云百炼 / Model Studio 控制台创建 API Key(通常以 sk- 开头),"
|
||
"写入项目根目录 .env 的 QWEN_API_KEY=,勿加引号;"
|
||
"修改后执行: docker compose up -d --build backend"
|
||
),
|
||
)
|
||
msg = err.get("message") or text
|
||
raise HTTPException(status_code=502, detail=f"{prefix}: {msg}")
|
||
except HTTPException:
|
||
raise
|
||
except (ValueError, TypeError, KeyError):
|
||
pass
|
||
raise HTTPException(status_code=502, detail=f"{prefix}: {text[:1200]}")
|
||
|
||
|
||
def _httpx_trust_env() -> bool:
|
||
"""默认不信任环境变量中的代理,避免 Docker/IDE 注入空代理导致 ConnectError;需走系统代理时设 HTTPX_TRUST_ENV=1。"""
|
||
return os.getenv("HTTPX_TRUST_ENV", "0").lower() in ("1", "true", "yes")
|
||
|
||
|
||
def _qwen_http_client(timeout_sec: float = 60.0) -> httpx.AsyncClient:
|
||
return httpx.AsyncClient(
|
||
timeout=httpx.Timeout(timeout_sec, connect=20.0),
|
||
trust_env=_httpx_trust_env(),
|
||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||
)
|
||
|
||
|
||
def _message_content_to_str(content: Any) -> str:
|
||
"""OpenAI 兼容接口里 message.content 可能是 str 或多段结构。"""
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for part in content:
|
||
if isinstance(part, dict):
|
||
if part.get("type") == "text" and "text" in part:
|
||
parts.append(str(part["text"]))
|
||
elif "text" in part:
|
||
parts.append(str(part["text"]))
|
||
elif isinstance(part, str):
|
||
parts.append(part)
|
||
return "".join(parts)
|
||
return str(content)
|
||
|
||
|
||
def _openai_completion_assistant_text(data: dict) -> str:
|
||
"""从 chat/completions JSON 中取出助手文本;若含 error 或无 choices 则抛错。"""
|
||
err = data.get("error")
|
||
if err is not None:
|
||
if isinstance(err, dict):
|
||
code = str(err.get("code") or "")
|
||
if code == "invalid_api_key":
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail=(
|
||
"阿里云 DashScope API Key 无效。"
|
||
"请在 .env 中填写正确的 QWEN_API_KEY 并重启 backend。"
|
||
),
|
||
)
|
||
msg = err.get("message") or err.get("code") or json.dumps(err, ensure_ascii=False)
|
||
else:
|
||
msg = str(err)
|
||
raise HTTPException(status_code=502, detail=f"千问接口错误: {msg}")
|
||
choices = data.get("choices")
|
||
if not choices:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"千问返回异常(无 choices),请检查模型名与权限。原始片段: {json.dumps(data, ensure_ascii=False)[:800]}",
|
||
)
|
||
msg = choices[0].get("message") or {}
|
||
return _message_content_to_str(msg.get("content"))
|
||
|
||
|
||
async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
|
||
api_key = _get_qwen_api_key()
|
||
if not api_key:
|
||
return (
|
||
"当前未配置千问 API Key,已返回本地降级提示。\n"
|
||
"请在项目根目录创建 .env 并配置 QWEN_API_KEY 后重启服务。\n"
|
||
"示例:\n"
|
||
"QWEN_API_KEY=your_qwen_api_key_here\n"
|
||
"QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n"
|
||
"QWEN_MODEL=qwen-plus"
|
||
)
|
||
base_url = _qwen_base_url()
|
||
model = os.getenv("QWEN_MODEL", "qwen-plus")
|
||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||
payload = {
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": 0.4,
|
||
}
|
||
url = f"{base_url}/chat/completions"
|
||
try:
|
||
async with _qwen_http_client(40.0) as client:
|
||
resp = await client.post(url, headers=headers, json=payload)
|
||
except httpx.ConnectError as e:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=(
|
||
f"无法连接千问接口({url})。请检查本机/容器能否访问外网、DNS 是否正常;"
|
||
"若在 Docker 中可尝试为 backend 配置 dns 或关闭错误代理。"
|
||
"默认已忽略 HTTP(S)_PROXY,若需代理请设置 HTTPX_TRUST_ENV=1。"
|
||
f" 原始错误: {e!s}"
|
||
),
|
||
) from e
|
||
except httpx.TimeoutException as e:
|
||
raise HTTPException(status_code=504, detail=f"千问请求超时: {e!s}") from e
|
||
_raise_for_qwen_http_error(resp, "千问请求失败")
|
||
try:
|
||
data = resp.json()
|
||
except ValueError:
|
||
raise HTTPException(status_code=502, detail=f"千问返回非 JSON: {resp.text[:600]}")
|
||
return _openai_completion_assistant_text(data)
|
||
|
||
|
||
async def _call_qwen_vision(system_prompt: str, user_prompt: str, image_data_url: str) -> str:
|
||
api_key = _get_qwen_api_key()
|
||
if not api_key:
|
||
return (
|
||
"当前未配置千问 API Key,无法执行 OCR。\n"
|
||
"请在 .env 中配置 QWEN_API_KEY 后重试。"
|
||
)
|
||
base_url = _qwen_base_url()
|
||
model = os.getenv("QWEN_VL_MODEL", "qwen-vl-plus")
|
||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||
# 与 DashScope 文档一致:先图后文,利于多模态路由
|
||
payload = {
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||
{"type": "text", "text": user_prompt},
|
||
],
|
||
},
|
||
],
|
||
"temperature": 0.2,
|
||
}
|
||
url = f"{base_url}/chat/completions"
|
||
try:
|
||
async with _qwen_http_client(60.0) as client:
|
||
resp = await client.post(url, headers=headers, json=payload)
|
||
except httpx.ConnectError as e:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=(
|
||
f"无法连接千问接口(OCR,{url})。请检查网络与 DNS;"
|
||
"默认已忽略 HTTP(S)_PROXY,若需代理请设置 HTTPX_TRUST_ENV=1。"
|
||
f" 原始错误: {e!s}"
|
||
),
|
||
) from e
|
||
except httpx.TimeoutException as e:
|
||
raise HTTPException(status_code=504, detail=f"OCR 请求超时: {e!s}") from e
|
||
except httpx.RequestError as e:
|
||
raise HTTPException(status_code=502, detail=f"OCR 网络请求失败: {e!s}") from e
|
||
_raise_for_qwen_http_error(resp, "OCR 请求失败")
|
||
try:
|
||
data = resp.json()
|
||
except ValueError:
|
||
raise HTTPException(status_code=502, detail=f"OCR 返回非 JSON: {resp.text[:600]}")
|
||
return _openai_completion_assistant_text(data)
|
||
|
||
|
||
@app.get("/api/data/export")
|
||
def export_user_data(
|
||
format: str = Query("zip", pattern="^(zip|json)$"),
|
||
include_files: bool = True,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
payload = _dump_all_data(db)
|
||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||
|
||
if format == "json":
|
||
buf = BytesIO(json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8"))
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/json",
|
||
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.json"'},
|
||
)
|
||
|
||
used_upload_files: set[str] = set()
|
||
for item in payload["resources"]:
|
||
name = _extract_upload_filename(item.get("url"))
|
||
if name:
|
||
used_upload_files.add(name)
|
||
for item in payload["mistakes"]:
|
||
name = _extract_upload_filename(item.get("image_url"))
|
||
if name:
|
||
used_upload_files.add(name)
|
||
|
||
buf = BytesIO()
|
||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zip_ref:
|
||
zip_ref.writestr("data.json", json.dumps(payload, ensure_ascii=False, indent=2))
|
||
if include_files:
|
||
for file_name in sorted(used_upload_files):
|
||
path = UPLOAD_DIR / file_name
|
||
if path.exists() and path.is_file():
|
||
zip_ref.write(path, arcname=f"uploads/{file_name}")
|
||
buf.seek(0)
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/zip",
|
||
headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.zip"'},
|
||
)
|
||
|
||
|
||
@app.post("/api/data/import")
|
||
async def import_user_data(
|
||
file: UploadFile = File(...),
|
||
mode: str = Query("merge", pattern="^(merge|replace)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
content = await file.read()
|
||
if not content:
|
||
raise HTTPException(status_code=400, detail="导入文件为空")
|
||
if len(content) > 100 * 1024 * 1024:
|
||
raise HTTPException(status_code=400, detail="导入文件不能超过 100MB")
|
||
|
||
suffix = Path(file.filename or "").suffix.lower()
|
||
payload: dict
|
||
zip_ref: zipfile.ZipFile | None = None
|
||
|
||
if suffix == ".json":
|
||
try:
|
||
payload = json.loads(content.decode("utf-8"))
|
||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||
raise HTTPException(status_code=400, detail=f"JSON 解析失败: {exc}") from exc
|
||
elif suffix == ".zip":
|
||
try:
|
||
zip_ref = zipfile.ZipFile(BytesIO(content))
|
||
except zipfile.BadZipFile as exc:
|
||
raise HTTPException(status_code=400, detail="ZIP 文件损坏或格式错误") from exc
|
||
if "data.json" not in zip_ref.namelist():
|
||
raise HTTPException(status_code=400, detail="ZIP 中缺少 data.json")
|
||
try:
|
||
payload = json.loads(zip_ref.read("data.json").decode("utf-8"))
|
||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||
raise HTTPException(status_code=400, detail=f"data.json 解析失败: {exc}") from exc
|
||
else:
|
||
raise HTTPException(status_code=400, detail="仅支持 .json 或 .zip 导入")
|
||
|
||
resources = payload.get("resources", [])
|
||
mistakes = payload.get("mistakes", [])
|
||
scores = payload.get("scores", [])
|
||
if not isinstance(resources, list) or not isinstance(mistakes, list) or not isinstance(scores, list):
|
||
raise HTTPException(status_code=400, detail="导入文件结构错误")
|
||
|
||
if mode == "replace":
|
||
for item in db.scalars(select(Resource)).all():
|
||
db.delete(item)
|
||
for item in db.scalars(select(Mistake)).all():
|
||
db.delete(item)
|
||
for item in db.scalars(select(ScoreRecord)).all():
|
||
db.delete(item)
|
||
db.commit()
|
||
|
||
imported = {"resources": 0, "mistakes": 0, "scores": 0}
|
||
|
||
for item in resources:
|
||
url = item.get("url")
|
||
if zip_ref is not None:
|
||
url = _restore_upload_url_from_zip(url, zip_ref)
|
||
obj = Resource(
|
||
title=item.get("title") or "未命名资源",
|
||
resource_type=item.get("resource_type") if item.get("resource_type") in {"link", "file"} else "link",
|
||
url=url,
|
||
file_name=item.get("file_name"),
|
||
category=item.get("category") or "未分类",
|
||
tags=item.get("tags"),
|
||
created_at=_safe_datetime(item.get("created_at")),
|
||
)
|
||
db.add(obj)
|
||
imported["resources"] += 1
|
||
|
||
for item in mistakes:
|
||
image_url = item.get("image_url")
|
||
if zip_ref is not None:
|
||
image_url = _restore_upload_url_from_zip(image_url, zip_ref)
|
||
difficulty = item.get("difficulty")
|
||
obj = Mistake(
|
||
title=item.get("title") or "未命名错题",
|
||
image_url=image_url,
|
||
category=item.get("category") or "其他",
|
||
difficulty=difficulty if difficulty in {"easy", "medium", "hard"} else None,
|
||
question_content=item.get("question_content"),
|
||
answer=item.get("answer"),
|
||
explanation=item.get("explanation"),
|
||
note=item.get("note"),
|
||
wrong_count=max(int(item.get("wrong_count") or 1), 1),
|
||
created_at=_safe_datetime(item.get("created_at")),
|
||
)
|
||
db.add(obj)
|
||
imported["mistakes"] += 1
|
||
|
||
for item in scores:
|
||
score = float(item.get("total_score") or 0)
|
||
obj = ScoreRecord(
|
||
exam_name=item.get("exam_name") or "未命名考试",
|
||
exam_date=_safe_date(item.get("exam_date")),
|
||
total_score=max(min(score, 200), 0),
|
||
module_scores=item.get("module_scores"),
|
||
created_at=_safe_datetime(item.get("created_at")),
|
||
)
|
||
db.add(obj)
|
||
imported["scores"] += 1
|
||
|
||
db.commit()
|
||
return {"success": True, "mode": mode, "imported": imported}
|
||
|
||
|
||
@app.post("/api/ocr/parse", response_model=OcrParseOut)
|
||
async def parse_ocr(payload: OcrParseIn):
|
||
file_name = _extract_upload_filename(payload.image_url)
|
||
if not file_name:
|
||
raise HTTPException(status_code=400, detail="仅支持 /uploads 下的图片做 OCR")
|
||
target = UPLOAD_DIR / file_name
|
||
if not target.exists() or not target.is_file():
|
||
raise HTTPException(status_code=404, detail="图片不存在或已删除")
|
||
|
||
suffix = target.suffix.lower()
|
||
mime = {
|
||
".jpg": "image/jpeg",
|
||
".jpeg": "image/jpeg",
|
||
".png": "image/png",
|
||
".webp": "image/webp",
|
||
".heic": "image/heic",
|
||
".heif": "image/heif",
|
||
}.get(suffix)
|
||
if not mime:
|
||
raise HTTPException(status_code=400, detail="仅支持 JPG/PNG/WebP/HEIC OCR")
|
||
|
||
b64 = base64.b64encode(target.read_bytes()).decode("utf-8")
|
||
image_data_url = f"data:{mime};base64,{b64}"
|
||
ocr_prompt = (
|
||
"请识别图片中的题目,返回严格 JSON。"
|
||
"字段说明:text 为整题完整纯文本(含材料、提问句、全部选项);"
|
||
"question_content 必须与 text 一致地表示「完整题干」,须包含阅读材料、填空/提问句、所有选项(A B C D 等),"
|
||
"禁止只填写「依次填入…」等短提示句而省略材料和选项。"
|
||
"另含 title_suggestion、category_suggestion、difficulty_suggestion、answer、explanation。"
|
||
"无法确认的字段可填空字符串。"
|
||
)
|
||
if payload.prompt:
|
||
ocr_prompt = f"{ocr_prompt}\n补充要求:{payload.prompt}"
|
||
|
||
raw_text = await _call_qwen_vision(
|
||
"你是题目 OCR 与结构化助手。输出必须是 JSON,不要额外解释。",
|
||
ocr_prompt,
|
||
image_data_url,
|
||
)
|
||
|
||
try:
|
||
parsed = json.loads(_extract_json_text(raw_text))
|
||
data = (
|
||
parsed
|
||
if isinstance(parsed, dict)
|
||
else {
|
||
"text": raw_text.strip(),
|
||
"title_suggestion": None,
|
||
"category_suggestion": None,
|
||
"difficulty_suggestion": None,
|
||
"question_content": raw_text.strip(),
|
||
"answer": "",
|
||
"explanation": "",
|
||
}
|
||
)
|
||
except json.JSONDecodeError:
|
||
data = {
|
||
"text": raw_text.strip(),
|
||
"title_suggestion": None,
|
||
"category_suggestion": None,
|
||
"difficulty_suggestion": None,
|
||
"question_content": raw_text.strip(),
|
||
"answer": "",
|
||
"explanation": "",
|
||
}
|
||
|
||
def _opt_str(val: Any) -> str | None:
|
||
if val is None:
|
||
return None
|
||
if isinstance(val, (dict, list)):
|
||
return None
|
||
s = str(val).strip()
|
||
return s if s else None
|
||
|
||
def _merge_question_body(text_raw: str, qc_raw: str | None) -> str | None:
|
||
"""模型常把全文放在 text,却只把短问句放在 question_content;合并时以更长、更完整的文本为准。"""
|
||
t = (text_raw or "").strip()
|
||
q = (qc_raw or "").strip()
|
||
if not t and not q:
|
||
return None
|
||
if not q:
|
||
return t or None
|
||
if not t:
|
||
return q or None
|
||
if len(t) > len(q):
|
||
return t
|
||
if len(q) > len(t):
|
||
return q
|
||
# 长度接近或相等:若一方包含另一方,取更长;否则保留 text(整页 OCR 通常更全)
|
||
if t in q:
|
||
return q
|
||
if q in t:
|
||
return t
|
||
return t
|
||
|
||
text_out = str(data.get("text", "") or "").strip()
|
||
qc_model = _opt_str(data.get("question_content"))
|
||
question_merged = _merge_question_body(text_out, qc_model)
|
||
|
||
return OcrParseOut(
|
||
text=text_out,
|
||
title_suggestion=_opt_str(data.get("title_suggestion")),
|
||
category_suggestion=_opt_str(data.get("category_suggestion")),
|
||
difficulty_suggestion=_opt_str(data.get("difficulty_suggestion")),
|
||
question_content=_opt_str(question_merged),
|
||
answer=_opt_str(data.get("answer")),
|
||
explanation=_opt_str(data.get("explanation")),
|
||
)
|
||
|
||
|
||
@app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut)
|
||
async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(Mistake, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||
content = await _call_qwen(
|
||
"你是学习教练,请输出结构化、可执行的错题分析。",
|
||
(
|
||
f"错题标题: {item.title}\n"
|
||
f"分类: {item.category}\n"
|
||
f"难度: {item.difficulty or '未设置'}\n"
|
||
f"题目内容: {item.question_content or '无'}\n"
|
||
f"答案: {item.answer or '无'}\n"
|
||
f"解析: {item.explanation or '无'}\n"
|
||
f"错误频次: {item.wrong_count}\n"
|
||
f"备注: {item.note or '无'}\n\n"
|
||
"请按以下结构输出:\n"
|
||
"1) 错误根因\n2) 关键知识点\n3) 3步复盘法\n4) 3道同类训练建议\n5) 明日复习安排"
|
||
),
|
||
)
|
||
return AiMistakeAnalysisOut(analysis=content)
|
||
|
||
|
||
@app.post("/api/ai/study-plan", response_model=AiStudyPlanOut)
|
||
async def ai_study_plan(payload: AiStudyPlanIn, db: Session = Depends(get_db)):
|
||
since = date.today() - timedelta(days=30)
|
||
recent_scores = db.scalars(select(ScoreRecord).where(ScoreRecord.exam_date >= since).order_by(asc(ScoreRecord.exam_date))).all()
|
||
recent_mistakes = db.scalars(select(Mistake).order_by(desc(Mistake.created_at)).limit(20)).all()
|
||
score_text = ", ".join([f"{s.exam_date}:{s.total_score}" for s in recent_scores]) or "暂无成绩数据"
|
||
mistake_text = ", ".join([f"{m.category}-{m.title}" for m in recent_mistakes]) or "暂无错题数据"
|
||
|
||
content = await _call_qwen(
|
||
"你是学习规划师,请给出可执行计划并尽量量化。",
|
||
(
|
||
f"目标: {payload.goal}\n"
|
||
f"剩余天数: {payload.days_left}\n"
|
||
f"每天可学习小时数: {payload.daily_hours}\n"
|
||
f"近30天分数: {score_text}\n"
|
||
f"近期错题: {mistake_text}\n\n"
|
||
"请输出: 周计划表、每日任务模板、错题复盘节奏、模考安排、风险提醒。"
|
||
),
|
||
)
|
||
return AiStudyPlanOut(plan=content)
|