import os import json import uuid import zipfile import base64 from typing import Any from datetime import date, datetime, timedelta from io import BytesIO from pathlib import Path import httpx from docx import Document from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from reportlab.lib.pagesizes import A4 from reportlab.pdfbase import pdfmetrics from reportlab.pdfbase.cidfonts import UnicodeCIDFont from reportlab.pdfgen import canvas from sqlalchemy import asc, desc, func, inspect, or_, select, text from sqlalchemy.orm import Session from .database import Base, engine, get_db from .models import Mistake, Resource, ScoreRecord from .schemas import ( AiMistakeAnalysisOut, AiStudyPlanIn, AiStudyPlanOut, IdBatchPayload, MistakeCreate, MistakeOut, MistakeUpdate, OcrParseIn, OcrParseOut, ResourceBatchUpdate, ResourceCreate, ResourceOut, ResourceUpdate, ScoreRecordCreate, ScoreRecordOut, ScoreRecordUpdate, ScoreStats, ) Base.metadata.create_all(bind=engine) def _migrate_mistake_columns() -> None: inspector = inspect(engine) if "mistakes" not in inspector.get_table_names(): return existed = {col["name"] for col in inspector.get_columns("mistakes")} with engine.begin() as conn: if "question_content" not in existed: conn.execute(text("ALTER TABLE mistakes ADD COLUMN question_content TEXT")) if "answer" not in existed: conn.execute(text("ALTER TABLE mistakes ADD COLUMN answer TEXT")) if "explanation" not in existed: conn.execute(text("ALTER TABLE mistakes ADD COLUMN explanation TEXT")) _migrate_mistake_columns() app = FastAPI(title="学习伙伴 API", version="1.0.0") UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads")) UPLOAD_DIR.mkdir(parents=True, exist_ok=True) app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") origins = os.getenv("CORS_ORIGINS", "http://localhost:5173").split(",") app.add_middleware( CORSMiddleware, allow_origins=[origin.strip() for origin in origins if origin.strip()], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") def health(): return {"status": "ok"} def _validate_score_date(exam_date: date) -> None: if exam_date > date.today(): raise HTTPException(status_code=400, detail="考试时间不能晚于今天") def _query_mistakes_for_export( db: Session, category: str | None, start_date: date | None, end_date: date | None, ids: list[int] | None = None, ): stmt = select(Mistake) if category: stmt = stmt.where(Mistake.category == category) if start_date: stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time())) if end_date: stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time())) if ids: stmt = stmt.where(Mistake.id.in_(ids)) items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all() if len(items) > 200: raise HTTPException(status_code=400, detail="单次最多导出 200 题") return items def _validate_mistake_payload(payload: MistakeCreate | MistakeUpdate) -> None: has_image = bool((payload.image_url or "").strip()) has_question = bool((payload.question_content or "").strip()) has_answer = bool((payload.answer or "").strip()) if not has_image and not has_question and not has_answer: raise HTTPException(status_code=400, detail="请上传题目图片或填写试题/答案后再保存") def _normalize_multiline_text(value: str | None) -> str: if not value: return "" text_value = value.replace("\r\n", "\n").replace("\r", "\n") lines = [line.strip() for line in text_value.split("\n")] compact = [line for line in lines if line] return "\n".join(compact).strip() def _wrap_pdf_text(text: str, max_width: float, font_name: str = "STSong-Light", font_size: int = 12) -> list[str]: normalized = _normalize_multiline_text(text) if not normalized: return [] wrapped: list[str] = [] for raw_line in normalized.split("\n"): current = "" for ch in raw_line: candidate = f"{current}{ch}" if pdfmetrics.stringWidth(candidate, font_name, font_size) <= max_width: current = candidate else: if current: wrapped.append(current) current = ch if current: wrapped.append(current) return wrapped def _mistake_export_blocks(item: Mistake, content_mode: str) -> list[str]: question = _normalize_multiline_text(item.question_content) answer = _normalize_multiline_text(item.answer) explanation = _normalize_multiline_text(item.explanation) if not question: question = "无题干与选项内容" blocks: list[str] = [question] if content_mode == "full": # 「答案:」「解析:」与正文同一行开头,避免标签单独成行(与题号+题干规则一致) blocks.append(f"答案: {answer or '无'}") blocks.append(f"解析: {explanation or '无'}") return blocks def _extract_upload_filename(url: str | None) -> str | None: if not url or not url.startswith("/uploads/"): return None return Path(url).name def _safe_datetime(value: str | None) -> datetime: if not value: return datetime.utcnow() try: return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None) except ValueError: return datetime.utcnow() def _safe_date(value: str | None) -> date: if not value: return date.today() try: return date.fromisoformat(value) except ValueError: return date.today() def _extract_json_text(raw_text: str) -> str: content = raw_text.strip() if content.startswith("```"): lines = content.splitlines() if lines: lines = lines[1:] if lines and lines[-1].strip() == "```": lines = lines[:-1] content = "\n".join(lines).strip() return content def _dump_all_data(db: Session) -> dict: resources = db.scalars(select(Resource).order_by(asc(Resource.id))).all() mistakes = db.scalars(select(Mistake).order_by(asc(Mistake.id))).all() scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.id))).all() return { "meta": { "exported_at": datetime.utcnow().isoformat(), "version": "1.1.0", }, "resources": [ { "id": item.id, "title": item.title, "resource_type": item.resource_type, "url": item.url, "file_name": item.file_name, "category": item.category, "tags": item.tags, "created_at": item.created_at.isoformat() if item.created_at else None, } for item in resources ], "mistakes": [ { "id": item.id, "title": item.title, "image_url": item.image_url, "category": item.category, "difficulty": item.difficulty, "question_content": item.question_content, "answer": item.answer, "explanation": item.explanation, "note": item.note, "wrong_count": item.wrong_count, "created_at": item.created_at.isoformat() if item.created_at else None, } for item in mistakes ], "scores": [ { "id": item.id, "exam_name": item.exam_name, "exam_date": item.exam_date.isoformat() if item.exam_date else None, "total_score": item.total_score, "module_scores": item.module_scores, "created_at": item.created_at.isoformat() if item.created_at else None, } for item in scores ], } def _restore_upload_url_from_zip(url: str | None, zip_ref: zipfile.ZipFile) -> str | None: if not url: return None file_name = _extract_upload_filename(url) if not file_name: return url zip_path = f"uploads/{file_name}" if zip_path not in zip_ref.namelist(): return url data = zip_ref.read(zip_path) target_name = file_name target_path = UPLOAD_DIR / target_name if target_path.exists(): target_name = f"{uuid.uuid4().hex}_{file_name}" target_path = UPLOAD_DIR / target_name target_path.write_bytes(data) return f"/uploads/{target_name}" @app.post("/api/upload") async def upload_file(file: UploadFile = File(...)): suffix = Path(file.filename or "").suffix.lower() allowed = {".pdf", ".doc", ".docx", ".jpg", ".jpeg", ".png", ".webp", ".heic", ".heif"} mime_to_suffix = { "application/pdf": ".pdf", "application/msword": ".doc", "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", "image/jpeg": ".jpg", "image/png": ".png", "image/webp": ".webp", "image/heic": ".heic", "image/heif": ".heif", } if suffix not in allowed: guessed = mime_to_suffix.get((file.content_type or "").lower()) if guessed: suffix = guessed else: raise HTTPException(status_code=400, detail="不支持的文件类型") content = await file.read() if len(content) > 50 * 1024 * 1024: raise HTTPException(status_code=400, detail="文件不能超过 50MB") file_name = f"{uuid.uuid4().hex}{suffix}" target = UPLOAD_DIR / file_name target.write_bytes(content) return {"url": f"/uploads/{file_name}", "original_name": file.filename} @app.get("/api/resources", response_model=list[ResourceOut]) def list_resources( q: str | None = None, category: str | None = None, tags: str | None = None, resource_type: str | None = None, sort_by: str = Query("created_at", pattern="^(created_at|title|name)$"), order: str = Query("desc", pattern="^(asc|desc)$"), db: Session = Depends(get_db), ): stmt = select(Resource) if q: stmt = stmt.where( or_( Resource.title.ilike(f"%{q}%"), Resource.tags.ilike(f"%{q}%"), Resource.url.ilike(f"%{q}%"), ) ) if category: stmt = stmt.where(Resource.category == category) if tags: stmt = stmt.where(Resource.tags.ilike(f"%{tags}%")) if resource_type: stmt = stmt.where(Resource.resource_type == resource_type) sort_col = Resource.created_at if sort_by == "created_at" else Resource.title stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col)) return db.scalars(stmt).all() @app.post("/api/resources", response_model=ResourceOut) def create_resource(payload: ResourceCreate, db: Session = Depends(get_db)): item = Resource(**payload.model_dump()) db.add(item) db.commit() db.refresh(item) return item @app.put("/api/resources/{item_id}", response_model=ResourceOut) def update_resource(item_id: int, payload: ResourceUpdate, db: Session = Depends(get_db)): item = db.get(Resource, item_id) if not item: raise HTTPException(status_code=404, detail="Resource not found") for k, v in payload.model_dump().items(): setattr(item, k, v) db.commit() db.refresh(item) return item @app.delete("/api/resources/{item_id}") def delete_resource(item_id: int, db: Session = Depends(get_db)): item = db.get(Resource, item_id) if not item: raise HTTPException(status_code=404, detail="Resource not found") db.delete(item) db.commit() return {"success": True} @app.patch("/api/resources/batch") def batch_update_resources(payload: ResourceBatchUpdate, db: Session = Depends(get_db)): if not payload.ids: raise HTTPException(status_code=400, detail="ids 不能为空") items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all() for item in items: if payload.category is not None: item.category = payload.category if payload.tags is not None: item.tags = payload.tags db.commit() return {"success": True, "count": len(items)} @app.post("/api/resources/batch-delete") def batch_delete_resources(payload: IdBatchPayload, db: Session = Depends(get_db)): if not payload.ids: raise HTTPException(status_code=400, detail="ids 不能为空") items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all() for item in items: db.delete(item) db.commit() return {"success": True, "count": len(items)} @app.get("/api/mistakes", response_model=list[MistakeOut]) def list_mistakes( category: str | None = None, keyword: str | None = None, sort_by: str = Query("created_at", pattern="^(created_at|wrong_count)$"), order: str = Query("desc", pattern="^(asc|desc)$"), db: Session = Depends(get_db), ): stmt = select(Mistake) if category: stmt = stmt.where(Mistake.category == category) if keyword: stmt = stmt.where( or_( Mistake.note.ilike(f"%{keyword}%"), Mistake.title.ilike(f"%{keyword}%"), Mistake.question_content.ilike(f"%{keyword}%"), Mistake.answer.ilike(f"%{keyword}%"), Mistake.explanation.ilike(f"%{keyword}%"), Mistake.image_url.ilike(f"%{keyword}%"), ) ) sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col)) return db.scalars(stmt).all() @app.post("/api/mistakes", response_model=MistakeOut) def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)): _validate_mistake_payload(payload) item = Mistake(**payload.model_dump()) db.add(item) db.commit() db.refresh(item) return item @app.put("/api/mistakes/{item_id}", response_model=MistakeOut) def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)): _validate_mistake_payload(payload) item = db.get(Mistake, item_id) if not item: raise HTTPException(status_code=404, detail="Mistake not found") for k, v in payload.model_dump().items(): setattr(item, k, v) db.commit() db.refresh(item) return item @app.delete("/api/mistakes/{item_id}") def delete_mistake(item_id: int, db: Session = Depends(get_db)): item = db.get(Mistake, item_id) if not item: raise HTTPException(status_code=404, detail="Mistake not found") db.delete(item) db.commit() return {"success": True} @app.get("/api/mistakes/export/pdf") def export_mistakes_pdf( category: str | None = None, start_date: date | None = None, end_date: date | None = None, ids: str | None = None, content_mode: str = Query("full", pattern="^(full|question_only)$"), db: Session = Depends(get_db), ): id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None items = _query_mistakes_for_export(db, category, start_date, end_date, id_list) buf = BytesIO() pdf = canvas.Canvas(buf, pagesize=A4) pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light")) pdf.setFont("STSong-Light", 12) y = 800 left = 48 right = 560 max_width = right - left pdf.drawString(left, y, "学习伙伴 - 错题导出") y -= 28 for idx, item in enumerate(items, start=1): if y < 90: pdf.showPage() pdf.setFont("STSong-Light", 12) y = 800 blocks = _mistake_export_blocks(item, content_mode) for bi, block in enumerate(blocks): # 题号与题干同一行开头,避免「1.」单独成行 text = f"{idx}. {block}" if bi == 0 else block lines = _wrap_pdf_text(text, max_width=max_width) if not lines: continue for line in lines: if y < 70: pdf.showPage() pdf.setFont("STSong-Light", 12) y = 800 pdf.drawString(left, y, line) y -= 18 y -= 6 y -= 8 pdf.save() buf.seek(0) return StreamingResponse( buf, media_type="application/pdf", headers={"Content-Disposition": 'attachment; filename="mistakes.pdf"'}, ) @app.get("/api/mistakes/export/docx") def export_mistakes_docx( category: str | None = None, start_date: date | None = None, end_date: date | None = None, ids: str | None = None, content_mode: str = Query("full", pattern="^(full|question_only)$"), db: Session = Depends(get_db), ): id_list = [int(x) for x in ids.split(",") if x.strip().isdigit()] if ids else None items = _query_mistakes_for_export(db, category, start_date, end_date, id_list) doc = Document() doc.add_heading("学习伙伴 - 错题导出", level=1) for idx, item in enumerate(items, start=1): blocks = _mistake_export_blocks(item, content_mode) for bi, block in enumerate(blocks): # 题号与题干同段,避免单独一行只有「1.」 para = f"{idx}. {block}" if bi == 0 else block doc.add_paragraph(para) buf = BytesIO() doc.save(buf) buf.seek(0) return StreamingResponse( buf, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": 'attachment; filename="mistakes.docx"'}, ) @app.get("/api/scores", response_model=list[ScoreRecordOut]) def list_scores( start_date: date | None = None, end_date: date | None = None, db: Session = Depends(get_db), ): stmt = select(ScoreRecord) if start_date: stmt = stmt.where(ScoreRecord.exam_date >= start_date) if end_date: stmt = stmt.where(ScoreRecord.exam_date <= end_date) stmt = stmt.order_by(asc(ScoreRecord.exam_date)) return db.scalars(stmt).all() @app.post("/api/scores", response_model=ScoreRecordOut) def create_score(payload: ScoreRecordCreate, db: Session = Depends(get_db)): _validate_score_date(payload.exam_date) item = ScoreRecord(**payload.model_dump()) db.add(item) db.commit() db.refresh(item) return item @app.put("/api/scores/{item_id}", response_model=ScoreRecordOut) def update_score(item_id: int, payload: ScoreRecordUpdate, db: Session = Depends(get_db)): _validate_score_date(payload.exam_date) item = db.get(ScoreRecord, item_id) if not item: raise HTTPException(status_code=404, detail="Score record not found") for k, v in payload.model_dump().items(): setattr(item, k, v) db.commit() db.refresh(item) return item @app.delete("/api/scores/{item_id}") def delete_score(item_id: int, db: Session = Depends(get_db)): item = db.get(ScoreRecord, item_id) if not item: raise HTTPException(status_code=404, detail="Score record not found") db.delete(item) db.commit() return {"success": True} @app.get("/api/scores/stats", response_model=ScoreStats) def score_stats(db: Session = Depends(get_db)): scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.exam_date))).all() if not scores: return ScoreStats(highest=0, lowest=0, average=0, improvement=0) highest = max(item.total_score for item in scores) lowest = min(item.total_score for item in scores) avg = db.scalar(select(func.avg(ScoreRecord.total_score))) or 0 improvement = scores[-1].total_score - scores[0].total_score return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement) def _qwen_base_url() -> str: return os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1").strip().rstrip("/") def _get_qwen_api_key() -> str: """去除首尾空白与常见误加的引号,避免 .env 里写成 'sk-xxx' 导致鉴权失败。""" raw = os.getenv("QWEN_API_KEY", "") or "" return raw.strip().strip('"').strip("'").strip() def _raise_for_qwen_http_error(resp: httpx.Response, prefix: str) -> None: """HTTP 非 2xx 时解析 DashScope 错误体,对 invalid_api_key 返回 401 + 明确说明。""" if resp.status_code < 300: return text = resp.text try: data = resp.json() err = data.get("error") if isinstance(err, dict): code = str(err.get("code") or "") if code == "invalid_api_key": raise HTTPException( status_code=401, detail=( "阿里云 DashScope API Key 无效或未生效。" "请到阿里云百炼 / Model Studio 控制台创建 API Key(通常以 sk- 开头)," "写入项目根目录 .env 的 QWEN_API_KEY=,勿加引号;" "修改后执行: docker compose up -d --build backend" ), ) msg = err.get("message") or text raise HTTPException(status_code=502, detail=f"{prefix}: {msg}") except HTTPException: raise except (ValueError, TypeError, KeyError): pass raise HTTPException(status_code=502, detail=f"{prefix}: {text[:1200]}") def _httpx_trust_env() -> bool: """默认不信任环境变量中的代理,避免 Docker/IDE 注入空代理导致 ConnectError;需走系统代理时设 HTTPX_TRUST_ENV=1。""" return os.getenv("HTTPX_TRUST_ENV", "0").lower() in ("1", "true", "yes") def _qwen_http_client(timeout_sec: float = 60.0) -> httpx.AsyncClient: return httpx.AsyncClient( timeout=httpx.Timeout(timeout_sec, connect=20.0), trust_env=_httpx_trust_env(), limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), ) def _message_content_to_str(content: Any) -> str: """OpenAI 兼容接口里 message.content 可能是 str 或多段结构。""" if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for part in content: if isinstance(part, dict): if part.get("type") == "text" and "text" in part: parts.append(str(part["text"])) elif "text" in part: parts.append(str(part["text"])) elif isinstance(part, str): parts.append(part) return "".join(parts) return str(content) def _openai_completion_assistant_text(data: dict) -> str: """从 chat/completions JSON 中取出助手文本;若含 error 或无 choices 则抛错。""" err = data.get("error") if err is not None: if isinstance(err, dict): code = str(err.get("code") or "") if code == "invalid_api_key": raise HTTPException( status_code=401, detail=( "阿里云 DashScope API Key 无效。" "请在 .env 中填写正确的 QWEN_API_KEY 并重启 backend。" ), ) msg = err.get("message") or err.get("code") or json.dumps(err, ensure_ascii=False) else: msg = str(err) raise HTTPException(status_code=502, detail=f"千问接口错误: {msg}") choices = data.get("choices") if not choices: raise HTTPException( status_code=502, detail=f"千问返回异常(无 choices),请检查模型名与权限。原始片段: {json.dumps(data, ensure_ascii=False)[:800]}", ) msg = choices[0].get("message") or {} return _message_content_to_str(msg.get("content")) async def _call_qwen(system_prompt: str, user_prompt: str) -> str: api_key = _get_qwen_api_key() if not api_key: return ( "当前未配置千问 API Key,已返回本地降级提示。\n" "请在项目根目录创建 .env 并配置 QWEN_API_KEY 后重启服务。\n" "示例:\n" "QWEN_API_KEY=your_qwen_api_key_here\n" "QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n" "QWEN_MODEL=qwen-plus" ) base_url = _qwen_base_url() model = os.getenv("QWEN_MODEL", "qwen-plus") headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} payload = { "model": model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": 0.4, } url = f"{base_url}/chat/completions" try: async with _qwen_http_client(40.0) as client: resp = await client.post(url, headers=headers, json=payload) except httpx.ConnectError as e: raise HTTPException( status_code=502, detail=( f"无法连接千问接口({url})。请检查本机/容器能否访问外网、DNS 是否正常;" "若在 Docker 中可尝试为 backend 配置 dns 或关闭错误代理。" "默认已忽略 HTTP(S)_PROXY,若需代理请设置 HTTPX_TRUST_ENV=1。" f" 原始错误: {e!s}" ), ) from e except httpx.TimeoutException as e: raise HTTPException(status_code=504, detail=f"千问请求超时: {e!s}") from e _raise_for_qwen_http_error(resp, "千问请求失败") try: data = resp.json() except ValueError: raise HTTPException(status_code=502, detail=f"千问返回非 JSON: {resp.text[:600]}") return _openai_completion_assistant_text(data) async def _call_qwen_vision(system_prompt: str, user_prompt: str, image_data_url: str) -> str: api_key = _get_qwen_api_key() if not api_key: return ( "当前未配置千问 API Key,无法执行 OCR。\n" "请在 .env 中配置 QWEN_API_KEY 后重试。" ) base_url = _qwen_base_url() model = os.getenv("QWEN_VL_MODEL", "qwen-vl-plus") headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} # 与 DashScope 文档一致:先图后文,利于多模态路由 payload = { "model": model, "messages": [ {"role": "system", "content": system_prompt}, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": image_data_url}}, {"type": "text", "text": user_prompt}, ], }, ], "temperature": 0.2, } url = f"{base_url}/chat/completions" try: async with _qwen_http_client(60.0) as client: resp = await client.post(url, headers=headers, json=payload) except httpx.ConnectError as e: raise HTTPException( status_code=502, detail=( f"无法连接千问接口(OCR,{url})。请检查网络与 DNS;" "默认已忽略 HTTP(S)_PROXY,若需代理请设置 HTTPX_TRUST_ENV=1。" f" 原始错误: {e!s}" ), ) from e except httpx.TimeoutException as e: raise HTTPException(status_code=504, detail=f"OCR 请求超时: {e!s}") from e except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"OCR 网络请求失败: {e!s}") from e _raise_for_qwen_http_error(resp, "OCR 请求失败") try: data = resp.json() except ValueError: raise HTTPException(status_code=502, detail=f"OCR 返回非 JSON: {resp.text[:600]}") return _openai_completion_assistant_text(data) @app.get("/api/data/export") def export_user_data( format: str = Query("zip", pattern="^(zip|json)$"), include_files: bool = True, db: Session = Depends(get_db), ): payload = _dump_all_data(db) timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") if format == "json": buf = BytesIO(json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8")) return StreamingResponse( buf, media_type="application/json", headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.json"'}, ) used_upload_files: set[str] = set() for item in payload["resources"]: name = _extract_upload_filename(item.get("url")) if name: used_upload_files.add(name) for item in payload["mistakes"]: name = _extract_upload_filename(item.get("image_url")) if name: used_upload_files.add(name) buf = BytesIO() with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zip_ref: zip_ref.writestr("data.json", json.dumps(payload, ensure_ascii=False, indent=2)) if include_files: for file_name in sorted(used_upload_files): path = UPLOAD_DIR / file_name if path.exists() and path.is_file(): zip_ref.write(path, arcname=f"uploads/{file_name}") buf.seek(0) return StreamingResponse( buf, media_type="application/zip", headers={"Content-Disposition": f'attachment; filename="exam_helper_backup_{timestamp}.zip"'}, ) @app.post("/api/data/import") async def import_user_data( file: UploadFile = File(...), mode: str = Query("merge", pattern="^(merge|replace)$"), db: Session = Depends(get_db), ): content = await file.read() if not content: raise HTTPException(status_code=400, detail="导入文件为空") if len(content) > 100 * 1024 * 1024: raise HTTPException(status_code=400, detail="导入文件不能超过 100MB") suffix = Path(file.filename or "").suffix.lower() payload: dict zip_ref: zipfile.ZipFile | None = None if suffix == ".json": try: payload = json.loads(content.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError) as exc: raise HTTPException(status_code=400, detail=f"JSON 解析失败: {exc}") from exc elif suffix == ".zip": try: zip_ref = zipfile.ZipFile(BytesIO(content)) except zipfile.BadZipFile as exc: raise HTTPException(status_code=400, detail="ZIP 文件损坏或格式错误") from exc if "data.json" not in zip_ref.namelist(): raise HTTPException(status_code=400, detail="ZIP 中缺少 data.json") try: payload = json.loads(zip_ref.read("data.json").decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError) as exc: raise HTTPException(status_code=400, detail=f"data.json 解析失败: {exc}") from exc else: raise HTTPException(status_code=400, detail="仅支持 .json 或 .zip 导入") resources = payload.get("resources", []) mistakes = payload.get("mistakes", []) scores = payload.get("scores", []) if not isinstance(resources, list) or not isinstance(mistakes, list) or not isinstance(scores, list): raise HTTPException(status_code=400, detail="导入文件结构错误") if mode == "replace": for item in db.scalars(select(Resource)).all(): db.delete(item) for item in db.scalars(select(Mistake)).all(): db.delete(item) for item in db.scalars(select(ScoreRecord)).all(): db.delete(item) db.commit() imported = {"resources": 0, "mistakes": 0, "scores": 0} for item in resources: url = item.get("url") if zip_ref is not None: url = _restore_upload_url_from_zip(url, zip_ref) obj = Resource( title=item.get("title") or "未命名资源", resource_type=item.get("resource_type") if item.get("resource_type") in {"link", "file"} else "link", url=url, file_name=item.get("file_name"), category=item.get("category") or "未分类", tags=item.get("tags"), created_at=_safe_datetime(item.get("created_at")), ) db.add(obj) imported["resources"] += 1 for item in mistakes: image_url = item.get("image_url") if zip_ref is not None: image_url = _restore_upload_url_from_zip(image_url, zip_ref) difficulty = item.get("difficulty") obj = Mistake( title=item.get("title") or "未命名错题", image_url=image_url, category=item.get("category") or "其他", difficulty=difficulty if difficulty in {"easy", "medium", "hard"} else None, question_content=item.get("question_content"), answer=item.get("answer"), explanation=item.get("explanation"), note=item.get("note"), wrong_count=max(int(item.get("wrong_count") or 1), 1), created_at=_safe_datetime(item.get("created_at")), ) db.add(obj) imported["mistakes"] += 1 for item in scores: score = float(item.get("total_score") or 0) obj = ScoreRecord( exam_name=item.get("exam_name") or "未命名考试", exam_date=_safe_date(item.get("exam_date")), total_score=max(min(score, 200), 0), module_scores=item.get("module_scores"), created_at=_safe_datetime(item.get("created_at")), ) db.add(obj) imported["scores"] += 1 db.commit() return {"success": True, "mode": mode, "imported": imported} @app.post("/api/ocr/parse", response_model=OcrParseOut) async def parse_ocr(payload: OcrParseIn): file_name = _extract_upload_filename(payload.image_url) if not file_name: raise HTTPException(status_code=400, detail="仅支持 /uploads 下的图片做 OCR") target = UPLOAD_DIR / file_name if not target.exists() or not target.is_file(): raise HTTPException(status_code=404, detail="图片不存在或已删除") suffix = target.suffix.lower() mime = { ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp", ".heic": "image/heic", ".heif": "image/heif", }.get(suffix) if not mime: raise HTTPException(status_code=400, detail="仅支持 JPG/PNG/WebP/HEIC OCR") b64 = base64.b64encode(target.read_bytes()).decode("utf-8") image_data_url = f"data:{mime};base64,{b64}" ocr_prompt = ( "请识别图片中的题目,返回严格 JSON。" "字段说明:text 为整题完整纯文本(含材料、提问句、全部选项);" "question_content 必须与 text 一致地表示「完整题干」,须包含阅读材料、填空/提问句、所有选项(A B C D 等)," "禁止只填写「依次填入…」等短提示句而省略材料和选项。" "另含 title_suggestion、category_suggestion、difficulty_suggestion、answer、explanation。" "无法确认的字段可填空字符串。" ) if payload.prompt: ocr_prompt = f"{ocr_prompt}\n补充要求:{payload.prompt}" raw_text = await _call_qwen_vision( "你是题目 OCR 与结构化助手。输出必须是 JSON,不要额外解释。", ocr_prompt, image_data_url, ) try: parsed = json.loads(_extract_json_text(raw_text)) data = ( parsed if isinstance(parsed, dict) else { "text": raw_text.strip(), "title_suggestion": None, "category_suggestion": None, "difficulty_suggestion": None, "question_content": raw_text.strip(), "answer": "", "explanation": "", } ) except json.JSONDecodeError: data = { "text": raw_text.strip(), "title_suggestion": None, "category_suggestion": None, "difficulty_suggestion": None, "question_content": raw_text.strip(), "answer": "", "explanation": "", } def _opt_str(val: Any) -> str | None: if val is None: return None if isinstance(val, (dict, list)): return None s = str(val).strip() return s if s else None def _merge_question_body(text_raw: str, qc_raw: str | None) -> str | None: """模型常把全文放在 text,却只把短问句放在 question_content;合并时以更长、更完整的文本为准。""" t = (text_raw or "").strip() q = (qc_raw or "").strip() if not t and not q: return None if not q: return t or None if not t: return q or None if len(t) > len(q): return t if len(q) > len(t): return q # 长度接近或相等:若一方包含另一方,取更长;否则保留 text(整页 OCR 通常更全) if t in q: return q if q in t: return t return t text_out = str(data.get("text", "") or "").strip() qc_model = _opt_str(data.get("question_content")) question_merged = _merge_question_body(text_out, qc_model) return OcrParseOut( text=text_out, title_suggestion=_opt_str(data.get("title_suggestion")), category_suggestion=_opt_str(data.get("category_suggestion")), difficulty_suggestion=_opt_str(data.get("difficulty_suggestion")), question_content=_opt_str(question_merged), answer=_opt_str(data.get("answer")), explanation=_opt_str(data.get("explanation")), ) @app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut) async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)): item = db.get(Mistake, item_id) if not item: raise HTTPException(status_code=404, detail="Mistake not found") content = await _call_qwen( "你是学习教练,请输出结构化、可执行的错题分析。", ( f"错题标题: {item.title}\n" f"分类: {item.category}\n" f"难度: {item.difficulty or '未设置'}\n" f"题目内容: {item.question_content or '无'}\n" f"答案: {item.answer or '无'}\n" f"解析: {item.explanation or '无'}\n" f"错误频次: {item.wrong_count}\n" f"备注: {item.note or '无'}\n\n" "请按以下结构输出:\n" "1) 错误根因\n2) 关键知识点\n3) 3步复盘法\n4) 3道同类训练建议\n5) 明日复习安排" ), ) return AiMistakeAnalysisOut(analysis=content) @app.post("/api/ai/study-plan", response_model=AiStudyPlanOut) async def ai_study_plan(payload: AiStudyPlanIn, db: Session = Depends(get_db)): since = date.today() - timedelta(days=30) recent_scores = db.scalars(select(ScoreRecord).where(ScoreRecord.exam_date >= since).order_by(asc(ScoreRecord.exam_date))).all() recent_mistakes = db.scalars(select(Mistake).order_by(desc(Mistake.created_at)).limit(20)).all() score_text = ", ".join([f"{s.exam_date}:{s.total_score}" for s in recent_scores]) or "暂无成绩数据" mistake_text = ", ".join([f"{m.category}-{m.title}" for m in recent_mistakes]) or "暂无错题数据" content = await _call_qwen( "你是学习规划师,请给出可执行计划并尽量量化。", ( f"目标: {payload.goal}\n" f"剩余天数: {payload.days_left}\n" f"每天可学习小时数: {payload.daily_hours}\n" f"近30天分数: {score_text}\n" f"近期错题: {mistake_text}\n\n" "请输出: 周计划表、每日任务模板、错题复盘节奏、模考安排、风险提醒。" ), ) return AiStudyPlanOut(plan=content)