feat: add new file
This commit is contained in:
434
backend/app/main.py
Normal file
434
backend/app/main.py
Normal file
@@ -0,0 +1,434 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import date, datetime, timedelta
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from docx import Document
|
||||
from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
|
||||
from reportlab.pdfgen import canvas
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .database import Base, engine, get_db
|
||||
from .models import Mistake, Resource, ScoreRecord
|
||||
from .schemas import (
|
||||
AiMistakeAnalysisOut,
|
||||
AiStudyPlanIn,
|
||||
AiStudyPlanOut,
|
||||
IdBatchPayload,
|
||||
MistakeCreate,
|
||||
MistakeOut,
|
||||
MistakeUpdate,
|
||||
ResourceBatchUpdate,
|
||||
ResourceCreate,
|
||||
ResourceOut,
|
||||
ResourceUpdate,
|
||||
ScoreRecordCreate,
|
||||
ScoreRecordOut,
|
||||
ScoreRecordUpdate,
|
||||
ScoreStats,
|
||||
)
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app = FastAPI(title="公考助手 API", version="1.0.0")
|
||||
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
||||
|
||||
origins = os.getenv("CORS_ORIGINS", "http://localhost:5173").split(",")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[origin.strip() for origin in origins if origin.strip()],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def _validate_score_date(exam_date: date) -> None:
|
||||
if exam_date > date.today():
|
||||
raise HTTPException(status_code=400, detail="考试时间不能晚于今天")
|
||||
|
||||
|
||||
def _query_mistakes_for_export(
|
||||
db: Session,
|
||||
category: str | None,
|
||||
start_date: date | None,
|
||||
end_date: date | None,
|
||||
):
|
||||
stmt = select(Mistake)
|
||||
if category:
|
||||
stmt = stmt.where(Mistake.category == category)
|
||||
if start_date:
|
||||
stmt = stmt.where(Mistake.created_at >= datetime.combine(start_date, datetime.min.time()))
|
||||
if end_date:
|
||||
stmt = stmt.where(Mistake.created_at <= datetime.combine(end_date, datetime.max.time()))
|
||||
items = db.scalars(stmt.order_by(desc(Mistake.created_at))).all()
|
||||
if len(items) > 200:
|
||||
raise HTTPException(status_code=400, detail="单次最多导出 200 题")
|
||||
return items
|
||||
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
suffix = Path(file.filename or "").suffix.lower()
|
||||
allowed = {".pdf", ".doc", ".docx", ".jpg", ".jpeg", ".png", ".webp"}
|
||||
if suffix not in allowed:
|
||||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||||
content = await file.read()
|
||||
if len(content) > 50 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="文件不能超过 50MB")
|
||||
file_name = f"{uuid.uuid4().hex}{suffix}"
|
||||
target = UPLOAD_DIR / file_name
|
||||
target.write_bytes(content)
|
||||
return {"url": f"/uploads/{file_name}", "original_name": file.filename}
|
||||
|
||||
|
||||
@app.get("/api/resources", response_model=list[ResourceOut])
|
||||
def list_resources(
|
||||
q: str | None = None,
|
||||
category: str | None = None,
|
||||
tags: str | None = None,
|
||||
resource_type: str | None = None,
|
||||
sort_by: str = Query("created_at", pattern="^(created_at|title|name)$"),
|
||||
order: str = Query("desc", pattern="^(asc|desc)$"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
stmt = select(Resource)
|
||||
if q:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Resource.title.ilike(f"%{q}%"),
|
||||
Resource.tags.ilike(f"%{q}%"),
|
||||
Resource.url.ilike(f"%{q}%"),
|
||||
)
|
||||
)
|
||||
if category:
|
||||
stmt = stmt.where(Resource.category == category)
|
||||
if tags:
|
||||
stmt = stmt.where(Resource.tags.ilike(f"%{tags}%"))
|
||||
if resource_type:
|
||||
stmt = stmt.where(Resource.resource_type == resource_type)
|
||||
|
||||
sort_col = Resource.created_at if sort_by == "created_at" else Resource.title
|
||||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||||
return db.scalars(stmt).all()
|
||||
|
||||
|
||||
@app.post("/api/resources", response_model=ResourceOut)
|
||||
def create_resource(payload: ResourceCreate, db: Session = Depends(get_db)):
|
||||
item = Resource(**payload.model_dump())
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@app.put("/api/resources/{item_id}", response_model=ResourceOut)
|
||||
def update_resource(item_id: int, payload: ResourceUpdate, db: Session = Depends(get_db)):
|
||||
item = db.get(Resource, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
for k, v in payload.model_dump().items():
|
||||
setattr(item, k, v)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@app.delete("/api/resources/{item_id}")
|
||||
def delete_resource(item_id: int, db: Session = Depends(get_db)):
|
||||
item = db.get(Resource, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.patch("/api/resources/batch")
|
||||
def batch_update_resources(payload: ResourceBatchUpdate, db: Session = Depends(get_db)):
|
||||
if not payload.ids:
|
||||
raise HTTPException(status_code=400, detail="ids 不能为空")
|
||||
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
|
||||
for item in items:
|
||||
if payload.category is not None:
|
||||
item.category = payload.category
|
||||
if payload.tags is not None:
|
||||
item.tags = payload.tags
|
||||
db.commit()
|
||||
return {"success": True, "count": len(items)}
|
||||
|
||||
|
||||
@app.post("/api/resources/batch-delete")
|
||||
def batch_delete_resources(payload: IdBatchPayload, db: Session = Depends(get_db)):
|
||||
if not payload.ids:
|
||||
raise HTTPException(status_code=400, detail="ids 不能为空")
|
||||
items = db.scalars(select(Resource).where(Resource.id.in_(payload.ids))).all()
|
||||
for item in items:
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"success": True, "count": len(items)}
|
||||
|
||||
|
||||
@app.get("/api/mistakes", response_model=list[MistakeOut])
|
||||
def list_mistakes(
|
||||
category: str | None = None,
|
||||
keyword: str | None = None,
|
||||
sort_by: str = Query("created_at", pattern="^(created_at|wrong_count)$"),
|
||||
order: str = Query("desc", pattern="^(asc|desc)$"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
stmt = select(Mistake)
|
||||
if category:
|
||||
stmt = stmt.where(Mistake.category == category)
|
||||
if keyword:
|
||||
stmt = stmt.where(or_(Mistake.note.ilike(f"%{keyword}%"), Mistake.title.ilike(f"%{keyword}%")))
|
||||
sort_col = Mistake.created_at if sort_by == "created_at" else Mistake.wrong_count
|
||||
stmt = stmt.order_by(desc(sort_col) if order == "desc" else asc(sort_col))
|
||||
return db.scalars(stmt).all()
|
||||
|
||||
|
||||
@app.post("/api/mistakes", response_model=MistakeOut)
|
||||
def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
|
||||
item = Mistake(**payload.model_dump())
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@app.put("/api/mistakes/{item_id}", response_model=MistakeOut)
|
||||
def update_mistake(item_id: int, payload: MistakeUpdate, db: Session = Depends(get_db)):
|
||||
item = db.get(Mistake, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||||
for k, v in payload.model_dump().items():
|
||||
setattr(item, k, v)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@app.delete("/api/mistakes/{item_id}")
|
||||
def delete_mistake(item_id: int, db: Session = Depends(get_db)):
|
||||
item = db.get(Mistake, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.get("/api/mistakes/export/pdf")
|
||||
def export_mistakes_pdf(
|
||||
category: str | None = None,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
items = _query_mistakes_for_export(db, category, start_date, end_date)
|
||||
buf = BytesIO()
|
||||
pdf = canvas.Canvas(buf, pagesize=A4)
|
||||
pdfmetrics.registerFont(UnicodeCIDFont("STSong-Light"))
|
||||
pdf.setFont("STSong-Light", 12)
|
||||
|
||||
y = 800
|
||||
pdf.drawString(50, y, "公考助手 - 错题导出")
|
||||
y -= 30
|
||||
for idx, item in enumerate(items, start=1):
|
||||
lines = [
|
||||
f"{idx}. {item.title}",
|
||||
f"分类: {item.category} 难度: {item.difficulty or '未设置'} 错误频次: {item.wrong_count}",
|
||||
f"备注: {item.note or '无'}",
|
||||
"答题区: _______________________________",
|
||||
]
|
||||
for line in lines:
|
||||
if y < 70:
|
||||
pdf.showPage()
|
||||
pdf.setFont("STSong-Light", 12)
|
||||
y = 800
|
||||
pdf.drawString(50, y, line[:90])
|
||||
y -= 22
|
||||
y -= 6
|
||||
|
||||
pdf.save()
|
||||
buf.seek(0)
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": 'attachment; filename="mistakes.pdf"'},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/mistakes/export/docx")
|
||||
def export_mistakes_docx(
|
||||
category: str | None = None,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
items = _query_mistakes_for_export(db, category, start_date, end_date)
|
||||
doc = Document()
|
||||
doc.add_heading("公考助手 - 错题导出", level=1)
|
||||
for idx, item in enumerate(items, start=1):
|
||||
doc.add_paragraph(f"{idx}. {item.title}")
|
||||
doc.add_paragraph(f"分类: {item.category} | 难度: {item.difficulty or '未设置'} | 错误频次: {item.wrong_count}")
|
||||
doc.add_paragraph(f"备注: {item.note or '无'}")
|
||||
doc.add_paragraph("答题区: ________________________________________")
|
||||
|
||||
buf = BytesIO()
|
||||
doc.save(buf)
|
||||
buf.seek(0)
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": 'attachment; filename="mistakes.docx"'},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/scores", response_model=list[ScoreRecordOut])
|
||||
def list_scores(
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
stmt = select(ScoreRecord)
|
||||
if start_date:
|
||||
stmt = stmt.where(ScoreRecord.exam_date >= start_date)
|
||||
if end_date:
|
||||
stmt = stmt.where(ScoreRecord.exam_date <= end_date)
|
||||
stmt = stmt.order_by(asc(ScoreRecord.exam_date))
|
||||
return db.scalars(stmt).all()
|
||||
|
||||
|
||||
@app.post("/api/scores", response_model=ScoreRecordOut)
|
||||
def create_score(payload: ScoreRecordCreate, db: Session = Depends(get_db)):
|
||||
_validate_score_date(payload.exam_date)
|
||||
item = ScoreRecord(**payload.model_dump())
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@app.put("/api/scores/{item_id}", response_model=ScoreRecordOut)
|
||||
def update_score(item_id: int, payload: ScoreRecordUpdate, db: Session = Depends(get_db)):
|
||||
_validate_score_date(payload.exam_date)
|
||||
item = db.get(ScoreRecord, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Score record not found")
|
||||
for k, v in payload.model_dump().items():
|
||||
setattr(item, k, v)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
@app.delete("/api/scores/{item_id}")
|
||||
def delete_score(item_id: int, db: Session = Depends(get_db)):
|
||||
item = db.get(ScoreRecord, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Score record not found")
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.get("/api/scores/stats", response_model=ScoreStats)
|
||||
def score_stats(db: Session = Depends(get_db)):
|
||||
scores = db.scalars(select(ScoreRecord).order_by(asc(ScoreRecord.exam_date))).all()
|
||||
if not scores:
|
||||
return ScoreStats(highest=0, lowest=0, average=0, improvement=0)
|
||||
|
||||
highest = max(item.total_score for item in scores)
|
||||
lowest = min(item.total_score for item in scores)
|
||||
avg = db.scalar(select(func.avg(ScoreRecord.total_score))) or 0
|
||||
improvement = scores[-1].total_score - scores[0].total_score
|
||||
return ScoreStats(highest=highest, lowest=lowest, average=round(float(avg), 2), improvement=improvement)
|
||||
|
||||
|
||||
async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
|
||||
api_key = os.getenv("QWEN_API_KEY", "")
|
||||
if not api_key:
|
||||
return (
|
||||
"当前未配置千问 API Key,已返回本地降级提示。\n"
|
||||
"请在项目根目录创建 .env 并配置 QWEN_API_KEY 后重启服务。\n"
|
||||
"示例:\n"
|
||||
"QWEN_API_KEY=your_qwen_api_key_here\n"
|
||||
"QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n"
|
||||
"QWEN_MODEL=qwen-plus"
|
||||
)
|
||||
base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
model = os.getenv("QWEN_MODEL", "qwen-plus")
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": 0.4,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=40) as client:
|
||||
resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload)
|
||||
if resp.status_code >= 300:
|
||||
raise HTTPException(status_code=502, detail=f"千问请求失败: {resp.text}")
|
||||
data = resp.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
@app.post("/api/ai/mistakes/{item_id}/analyze", response_model=AiMistakeAnalysisOut)
|
||||
async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
|
||||
item = db.get(Mistake, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Mistake not found")
|
||||
content = await _call_qwen(
|
||||
"你是公考备考教练,请输出结构化、可执行的错题分析。",
|
||||
(
|
||||
f"错题标题: {item.title}\n"
|
||||
f"分类: {item.category}\n"
|
||||
f"难度: {item.difficulty or '未设置'}\n"
|
||||
f"错误频次: {item.wrong_count}\n"
|
||||
f"备注: {item.note or '无'}\n\n"
|
||||
"请按以下结构输出:\n"
|
||||
"1) 错误根因\n2) 关键知识点\n3) 3步复盘法\n4) 3道同类训练建议\n5) 明日复习安排"
|
||||
),
|
||||
)
|
||||
return AiMistakeAnalysisOut(analysis=content)
|
||||
|
||||
|
||||
@app.post("/api/ai/study-plan", response_model=AiStudyPlanOut)
|
||||
async def ai_study_plan(payload: AiStudyPlanIn, db: Session = Depends(get_db)):
|
||||
since = date.today() - timedelta(days=30)
|
||||
recent_scores = db.scalars(select(ScoreRecord).where(ScoreRecord.exam_date >= since).order_by(asc(ScoreRecord.exam_date))).all()
|
||||
recent_mistakes = db.scalars(select(Mistake).order_by(desc(Mistake.created_at)).limit(20)).all()
|
||||
score_text = ", ".join([f"{s.exam_date}:{s.total_score}" for s in recent_scores]) or "暂无成绩数据"
|
||||
mistake_text = ", ".join([f"{m.category}-{m.title}" for m in recent_mistakes]) or "暂无错题数据"
|
||||
|
||||
content = await _call_qwen(
|
||||
"你是公考学习规划师,请给出可执行计划并尽量量化。",
|
||||
(
|
||||
f"目标: {payload.goal}\n"
|
||||
f"剩余天数: {payload.days_left}\n"
|
||||
f"每天可学习小时数: {payload.daily_hours}\n"
|
||||
f"近30天分数: {score_text}\n"
|
||||
f"近期错题: {mistake_text}\n\n"
|
||||
"请输出: 周计划表、每日任务模板、错题复盘节奏、模考安排、风险提醒。"
|
||||
),
|
||||
)
|
||||
return AiStudyPlanOut(plan=content)
|
||||
Reference in New Issue
Block a user