435 lines
15 KiB
Python
435 lines
15 KiB
Python
import os
|
||
import uuid
|
||
from datetime import date, datetime, timedelta
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from docx import Document
|
||
from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from reportlab.lib.pagesizes import A4
|
||
from reportlab.pdfbase import pdfmetrics
|
||
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
|
||
from reportlab.pdfgen import canvas
|
||
from sqlalchemy import asc, desc, func, or_, select
|
||
from sqlalchemy.orm import Session
|
||
|
||
from .database import Base, engine, get_db
|
||
from .models import Mistake, Resource, ScoreRecord
|
||
from .schemas import (
|
||
AiMistakeAnalysisOut,
|
||
AiStudyPlanIn,
|
||
AiStudyPlanOut,
|
||
IdBatchPayload,
|
||
MistakeCreate,
|
||
MistakeOut,
|
||
MistakeUpdate,
|
||
ResourceBatchUpdate,
|
||
ResourceCreate,
|
||
ResourceOut,
|
||
ResourceUpdate,
|
||
ScoreRecordCreate,
|
||
ScoreRecordOut,
|
||
ScoreRecordUpdate,
|
||
ScoreStats,
|
||
)
|
||
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
app = FastAPI(title="公考助手 API", version="1.0.0")
|
||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
||
|
||
origins = os.getenv("CORS_ORIGINS", "http://localhost:5173").split(",")
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=[origin.strip() for origin in origins if origin.strip()],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
return {"status": "ok"}
|
||
|
||
|
||
def _validate_score_date(exam_date: date) -> None:
|
||
if exam_date > date.today():
|
||
raise HTTPException(status_code=400, detail="考试时间不能晚于今天")
|
||
|
||
|
||
def _query_mistakes_for_export(
|
||
db: Session,
|
||
category: str | None,
|
||
start_date: date | None,
|
||
end_date: date | None,
|
||
):
|
||
stmt = select(Mistake)
|
||
if category:
|
||
stmt = stmt.where(Mistake.category == category)
|
||
if start_date:
|
||
stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time()))
|
||
if end_date:
|
||
stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time()))
|
||
items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all()
|
||
if len(items) > 200:
|
||
raise HTTPException(status_code=400, detail="单次最多导出 200 题")
|
||
return items
|
||
|
||
|
||
@app.post("/api/upload")
|
||
async def upload_file(file: UploadFile = File(...)):
|
||
suffix = Path(file.filename or "").suffix.lower()
|
||
allowed = {".pdf", ".doc", ".docx", ".jpg", ".jpeg", ".png", ".webp"}
|
||
if suffix not in allowed:
|
||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||
content = await file.read()
|
||
if len(content) > 50 * 1024 * 1024:
|
||
raise HTTPException(status_code=400, detail="文件不能超过 50MB")
|
||
file_name = f"{uuid.uuid4().hex}{suffix}"
|
||
target = UPLOAD_DIR / file_name
|
||
target.write_bytes(content)
|
||
return {"url": f"/uploads/{file_name}", "original_name": file.filename}
|
||
|
||
|
||
@app.get("/api/resources", response_model=list[ResourceOut])
|
||
def list_resources(
|
||
q: str | None = None,
|
||
category: str | None = None,
|
||
tags: str | None = None,
|
||
resource_type: str | None = None,
|
||
sort_by: str = Query("created_at", pattern="^(created_at|title|name)$"),
|
||
order: str = Query("desc", pattern="^(asc|desc)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
stmt = select(Resource)
|
||
if q:
|
||
stmt = stmt.where(
|
||
or_(
|
||
Resource.title.ilike(f"%{q}%"),
|
||
Resource.tags.ilike(f"%{q}%"),
|
||
Resource.url.ilike(f"%{q}%"),
|
||
)
|
||
)
|
||
if category:
|
||
stmt = stmt.where(Resource.category == category)
|
||
if tags:
|
||
stmt = stmt.where(Resource.tags.ilike(f"%{tags}%"))
|
||
if resource_type:
|
||
stmt = stmt.where(Resource.resource_type == resource_type)
|
||
|
||
sort_col = Resource.created_at if sort_by == "created_at" else Resource.title
|
||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||
return db.scalars(stmt).all()
|
||
|
||
|
||
@app.post("/api/resources", response_model=ResourceOut)
|
||
def create_resource(payload: ResourceCreate, db: Session = Depends(get_db)):
|
||
item = Resource(**payload.model_dump())
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.put("/api/resources/{item_id}", response_model=ResourceOut)
|
||
def update_resource(item_id: int, payload: ResourceUpdate, db: Session = Depends(get_db)):
|
||
item = db.get(Resource, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Resource not found")
|
||
for k, v in payload.model_dump().items():
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.delete("/api/resources/{item_id}")
|
||
def delete_resource(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(Resource, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Resource not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True}
|
||
|
||
|
||
@app.patch("/api/resources/batch")
|
||
def batch_update_resources(payload: ResourceBatchUpdate, db: Session = Depends(get_db)):
|
||
if not payload.ids:
|
||
raise HTTPException(status_code=400, detail="ids 不能为空")
|
||
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
|
||
for item in items:
|
||
if payload.category is not None:
|
||
item.category = payload.category
|
||
if payload.tags is not None:
|
||
item.tags = payload.tags
|
||
db.commit()
|
||
return {"success": True, "count": len(items)}
|
||
|
||
|
||
@app.post("/api/resources/batch-delete")
|
||
def batch_delete_resources(payload: IdBatchPayload, db: Session = Depends(get_db)):
|
||
if not payload.ids:
|
||
raise HTTPException(status_code=400, detail="ids 不能为空")
|
||
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
|
||
for item in items:
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True, "count": len(items)}
|
||
|
||
|
||
@app.get("/api/mistakes", response_model=list[MistakeOut])
|
||
def list_mistakes(
|
||
category: str | None = None,
|
||
keyword: str | None = None,
|
||
sort_by: str = Query("created_at", pattern="^(created_at|wrong_count)$"),
|
||
order: str = Query("desc", pattern="^(asc|desc)$"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
stmt = select(Mistake)
|
||
if category:
|
||
stmt = stmt.where(Mistake.category == category)
|
||
if keyword:
|
||
stmt = stmt.where(or_(Mistake.note.ilike(f"%{keyword}%"), Mistake.title.ilike(f"%{keyword}%")))
|
||
sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count
|
||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||
return db.scalars(stmt).all()
|
||
|
||
|
||
@app.post("/api/mistakes", response_model=MistakeOut)
|
||
def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
|
||
item = Mistake(**payload.model_dump())
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.put("/api/mistakes/{item_id}", response_model=MistakeOut)
|
||
def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)):
|
||
item = db.get(Mistake, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||
for k, v in payload.model_dump().items():
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.delete("/api/mistakes/{item_id}")
|
||
def delete_mistake(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(Mistake, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True}
|
||
|
||
|
||
@app.get("/api/mistakes/export/pdf")
|
||
def export_mistakes_pdf(
|
||
category: str | None = None,
|
||
start_date: date | None = None,
|
||
end_date: date | None = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
items = _query_mistakes_for_export(db, category, start_date, end_date)
|
||
buf = BytesIO()
|
||
pdf = canvas.Canvas(buf, pagesize=A4)
|
||
pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light"))
|
||
pdf.setFont("STSong-Light", 12)
|
||
|
||
y = 800
|
||
pdf.drawString(50, y, "公考助手 - 错题导出")
|
||
y -= 30
|
||
for idx, item in enumerate(items, start=1):
|
||
lines = [
|
||
f"{idx}. {item.title}",
|
||
f"分类: {item.category} 难度: {item.difficulty or '未设置'} 错误频次: {item.wrong_count}",
|
||
f"备注: {item.note or '无'}",
|
||
"答题区: _______________________________",
|
||
]
|
||
for line in lines:
|
||
if y < 70:
|
||
pdf.showPage()
|
||
pdf.setFont("STSong-Light", 12)
|
||
y = 800
|
||
pdf.drawString(50, y, line[:90])
|
||
y -= 22
|
||
y -= 6
|
||
|
||
pdf.save()
|
||
buf.seek(0)
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/pdf",
|
||
headers={"Content-Disposition": 'attachment; filename="mistakes.pdf"'},
|
||
)
|
||
|
||
|
||
@app.get("/api/mistakes/export/docx")
|
||
def export_mistakes_docx(
|
||
category: str | None = None,
|
||
start_date: date | None = None,
|
||
end_date: date | None = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
items = _query_mistakes_for_export(db, category, start_date, end_date)
|
||
doc = Document()
|
||
doc.add_heading("公考助手 - 错题导出", level=1)
|
||
for idx, item in enumerate(items, start=1):
|
||
doc.add_paragraph(f"{idx}. {item.title}")
|
||
doc.add_paragraph(f"分类: {item.category} | 难度: {item.difficulty or '未设置'} | 错误频次: {item.wrong_count}")
|
||
doc.add_paragraph(f"备注: {item.note or '无'}")
|
||
doc.add_paragraph("答题区: ________________________________________")
|
||
|
||
buf = BytesIO()
|
||
doc.save(buf)
|
||
buf.seek(0)
|
||
return StreamingResponse(
|
||
buf,
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
headers={"Content-Disposition": 'attachment; filename="mistakes.docx"'},
|
||
)
|
||
|
||
|
||
@app.get("/api/scores", response_model=list[ScoreRecordOut])
|
||
def list_scores(
|
||
start_date: date | None = None,
|
||
end_date: date | None = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
stmt = select(ScoreRecord)
|
||
if start_date:
|
||
stmt = stmt.where(ScoreRecord.exam_date >= start_date)
|
||
if end_date:
|
||
stmt = stmt.where(ScoreRecord.exam_date <= end_date)
|
||
stmt = stmt.order_by(asc(ScoreRecord.exam_date))
|
||
return db.scalars(stmt).all()
|
||
|
||
|
||
@app.post("/api/scores", response_model=ScoreRecordOut)
|
||
def create_score(payload: ScoreRecordCreate, db: Session = Depends(get_db)):
|
||
_validate_score_date(payload.exam_date)
|
||
item = ScoreRecord(**payload.model_dump())
|
||
db.add(item)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.put("/api/scores/{item_id}", response_model=ScoreRecordOut)
|
||
def update_score(item_id: int, payload: ScoreRecordUpdate, db: Session = Depends(get_db)):
|
||
_validate_score_date(payload.exam_date)
|
||
item = db.get(ScoreRecord, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Score record not found")
|
||
for k, v in payload.model_dump().items():
|
||
setattr(item, k, v)
|
||
db.commit()
|
||
db.refresh(item)
|
||
return item
|
||
|
||
|
||
@app.delete("/api/scores/{item_id}")
|
||
def delete_score(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(ScoreRecord, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Score record not found")
|
||
db.delete(item)
|
||
db.commit()
|
||
return {"success": True}
|
||
|
||
|
||
@app.get("/api/scores/stats", response_model=ScoreStats)
|
||
def score_stats(db: Session = Depends(get_db)):
|
||
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.exam_date))).all()
|
||
if not scores:
|
||
return ScoreStats(highest=0, lowest=0, average=0, improvement=0)
|
||
|
||
highest = max(item.total_score for item in scores)
|
||
lowest = min(item.total_score for item in scores)
|
||
avg = db.scalar(select(func.avg(ScoreRecord.total_score))) or 0
|
||
improvement = scores[-1].total_score - scores[0].total_score
|
||
return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement)
|
||
|
||
|
||
async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
|
||
api_key = os.getenv("QWEN_API_KEY", "")
|
||
if not api_key:
|
||
return (
|
||
"当前未配置千问 API Key,已返回本地降级提示。\n"
|
||
"请在项目根目录创建 .env 并配置 QWEN_API_KEY 后重启服务。\n"
|
||
"示例:\n"
|
||
"QWEN_API_KEY=your_qwen_api_key_here\n"
|
||
"QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n"
|
||
"QWEN_MODEL=qwen-plus"
|
||
)
|
||
base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
model = os.getenv("QWEN_MODEL", "qwen-plus")
|
||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||
payload = {
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": 0.4,
|
||
}
|
||
async with httpx.AsyncClient(timeout=40) as client:
|
||
resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload)
|
||
if resp.status_code >= 300:
|
||
raise HTTPException(status_code=502, detail=f"千问请求失败: {resp.text}")
|
||
data = resp.json()
|
||
return data["choices"][0]["message"]["content"]
|
||
|
||
|
||
@app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut)
|
||
async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
|
||
item = db.get(Mistake, item_id)
|
||
if not item:
|
||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||
content = await _call_qwen(
|
||
"你是公考备考教练,请输出结构化、可执行的错题分析。",
|
||
(
|
||
f"错题标题: {item.title}\n"
|
||
f"分类: {item.category}\n"
|
||
f"难度: {item.difficulty or '未设置'}\n"
|
||
f"错误频次: {item.wrong_count}\n"
|
||
f"备注: {item.note or '无'}\n\n"
|
||
"请按以下结构输出:\n"
|
||
"1) 错误根因\n2) 关键知识点\n3) 3步复盘法\n4) 3道同类训练建议\n5) 明日复习安排"
|
||
),
|
||
)
|
||
return AiMistakeAnalysisOut(analysis=content)
|
||
|
||
|
||
@app.post("/api/ai/study-plan", response_model=AiStudyPlanOut)
|
||
async def ai_study_plan(payload: AiStudyPlanIn, db: Session = Depends(get_db)):
|
||
since = date.today() - timedelta(days=30)
|
||
recent_scores = db.scalars(select(ScoreRecord).where(ScoreRecord.exam_date >= since).order_by(asc(ScoreRecord.exam_date))).all()
|
||
recent_mistakes = db.scalars(select(Mistake).order_by(desc(Mistake.created_at)).limit(20)).all()
|
||
score_text = ", ".join([f"{s.exam_date}:{s.total_score}" for s in recent_scores]) or "暂无成绩数据"
|
||
mistake_text = ", ".join([f"{m.category}-{m.title}" for m in recent_mistakes]) or "暂无错题数据"
|
||
|
||
content = await _call_qwen(
|
||
"你是公考学习规划师,请给出可执行计划并尽量量化。",
|
||
(
|
||
f"目标: {payload.goal}\n"
|
||
f"剩余天数: {payload.days_left}\n"
|
||
f"每天可学习小时数: {payload.daily_hours}\n"
|
||
f"近30天分数: {score_text}\n"
|
||
f"近期错题: {mistake_text}\n\n"
|
||
"请输出: 周计划表、每日任务模板、错题复盘节奏、模考安排、风险提醒。"
|
||
),
|
||
)
|
||
return AiStudyPlanOut(plan=content)
|