import os import uuid from datetime import date, datetime, timedelta from io import BytesIO from pathlib import Path import httpx from docx import Document from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from reportlab.lib.pagesizes import A4 from reportlab.pdfbase import pdfmetrics from reportlab.pdfbase.cidfonts import UnicodeCIDFont from reportlab.pdfgen import canvas from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session from .database import Base, engine, get_db from .models import Mistake, Resource, ScoreRecord from .schemas import ( AiMistakeAnalysisOut, AiStudyPlanIn, AiStudyPlanOut, IdBatchPayload, MistakeCreate, MistakeOut, MistakeUpdate, ResourceBatchUpdate, ResourceCreate, ResourceOut, ResourceUpdate, ScoreRecordCreate, ScoreRecordOut, ScoreRecordUpdate, ScoreStats, ) Base.metadata.create_all(bind=engine) app = FastAPI(title="公考助手 API", version="1.0.0") UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads")) UPLOAD_DIR.mkdir(parents=True, exist_ok=True) app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") origins = os.getenv("CORS_ORIGINS", "http://localhost:5173").split(",") app.add_middleware( CORSMiddleware, allow_origins=[origin.strip() for origin in origins if origin.strip()], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") def health(): return {"status": "ok"} def _validate_score_date(exam_date: date) -> None: if exam_date > date.today(): raise HTTPException(status_code=400, detail="考试时间不能晚于今天") def _query_mistakes_for_export( db: Session, category: str | None, start_date: date | None, end_date: date | None, ): stmt = select(Mistake) if category: stmt = stmt.where(Mistake.category == category) if start_date: stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time())) if end_date: stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time())) items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all() if len(items) > 200: raise HTTPException(status_code=400, detail="单次最多导出 200 题") return items @app.post("/api/upload") async def upload_file(file: UploadFile = File(...)): suffix = Path(file.filename or "").suffix.lower() allowed = {".pdf", ".doc", ".docx", ".jpg", ".jpeg", ".png", ".webp"} if suffix not in allowed: raise HTTPException(status_code=400, detail="不支持的文件类型") content = await file.read() if len(content) > 50 * 1024 * 1024: raise HTTPException(status_code=400, detail="文件不能超过 50MB") file_name = f"{uuid.uuid4().hex}{suffix}" target = UPLOAD_DIR / file_name target.write_bytes(content) return {"url": f"/uploads/{file_name}", "original_name": file.filename} @app.get("/api/resources", response_model=list[ResourceOut]) def list_resources( q: str | None = None, category: str | None = None, tags: str | None = None, resource_type: str | None = None, sort_by: str = Query("created_at", pattern="^(created_at|title|name)$"), order: str = Query("desc", pattern="^(asc|desc)$"), db: Session = Depends(get_db), ): stmt = select(Resource) if q: stmt = stmt.where( or_( Resource.title.ilike(f"%{q}%"), Resource.tags.ilike(f"%{q}%"), Resource.url.ilike(f"%{q}%"), ) ) if category: stmt = stmt.where(Resource.category == category) if tags: stmt = stmt.where(Resource.tags.ilike(f"%{tags}%")) if resource_type: stmt = stmt.where(Resource.resource_type == resource_type) sort_col = Resource.created_at if sort_by == "created_at" else Resource.title stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col)) return db.scalars(stmt).all() @app.post("/api/resources", response_model=ResourceOut) def create_resource(payload: ResourceCreate, db: Session = Depends(get_db)): item = Resource(**payload.model_dump()) db.add(item) db.commit() db.refresh(item) return item @app.put("/api/resources/{item_id}", response_model=ResourceOut) def update_resource(item_id: int, payload: ResourceUpdate, db: Session = Depends(get_db)): item = db.get(Resource, item_id) if not item: raise HTTPException(status_code=404, detail="Resource not found") for k, v in payload.model_dump().items(): setattr(item, k, v) db.commit() db.refresh(item) return item @app.delete("/api/resources/{item_id}") def delete_resource(item_id: int, db: Session = Depends(get_db)): item = db.get(Resource, item_id) if not item: raise HTTPException(status_code=404, detail="Resource not found") db.delete(item) db.commit() return {"success": True} @app.patch("/api/resources/batch") def batch_update_resources(payload: ResourceBatchUpdate, db: Session = Depends(get_db)): if not payload.ids: raise HTTPException(status_code=400, detail="ids 不能为空") items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all() for item in items: if payload.category is not None: item.category = payload.category if payload.tags is not None: item.tags = payload.tags db.commit() return {"success": True, "count": len(items)} @app.post("/api/resources/batch-delete") def batch_delete_resources(payload: IdBatchPayload, db: Session = Depends(get_db)): if not payload.ids: raise HTTPException(status_code=400, detail="ids 不能为空") items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all() for item in items: db.delete(item) db.commit() return {"success": True, "count": len(items)} @app.get("/api/mistakes", response_model=list[MistakeOut]) def list_mistakes( category: str | None = None, keyword: str | None = None, sort_by: str = Query("created_at", pattern="^(created_at|wrong_count)$"), order: str = Query("desc", pattern="^(asc|desc)$"), db: Session = Depends(get_db), ): stmt = select(Mistake) if category: stmt = stmt.where(Mistake.category == category) if keyword: stmt = stmt.where(or_(Mistake.note.ilike(f"%{keyword}%"), Mistake.title.ilike(f"%{keyword}%"))) sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col)) return db.scalars(stmt).all() @app.post("/api/mistakes", response_model=MistakeOut) def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)): item = Mistake(**payload.model_dump()) db.add(item) db.commit() db.refresh(item) return item @app.put("/api/mistakes/{item_id}", response_model=MistakeOut) def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)): item = db.get(Mistake, item_id) if not item: raise HTTPException(status_code=404, detail="Mistake not found") for k, v in payload.model_dump().items(): setattr(item, k, v) db.commit() db.refresh(item) return item @app.delete("/api/mistakes/{item_id}") def delete_mistake(item_id: int, db: Session = Depends(get_db)): item = db.get(Mistake, item_id) if not item: raise HTTPException(status_code=404, detail="Mistake not found") db.delete(item) db.commit() return {"success": True} @app.get("/api/mistakes/export/pdf") def export_mistakes_pdf( category: str | None = None, start_date: date | None = None, end_date: date | None = None, db: Session = Depends(get_db), ): items = _query_mistakes_for_export(db, category, start_date, end_date) buf = BytesIO() pdf = canvas.Canvas(buf, pagesize=A4) pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light")) pdf.setFont("STSong-Light", 12) y = 800 pdf.drawString(50, y, "公考助手 - 错题导出") y -= 30 for idx, item in enumerate(items, start=1): lines = [ f"{idx}. {item.title}", f"分类: {item.category} 难度: {item.difficulty or '未设置'} 错误频次: {item.wrong_count}", f"备注: {item.note or '无'}", "答题区: _______________________________", ] for line in lines: if y < 70: pdf.showPage() pdf.setFont("STSong-Light", 12) y = 800 pdf.drawString(50, y, line[:90]) y -= 22 y -= 6 pdf.save() buf.seek(0) return StreamingResponse( buf, media_type="application/pdf", headers={"Content-Disposition": 'attachment; filename="mistakes.pdf"'}, ) @app.get("/api/mistakes/export/docx") def export_mistakes_docx( category: str | None = None, start_date: date | None = None, end_date: date | None = None, db: Session = Depends(get_db), ): items = _query_mistakes_for_export(db, category, start_date, end_date) doc = Document() doc.add_heading("公考助手 - 错题导出", level=1) for idx, item in enumerate(items, start=1): doc.add_paragraph(f"{idx}. {item.title}") doc.add_paragraph(f"分类: {item.category} | 难度: {item.difficulty or '未设置'} | 错误频次: {item.wrong_count}") doc.add_paragraph(f"备注: {item.note or '无'}") doc.add_paragraph("答题区: ________________________________________") buf = BytesIO() doc.save(buf) buf.seek(0) return StreamingResponse( buf, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": 'attachment; filename="mistakes.docx"'}, ) @app.get("/api/scores", response_model=list[ScoreRecordOut]) def list_scores( start_date: date | None = None, end_date: date | None = None, db: Session = Depends(get_db), ): stmt = select(ScoreRecord) if start_date: stmt = stmt.where(ScoreRecord.exam_date >= start_date) if end_date: stmt = stmt.where(ScoreRecord.exam_date <= end_date) stmt = stmt.order_by(asc(ScoreRecord.exam_date)) return db.scalars(stmt).all() @app.post("/api/scores", response_model=ScoreRecordOut) def create_score(payload: ScoreRecordCreate, db: Session = Depends(get_db)): _validate_score_date(payload.exam_date) item = ScoreRecord(**payload.model_dump()) db.add(item) db.commit() db.refresh(item) return item @app.put("/api/scores/{item_id}", response_model=ScoreRecordOut) def update_score(item_id: int, payload: ScoreRecordUpdate, db: Session = Depends(get_db)): _validate_score_date(payload.exam_date) item = db.get(ScoreRecord, item_id) if not item: raise HTTPException(status_code=404, detail="Score record not found") for k, v in payload.model_dump().items(): setattr(item, k, v) db.commit() db.refresh(item) return item @app.delete("/api/scores/{item_id}") def delete_score(item_id: int, db: Session = Depends(get_db)): item = db.get(ScoreRecord, item_id) if not item: raise HTTPException(status_code=404, detail="Score record not found") db.delete(item) db.commit() return {"success": True} @app.get("/api/scores/stats", response_model=ScoreStats) def score_stats(db: Session = Depends(get_db)): scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.exam_date))).all() if not scores: return ScoreStats(highest=0, lowest=0, average=0, improvement=0) highest = max(item.total_score for item in scores) lowest = min(item.total_score for item in scores) avg = db.scalar(select(func.avg(ScoreRecord.total_score))) or 0 improvement = scores[-1].total_score - scores[0].total_score return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement) async def _call_qwen(system_prompt: str, user_prompt: str) -> str: api_key = os.getenv("QWEN_API_KEY", "") if not api_key: return ( "当前未配置千问 API Key,已返回本地降级提示。\n" "请在项目根目录创建 .env 并配置 QWEN_API_KEY 后重启服务。\n" "示例:\n" "QWEN_API_KEY=your_qwen_api_key_here\n" "QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n" "QWEN_MODEL=qwen-plus" ) base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") model = os.getenv("QWEN_MODEL", "qwen-plus") headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} payload = { "model": model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": 0.4, } async with httpx.AsyncClient(timeout=40) as client: resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload) if resp.status_code >= 300: raise HTTPException(status_code=502, detail=f"千问请求失败: {resp.text}") data = resp.json() return data["choices"][0]["message"]["content"] @app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut) async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)): item = db.get(Mistake, item_id) if not item: raise HTTPException(status_code=404, detail="Mistake not found") content = await _call_qwen( "你是公考备考教练,请输出结构化、可执行的错题分析。", ( f"错题标题: {item.title}\n" f"分类: {item.category}\n" f"难度: {item.difficulty or '未设置'}\n" f"错误频次: {item.wrong_count}\n" f"备注: {item.note or '无'}\n\n" "请按以下结构输出:\n" "1) 错误根因\n2) 关键知识点\n3) 3步复盘法\n4) 3道同类训练建议\n5) 明日复习安排" ), ) return AiMistakeAnalysisOut(analysis=content) @app.post("/api/ai/study-plan", response_model=AiStudyPlanOut) async def ai_study_plan(payload: AiStudyPlanIn, db: Session = Depends(get_db)): since = date.today() - timedelta(days=30) recent_scores = db.scalars(select(ScoreRecord).where(ScoreRecord.exam_date >= since).order_by(asc(ScoreRecord.exam_date))).all() recent_mistakes = db.scalars(select(Mistake).order_by(desc(Mistake.created_at)).limit(20)).all() score_text = ", ".join([f"{s.exam_date}:{s.total_score}" for s in recent_scores]) or "暂无成绩数据" mistake_text = ", ".join([f"{m.category}-{m.title}" for m in recent_mistakes]) or "暂无错题数据" content = await _call_qwen( "你是公考学习规划师,请给出可执行计划并尽量量化。", ( f"目标: {payload.goal}\n" f"剩余天数: {payload.days_left}\n" f"每天可学习小时数: {payload.daily_hours}\n" f"近30天分数: {score_text}\n" f"近期错题: {mistake_text}\n\n" "请输出: 周计划表、每日任务模板、错题复盘节奏、模考安排、风险提醒。" ), ) return AiStudyPlanOut(plan=content)