@@ -1,5 +1,9 @@
import os
import json
import uuid
import zipfile
import base64
from typing import Any
from datetime import date , datetime , timedelta
from io import BytesIO
from pathlib import Path
@@ -14,7 +18,7 @@ from reportlab.lib.pagesizes import A4
from reportlab . pdfbase import pdfmetrics
from reportlab . pdfbase . cidfonts import UnicodeCIDFont
from reportlab . pdfgen import canvas
from sqlalchemy import asc , desc , func , or_ , select
from sqlalchemy import asc , desc , func , inspect , or_ , select , text
from sqlalchemy . orm import Session
from . database import Base , engine , get_db
@@ -27,6 +31,8 @@ from .schemas import (
MistakeCreate ,
MistakeOut ,
MistakeUpdate ,
OcrParseIn ,
OcrParseOut ,
ResourceBatchUpdate ,
ResourceCreate ,
ResourceOut ,
@@ -39,6 +45,23 @@ from .schemas import (
Base . metadata . create_all ( bind = engine )
def _migrate_mistake_columns ( ) - > None :
inspector = inspect ( engine )
if " mistakes " not in inspector . get_table_names ( ) :
return
existed = { col [ " name " ] for col in inspector . get_columns ( " mistakes " ) }
with engine . begin ( ) as conn :
if " question_content " not in existed :
conn . execute ( text ( " ALTER TABLE mistakes ADD COLUMN question_content TEXT " ) )
if " answer " not in existed :
conn . execute ( text ( " ALTER TABLE mistakes ADD COLUMN answer TEXT " ) )
if " explanation " not in existed :
conn . execute ( text ( " ALTER TABLE mistakes ADD COLUMN explanation TEXT " ) )
_migrate_mistake_columns ( )
app = FastAPI ( title = " 公考助手 API " , version = " 1.0.0 " )
UPLOAD_DIR = Path ( os . getenv ( " UPLOAD_DIR " , " /app/uploads " ) )
UPLOAD_DIR . mkdir ( parents = True , exist_ok = True )
@@ -69,6 +92,7 @@ def _query_mistakes_for_export(
category : str | None ,
start_date : date | None ,
end_date : date | None ,
ids : list [ int ] | None = None ,
) :
stmt = select ( Mistake )
if category :
@@ -77,12 +101,175 @@ def _query_mistakes_for_export(
stmt = stmt . where ( Mistake . created_at > = datetime . combine ( start_date , datetime . min . time ( ) ) )
if end_date :
stmt = stmt . where ( Mistake . created_at < = datetime . combine ( end_date , datetime . max . time ( ) ) )
if ids :
stmt = stmt . where ( Mistake . id . in_ ( ids ) )
items = db . scalars ( stmt . order_by ( desc ( Mistake . created_at ) ) ) . all ( )
if len ( items ) > 200 :
raise HTTPException ( status_code = 400 , detail = " 单次最多导出 200 题 " )
return items
def _validate_mistake_payload ( payload : MistakeCreate | MistakeUpdate ) - > None :
has_image = bool ( ( payload . image_url or " " ) . strip ( ) )
has_question = bool ( ( payload . question_content or " " ) . strip ( ) )
has_answer = bool ( ( payload . answer or " " ) . strip ( ) )
if not has_image and not has_question and not has_answer :
raise HTTPException ( status_code = 400 , detail = " 请上传题目图片或填写试题/答案后再保存 " )
def _normalize_multiline_text ( value : str | None ) - > str :
if not value :
return " "
text_value = value . replace ( " \r \n " , " \n " ) . replace ( " \r " , " \n " )
lines = [ line . strip ( ) for line in text_value . split ( " \n " ) ]
compact = [ line for line in lines if line ]
return " \n " . join ( compact ) . strip ( )
def _wrap_pdf_text ( text : str , max_width : float , font_name : str = " STSong-Light " , font_size : int = 12 ) - > list [ str ] :
normalized = _normalize_multiline_text ( text )
if not normalized :
return [ ]
wrapped : list [ str ] = [ ]
for raw_line in normalized . split ( " \n " ) :
current = " "
for ch in raw_line :
candidate = f " { current } { ch } "
if pdfmetrics . stringWidth ( candidate , font_name , font_size ) < = max_width :
current = candidate
else :
if current :
wrapped . append ( current )
current = ch
if current :
wrapped . append ( current )
return wrapped
def _mistake_export_blocks ( item : Mistake , content_mode : str ) - > list [ str ] :
question = _normalize_multiline_text ( item . question_content )
answer = _normalize_multiline_text ( item . answer )
explanation = _normalize_multiline_text ( item . explanation )
if not question :
question = " 无题干与选项内容 "
blocks : list [ str ] = [ question ]
if content_mode == " full " :
# 「答案:」「解析:」与正文同一行开头,避免标签单独成行(与题号+题干规则一致)
blocks . append ( f " 答案: { answer or ' 无 ' } " )
blocks . append ( f " 解析: { explanation or ' 无 ' } " )
return blocks
def _extract_upload_filename ( url : str | None ) - > str | None :
if not url or not url . startswith ( " /uploads/ " ) :
return None
return Path ( url ) . name
def _safe_datetime ( value : str | None ) - > datetime :
if not value :
return datetime . utcnow ( )
try :
return datetime . fromisoformat ( value . replace ( " Z " , " +00:00 " ) ) . replace ( tzinfo = None )
except ValueError :
return datetime . utcnow ( )
def _safe_date ( value : str | None ) - > date :
if not value :
return date . today ( )
try :
return date . fromisoformat ( value )
except ValueError :
return date . today ( )
def _extract_json_text ( raw_text : str ) - > str :
content = raw_text . strip ( )
if content . startswith ( " ``` " ) :
lines = content . splitlines ( )
if lines :
lines = lines [ 1 : ]
if lines and lines [ - 1 ] . strip ( ) == " ``` " :
lines = lines [ : - 1 ]
content = " \n " . join ( lines ) . strip ( )
return content
def _dump_all_data ( db : Session ) - > dict :
resources = db . scalars ( select ( Resource ) . order_by ( asc ( Resource . id ) ) ) . all ( )
mistakes = db . scalars ( select ( Mistake ) . order_by ( asc ( Mistake . id ) ) ) . all ( )
scores = db . scalars ( select ( ScoreRecord ) . order_by ( asc ( ScoreRecord . id ) ) ) . all ( )
return {
" meta " : {
" exported_at " : datetime . utcnow ( ) . isoformat ( ) ,
" version " : " 1.1.0 " ,
} ,
" resources " : [
{
" id " : item . id ,
" title " : item . title ,
" resource_type " : item . resource_type ,
" url " : item . url ,
" file_name " : item . file_name ,
" category " : item . category ,
" tags " : item . tags ,
" created_at " : item . created_at . isoformat ( ) if item . created_at else None ,
}
for item in resources
] ,
" mistakes " : [
{
" id " : item . id ,
" title " : item . title ,
" image_url " : item . image_url ,
" category " : item . category ,
" difficulty " : item . difficulty ,
" question_content " : item . question_content ,
" answer " : item . answer ,
" explanation " : item . explanation ,
" note " : item . note ,
" wrong_count " : item . wrong_count ,
" created_at " : item . created_at . isoformat ( ) if item . created_at else None ,
}
for item in mistakes
] ,
" scores " : [
{
" id " : item . id ,
" exam_name " : item . exam_name ,
" exam_date " : item . exam_date . isoformat ( ) if item . exam_date else None ,
" total_score " : item . total_score ,
" module_scores " : item . module_scores ,
" created_at " : item . created_at . isoformat ( ) if item . created_at else None ,
}
for item in scores
] ,
}
def _restore_upload_url_from_zip ( url : str | None , zip_ref : zipfile . ZipFile ) - > str | None :
if not url :
return None
file_name = _extract_upload_filename ( url )
if not file_name :
return url
zip_path = f " uploads/ { file_name } "
if zip_path not in zip_ref . namelist ( ) :
return url
data = zip_ref . read ( zip_path )
target_name = file_name
target_path = UPLOAD_DIR / target_name
if target_path . exists ( ) :
target_name = f " { uuid . uuid4 ( ) . hex } _ { file_name } "
target_path = UPLOAD_DIR / target_name
target_path . write_bytes ( data )
return f " /uploads/ { target_name } "
@app.post ( " /api/upload " )
async def upload_file ( file : UploadFile = File ( . . . ) ) :
suffix = Path ( file . filename or " " ) . suffix . lower ( )
@@ -197,7 +384,16 @@ def list_mistakes(
if category :
stmt = stmt . where ( Mistake . category == category )
if keyword :
stmt = stmt . where ( or_ ( Mistake . note . ilike ( f " % { keyword } % " ) , Mistake . title . ilike ( f " % { keyword } % " ) ) )
stmt = stmt . where (
or_ (
Mistake . note . ilike ( f " % { keyword } % " ) ,
Mistake . title . ilike ( f " % { keyword } % " ) ,
Mistake . question_content . ilike ( f " % { keyword } % " ) ,
Mistake . answer . ilike ( f " % { keyword } % " ) ,
Mistake . explanation . ilike ( f " % { keyword } % " ) ,
Mistake . image_url . ilike ( f " % { keyword } % " ) ,
)
)
sort_col = Mistake . created_at if sort_by == " created_at " else Mistake . wrong_count
stmt = stmt . order_by ( desc ( sort_col ) if order == " desc " else asc ( sort_col ) )
return db . scalars ( stmt ) . all ( )
@@ -205,6 +401,7 @@ def list_mistakes(
@app.post ( " /api/mistakes " , response_model = MistakeOut )
def create_mistake ( payload : MistakeCreate , db : Session = Depends ( get_db ) ) :
_validate_mistake_payload ( payload )
item = Mistake ( * * payload . model_dump ( ) )
db . add ( item )
db . commit ( )
@@ -214,6 +411,7 @@ def create_mistake(payload: MistakeCreate, db: Session = Depends(get_db)):
@app.put ( " /api/mistakes/ {item_id} " , response_model = MistakeOut )
def update_mistake ( item_id : int , payload : MistakeUpdate , db : Session = Depends ( get_db ) ) :
_validate_mistake_payload ( payload )
item = db . get ( Mistake , item_id )
if not item :
raise HTTPException ( status_code = 404 , detail = " Mistake not found " )
@@ -239,32 +437,44 @@ def export_mistakes_pdf(
category : str | None = None ,
start_date : date | None = None ,
end_date : date | None = None ,
ids : str | None = None ,
content_mode : str = Query ( " full " , pattern = " ^(full|question_only)$ " ) ,
db : Session = Depends ( get_db ) ,
) :
items = _query_mistakes_for_export ( db , category , start_date , end_date )
id_list = [ int ( x ) for x in ids . split ( " , " ) if x . strip ( ) . isdigit ( ) ] if ids else None
items = _query_mistakes_for_export ( db , category , start_date , end_date , id_list )
buf = BytesIO ( )
pdf = canvas . Canvas ( buf , pagesize = A4 )
pdfmetrics . registerFont ( UnicodeCIDFont ( " STSong-Light " ) )
pdf . setFont ( " STSong-Light " , 12 )
y = 800
pdf . drawString ( 50 , y , " 公考助手 - 错题导出 " )
y - = 3 0
left = 48
right = 56 0
max_width = right - left
pdf . drawString ( left , y , " 公考助手 - 错题导出 " )
y - = 28
for idx , item in enumerate ( items , start = 1 ) :
lines = [
f " { idx } . { item . title } " ,
f " 分类: { item . category } 难度: { item . difficulty or ' 未设置 ' } 错误频次: { item . wrong_count } " ,
f " 备注: { item . note or ' 无 ' } " ,
" 答题区: _______________________________ " ,
]
for line in lines :
if y < 7 0:
pdf . showPage ( )
pdf . setFont ( " STSong-Light " , 12 )
y = 800
pdf . drawString ( 50 , y , line [ : 90 ] )
y - = 22
y - = 6
if y < 90 :
pdf . showPage ( )
pdf . setFont ( " STSong-Light " , 12 )
y = 800
blocks = _mistake_export_blocks ( item , content_mode )
for bi , block in enumerate ( blocks ) :
# 题号与题干同一行开头, 避免「1.」单独成行
text = f " { idx } . { block } " if bi == 0 else block
lines = _wrap_pdf_text ( text , max_width = max_width )
if not lines :
continue
for line in lines :
if y < 70 :
pdf . showPage ( )
pdf . setFont ( " STSong-Light " , 12 )
y = 800
pdf . drawString ( left , y , line )
y - = 18
y - = 6
y - = 8
pdf . save ( )
buf . seek ( 0 )
@@ -280,16 +490,20 @@ def export_mistakes_docx(
category : str | None = None ,
start_date : date | None = None ,
end_date : date | None = None ,
ids : str | None = None ,
content_mode : str = Query ( " full " , pattern = " ^(full|question_only)$ " ) ,
db : Session = Depends ( get_db ) ,
) :
items = _query_mistakes_for_export ( db , category , start_date , end_date )
id_list = [ int ( x ) for x in ids . split ( " , " ) if x . strip ( ) . isdigit ( ) ] if ids else None
items = _query_mistakes_for_export ( db , category , start_date , end_date , id_list )
doc = Document ( )
doc . add_heading ( " 公考助手 - 错题导出 " , level = 1 )
for idx , item in enumerate ( items , start = 1 ) :
doc . add_paragraph ( f " { idx } . { item . title } " )
doc . add_paragraph ( f " 分类: { item . category } | 难度: { item . difficulty or ' 未设置 ' } | 错误频次: { item . wrong_count } " )
doc . add_paragraph ( f " 备注: { item . note or ' 无 ' } " )
doc . add_paragraph ( " 答题区: ________________________________________ " )
blocks = _mistake_export_blocks ( item , content_mode )
for bi , block in enumerate ( blocks ) :
# 题号与题干同段, 避免单独一行只有「1.」
para = f " { idx } . { block } " if bi == 0 else block
doc . add_paragraph ( para )
buf = BytesIO ( )
doc . save ( buf )
@@ -362,8 +576,108 @@ def score_stats(db: Session = Depends(get_db)):
return ScoreStats ( highest = highest , lowest = lowest , average = round ( float ( avg ) , 2 ) , improvement = improvement )
def _qwen_base_url ( ) - > str :
return os . getenv ( " QWEN_BASE_URL " , " https://dashscope.aliyuncs.com/compatible-mode/v1 " ) . strip ( ) . rstrip ( " / " )
def _get_qwen_api_key ( ) - > str :
""" 去除首尾空白与常见误加的引号,避免 .env 里写成 ' sk-xxx ' 导致鉴权失败。 """
raw = os . getenv ( " QWEN_API_KEY " , " " ) or " "
return raw . strip ( ) . strip ( ' " ' ) . strip ( " ' " ) . strip ( )
def _raise_for_qwen_http_error ( resp : httpx . Response , prefix : str ) - > None :
""" HTTP 非 2xx 时解析 DashScope 错误体,对 invalid_api_key 返回 401 + 明确说明。 """
if resp . status_code < 300 :
return
text = resp . text
try :
data = resp . json ( )
err = data . get ( " error " )
if isinstance ( err , dict ) :
code = str ( err . get ( " code " ) or " " )
if code == " invalid_api_key " :
raise HTTPException (
status_code = 401 ,
detail = (
" 阿里云 DashScope API Key 无效或未生效。 "
" 请到阿里云百炼 / Model Studio 控制台创建 API Key( 通常以 sk- 开头), "
" 写入项目根目录 .env 的 QWEN_API_KEY=,勿加引号; "
" 修改后执行: docker compose up -d --build backend "
) ,
)
msg = err . get ( " message " ) or text
raise HTTPException ( status_code = 502 , detail = f " { prefix } : { msg } " )
except HTTPException :
raise
except ( ValueError , TypeError , KeyError ) :
pass
raise HTTPException ( status_code = 502 , detail = f " { prefix } : { text [ : 1200 ] } " )
def _httpx_trust_env ( ) - > bool :
""" 默认不信任环境变量中的代理,避免 Docker/IDE 注入空代理导致 ConnectError; 需走系统代理时设 HTTPX_TRUST_ENV=1。 """
return os . getenv ( " HTTPX_TRUST_ENV " , " 0 " ) . lower ( ) in ( " 1 " , " true " , " yes " )
def _qwen_http_client ( timeout_sec : float = 60.0 ) - > httpx . AsyncClient :
return httpx . AsyncClient (
timeout = httpx . Timeout ( timeout_sec , connect = 20.0 ) ,
trust_env = _httpx_trust_env ( ) ,
limits = httpx . Limits ( max_keepalive_connections = 5 , max_connections = 10 ) ,
)
def _message_content_to_str ( content : Any ) - > str :
""" OpenAI 兼容接口里 message.content 可能是 str 或多段结构。 """
if content is None :
return " "
if isinstance ( content , str ) :
return content
if isinstance ( content , list ) :
parts : list [ str ] = [ ]
for part in content :
if isinstance ( part , dict ) :
if part . get ( " type " ) == " text " and " text " in part :
parts . append ( str ( part [ " text " ] ) )
elif " text " in part :
parts . append ( str ( part [ " text " ] ) )
elif isinstance ( part , str ) :
parts . append ( part )
return " " . join ( parts )
return str ( content )
def _openai_completion_assistant_text ( data : dict ) - > str :
""" 从 chat/completions JSON 中取出助手文本;若含 error 或无 choices 则抛错。 """
err = data . get ( " error " )
if err is not None :
if isinstance ( err , dict ) :
code = str ( err . get ( " code " ) or " " )
if code == " invalid_api_key " :
raise HTTPException (
status_code = 401 ,
detail = (
" 阿里云 DashScope API Key 无效。 "
" 请在 .env 中填写正确的 QWEN_API_KEY 并重启 backend。 "
) ,
)
msg = err . get ( " message " ) or err . get ( " code " ) or json . dumps ( err , ensure_ascii = False )
else :
msg = str ( err )
raise HTTPException ( status_code = 502 , detail = f " 千问接口错误: { msg } " )
choices = data . get ( " choices " )
if not choices :
raise HTTPException (
status_code = 502 ,
detail = f " 千问返回异常(无 choices) , 请检查模型名与权限。原始片段: { json . dumps ( data , ensure_ascii = False ) [ : 800 ] } " ,
)
msg = choices [ 0 ] . get ( " message " ) or { }
return _message_content_to_str ( msg . get ( " content " ) )
async def _call_qwen ( system_prompt : str , user_prompt : str ) - > str :
api_key = os . getenv ( " QWEN_API_KEY " , " " )
api_key = _get_qwen_api_key ( )
if not api_key :
return (
" 当前未配置千问 API Key, 已返回本地降级提示。 \n "
@@ -373,7 +687,7 @@ async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
" QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 \n "
" QWEN_MODEL=qwen-plus "
)
base_url = os . getenv ( " QWEN_BASE_URL " , " https://dashscope.aliyuncs.com/compatible-mode/v1 " )
base_url = _qwen_base_url ( )
model = os . getenv ( " QWEN_MODEL " , " qwen-plus " )
headers = { " Authorization " : f " Bearer { api_key } " , " Content-Type " : " application/json " }
payload = {
@@ -384,12 +698,333 @@ async def _call_qwen(system_prompt: str, user_prompt: str) -> str:
] ,
" temperature " : 0.4 ,
}
async with httpx . AsyncClient ( timeout = 40 ) as client :
resp = await client . post ( f " { base_url } /chat/completions " , headers = headers , json = payload )
if resp . status_code > = 300 :
raise HTTPException ( status_code = 502 , detail = f " 千问请求失败: { resp . text } " )
data = resp . json ( )
return data [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
url = f " { base_url } /chat/completions "
try :
async with _qwen_http_client ( 40.0 ) as client :
resp = await client . post ( url , headers = headers , json = payload )
except httpx . ConnectError as e :
raise HTTPException (
status_code = 502 ,
detail = (
f " 无法连接千问接口( { url } )。请检查本机/容器能否访问外网、DNS 是否正常; "
" 若在 Docker 中可尝试为 backend 配置 dns 或关闭错误代理。 "
" 默认已忽略 HTTP(S)_PROXY, 若需代理请设置 HTTPX_TRUST_ENV=1。 "
f " 原始错误: { e !s} "
) ,
) from e
except httpx . TimeoutException as e :
raise HTTPException ( status_code = 504 , detail = f " 千问请求超时: { e !s} " ) from e
_raise_for_qwen_http_error ( resp , " 千问请求失败 " )
try :
data = resp . json ( )
except ValueError :
raise HTTPException ( status_code = 502 , detail = f " 千问返回非 JSON: { resp . text [ : 600 ] } " )
return _openai_completion_assistant_text ( data )
async def _call_qwen_vision ( system_prompt : str , user_prompt : str , image_data_url : str ) - > str :
api_key = _get_qwen_api_key ( )
if not api_key :
return (
" 当前未配置千问 API Key, 无法执行 OCR。 \n "
" 请在 .env 中配置 QWEN_API_KEY 后重试。 "
)
base_url = _qwen_base_url ( )
model = os . getenv ( " QWEN_VL_MODEL " , " qwen-vl-plus " )
headers = { " Authorization " : f " Bearer { api_key } " , " Content-Type " : " application/json " }
# 与 DashScope 文档一致:先图后文,利于多模态路由
payload = {
" model " : model ,
" messages " : [
{ " role " : " system " , " content " : system_prompt } ,
{
" role " : " user " ,
" content " : [
{ " type " : " image_url " , " image_url " : { " url " : image_data_url } } ,
{ " type " : " text " , " text " : user_prompt } ,
] ,
} ,
] ,
" temperature " : 0.2 ,
}
url = f " { base_url } /chat/completions "
try :
async with _qwen_http_client ( 60.0 ) as client :
resp = await client . post ( url , headers = headers , json = payload )
except httpx . ConnectError as e :
raise HTTPException (
status_code = 502 ,
detail = (
f " 无法连接千问接口( OCR, { url } )。请检查网络与 DNS; "
" 默认已忽略 HTTP(S)_PROXY, 若需代理请设置 HTTPX_TRUST_ENV=1。 "
f " 原始错误: { e !s} "
) ,
) from e
except httpx . TimeoutException as e :
raise HTTPException ( status_code = 504 , detail = f " OCR 请求超时: { e !s} " ) from e
except httpx . RequestError as e :
raise HTTPException ( status_code = 502 , detail = f " OCR 网络请求失败: { e !s} " ) from e
_raise_for_qwen_http_error ( resp , " OCR 请求失败 " )
try :
data = resp . json ( )
except ValueError :
raise HTTPException ( status_code = 502 , detail = f " OCR 返回非 JSON: { resp . text [ : 600 ] } " )
return _openai_completion_assistant_text ( data )
@app.get ( " /api/data/export " )
def export_user_data (
format : str = Query ( " zip " , pattern = " ^(zip|json)$ " ) ,
include_files : bool = True ,
db : Session = Depends ( get_db ) ,
) :
payload = _dump_all_data ( db )
timestamp = datetime . utcnow ( ) . strftime ( " % Y % m %d _ % H % M % S " )
if format == " json " :
buf = BytesIO ( json . dumps ( payload , ensure_ascii = False , indent = 2 ) . encode ( " utf-8 " ) )
return StreamingResponse (
buf ,
media_type = " application/json " ,
headers = { " Content-Disposition " : f ' attachment; filename= " exam_helper_backup_ { timestamp } .json " ' } ,
)
used_upload_files : set [ str ] = set ( )
for item in payload [ " resources " ] :
name = _extract_upload_filename ( item . get ( " url " ) )
if name :
used_upload_files . add ( name )
for item in payload [ " mistakes " ] :
name = _extract_upload_filename ( item . get ( " image_url " ) )
if name :
used_upload_files . add ( name )
buf = BytesIO ( )
with zipfile . ZipFile ( buf , " w " , zipfile . ZIP_DEFLATED ) as zip_ref :
zip_ref . writestr ( " data.json " , json . dumps ( payload , ensure_ascii = False , indent = 2 ) )
if include_files :
for file_name in sorted ( used_upload_files ) :
path = UPLOAD_DIR / file_name
if path . exists ( ) and path . is_file ( ) :
zip_ref . write ( path , arcname = f " uploads/ { file_name } " )
buf . seek ( 0 )
return StreamingResponse (
buf ,
media_type = " application/zip " ,
headers = { " Content-Disposition " : f ' attachment; filename= " exam_helper_backup_ { timestamp } .zip " ' } ,
)
@app.post ( " /api/data/import " )
async def import_user_data (
file : UploadFile = File ( . . . ) ,
mode : str = Query ( " merge " , pattern = " ^(merge|replace)$ " ) ,
db : Session = Depends ( get_db ) ,
) :
content = await file . read ( )
if not content :
raise HTTPException ( status_code = 400 , detail = " 导入文件为空 " )
if len ( content ) > 100 * 1024 * 1024 :
raise HTTPException ( status_code = 400 , detail = " 导入文件不能超过 100MB " )
suffix = Path ( file . filename or " " ) . suffix . lower ( )
payload : dict
zip_ref : zipfile . ZipFile | None = None
if suffix == " .json " :
try :
payload = json . loads ( content . decode ( " utf-8 " ) )
except ( UnicodeDecodeError , json . JSONDecodeError ) as exc :
raise HTTPException ( status_code = 400 , detail = f " JSON 解析失败: { exc } " ) from exc
elif suffix == " .zip " :
try :
zip_ref = zipfile . ZipFile ( BytesIO ( content ) )
except zipfile . BadZipFile as exc :
raise HTTPException ( status_code = 400 , detail = " ZIP 文件损坏或格式错误 " ) from exc
if " data.json " not in zip_ref . namelist ( ) :
raise HTTPException ( status_code = 400 , detail = " ZIP 中缺少 data.json " )
try :
payload = json . loads ( zip_ref . read ( " data.json " ) . decode ( " utf-8 " ) )
except ( UnicodeDecodeError , json . JSONDecodeError ) as exc :
raise HTTPException ( status_code = 400 , detail = f " data.json 解析失败: { exc } " ) from exc
else :
raise HTTPException ( status_code = 400 , detail = " 仅支持 .json 或 .zip 导入 " )
resources = payload . get ( " resources " , [ ] )
mistakes = payload . get ( " mistakes " , [ ] )
scores = payload . get ( " scores " , [ ] )
if not isinstance ( resources , list ) or not isinstance ( mistakes , list ) or not isinstance ( scores , list ) :
raise HTTPException ( status_code = 400 , detail = " 导入文件结构错误 " )
if mode == " replace " :
for item in db . scalars ( select ( Resource ) ) . all ( ) :
db . delete ( item )
for item in db . scalars ( select ( Mistake ) ) . all ( ) :
db . delete ( item )
for item in db . scalars ( select ( ScoreRecord ) ) . all ( ) :
db . delete ( item )
db . commit ( )
imported = { " resources " : 0 , " mistakes " : 0 , " scores " : 0 }
for item in resources :
url = item . get ( " url " )
if zip_ref is not None :
url = _restore_upload_url_from_zip ( url , zip_ref )
obj = Resource (
title = item . get ( " title " ) or " 未命名资源 " ,
resource_type = item . get ( " resource_type " ) if item . get ( " resource_type " ) in { " link " , " file " } else " link " ,
url = url ,
file_name = item . get ( " file_name " ) ,
category = item . get ( " category " ) or " 未分类 " ,
tags = item . get ( " tags " ) ,
created_at = _safe_datetime ( item . get ( " created_at " ) ) ,
)
db . add ( obj )
imported [ " resources " ] + = 1
for item in mistakes :
image_url = item . get ( " image_url " )
if zip_ref is not None :
image_url = _restore_upload_url_from_zip ( image_url , zip_ref )
difficulty = item . get ( " difficulty " )
obj = Mistake (
title = item . get ( " title " ) or " 未命名错题 " ,
image_url = image_url ,
category = item . get ( " category " ) or " 其他 " ,
difficulty = difficulty if difficulty in { " easy " , " medium " , " hard " } else None ,
question_content = item . get ( " question_content " ) ,
answer = item . get ( " answer " ) ,
explanation = item . get ( " explanation " ) ,
note = item . get ( " note " ) ,
wrong_count = max ( int ( item . get ( " wrong_count " ) or 1 ) , 1 ) ,
created_at = _safe_datetime ( item . get ( " created_at " ) ) ,
)
db . add ( obj )
imported [ " mistakes " ] + = 1
for item in scores :
score = float ( item . get ( " total_score " ) or 0 )
obj = ScoreRecord (
exam_name = item . get ( " exam_name " ) or " 未命名考试 " ,
exam_date = _safe_date ( item . get ( " exam_date " ) ) ,
total_score = max ( min ( score , 200 ) , 0 ) ,
module_scores = item . get ( " module_scores " ) ,
created_at = _safe_datetime ( item . get ( " created_at " ) ) ,
)
db . add ( obj )
imported [ " scores " ] + = 1
db . commit ( )
return { " success " : True , " mode " : mode , " imported " : imported }
@app.post ( " /api/ocr/parse " , response_model = OcrParseOut )
async def parse_ocr ( payload : OcrParseIn ) :
file_name = _extract_upload_filename ( payload . image_url )
if not file_name :
raise HTTPException ( status_code = 400 , detail = " 仅支持 /uploads 下的图片做 OCR " )
target = UPLOAD_DIR / file_name
if not target . exists ( ) or not target . is_file ( ) :
raise HTTPException ( status_code = 404 , detail = " 图片不存在或已删除 " )
suffix = target . suffix . lower ( )
mime = {
" .jpg " : " image/jpeg " ,
" .jpeg " : " image/jpeg " ,
" .png " : " image/png " ,
" .webp " : " image/webp " ,
} . get ( suffix )
if not mime :
raise HTTPException ( status_code = 400 , detail = " 仅支持 JPG/PNG/WebP OCR " )
b64 = base64 . b64encode ( target . read_bytes ( ) ) . decode ( " utf-8 " )
image_data_url = f " data: { mime } ;base64, { b64 } "
ocr_prompt = (
" 请识别图片中的题目,返回严格 JSON。 "
" 字段说明: text 为整题完整纯文本(含材料、提问句、全部选项); "
" question_content 必须与 text 一致地表示「完整题干」,须包含阅读材料、填空/提问句、所有选项( A B C D 等), "
" 禁止只填写「依次填入…」等短提示句而省略材料和选项。 "
" 另含 title_suggestion、category_suggestion、difficulty_suggestion、answer、explanation。 "
" 无法确认的字段可填空字符串。 "
)
if payload . prompt :
ocr_prompt = f " { ocr_prompt } \n 补充要求: { payload . prompt } "
raw_text = await _call_qwen_vision (
" 你是公考题目OCR与结构化助手。输出必须是 JSON, 不要额外解释。 " ,
ocr_prompt ,
image_data_url ,
)
try :
parsed = json . loads ( _extract_json_text ( raw_text ) )
data = (
parsed
if isinstance ( parsed , dict )
else {
" text " : raw_text . strip ( ) ,
" title_suggestion " : None ,
" category_suggestion " : None ,
" difficulty_suggestion " : None ,
" question_content " : raw_text . strip ( ) ,
" answer " : " " ,
" explanation " : " " ,
}
)
except json . JSONDecodeError :
data = {
" text " : raw_text . strip ( ) ,
" title_suggestion " : None ,
" category_suggestion " : None ,
" difficulty_suggestion " : None ,
" question_content " : raw_text . strip ( ) ,
" answer " : " " ,
" explanation " : " " ,
}
def _opt_str ( val : Any ) - > str | None :
if val is None :
return None
if isinstance ( val , ( dict , list ) ) :
return None
s = str ( val ) . strip ( )
return s if s else None
def _merge_question_body ( text_raw : str , qc_raw : str | None ) - > str | None :
""" 模型常把全文放在 text, 却只把短问句放在 question_content; 合并时以更长、更完整的文本为准。 """
t = ( text_raw or " " ) . strip ( )
q = ( qc_raw or " " ) . strip ( )
if not t and not q :
return None
if not q :
return t or None
if not t :
return q or None
if len ( t ) > len ( q ) :
return t
if len ( q ) > len ( t ) :
return q
# 长度接近或相等:若一方包含另一方,取更长;否则保留 text( 整页 OCR 通常更全)
if t in q :
return q
if q in t :
return t
return t
text_out = str ( data . get ( " text " , " " ) or " " ) . strip ( )
qc_model = _opt_str ( data . get ( " question_content " ) )
question_merged = _merge_question_body ( text_out , qc_model )
return OcrParseOut (
text = text_out ,
title_suggestion = _opt_str ( data . get ( " title_suggestion " ) ) ,
category_suggestion = _opt_str ( data . get ( " category_suggestion " ) ) ,
difficulty_suggestion = _opt_str ( data . get ( " difficulty_suggestion " ) ) ,
question_content = _opt_str ( question_merged ) ,
answer = _opt_str ( data . get ( " answer " ) ) ,
explanation = _opt_str ( data . get ( " explanation " ) ) ,
)
@app.post ( " /api/ai/mistakes/ {item_id} /analyze " , response_model = AiMistakeAnalysisOut )
@@ -403,6 +1038,9 @@ async def ai_analyze_mistake(item_id: int, db: Session = Depends(get_db)):
f " 错题标题: { item . title } \n "
f " 分类: { item . category } \n "
f " 难度: { item . difficulty or ' 未设置 ' } \n "
f " 题目内容: { item . question_content or ' 无 ' } \n "
f " 答案: { item . answer or ' 无 ' } \n "
f " 解析: { item . explanation or ' 无 ' } \n "
f " 错误频次: { item . wrong_count } \n "
f " 备注: { item . note or ' 无 ' } \n \n "
" 请按以下结构输出: \n "