Files
AIcreat/app/services/ai_rewriter.py
Daniel babf24a0b0 fix
2026-04-01 18:49:09 +08:00

446 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import difflib
import json
import logging
import re
import time
from textwrap import shorten
from openai import OpenAI
from app.config import settings
from app.schemas import RewriteRequest, RewriteResponse
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """
你是顶级中文公众号主编,擅长把 X/Twitter 的观点型内容改写成高质量公众号文章。
你的目标不是“同义替换”,而是“重构表达”,保证可读性、逻辑性和可发布性。
硬性规则:
1) 保留核心事实与关键观点,不编造数据,不夸大结论;
2) 文章结构必须完整:导语 -> 核心观点 -> 深度分析 -> 落地建议 -> 结语;
3) 风格自然,避免 AI 套话(如“首先其次最后”“赋能”“闭环”等空话);
4) 每节都要有信息增量,不要重复原文句式;
5) 输出必须是合法 JSON字段title, summary, body_markdown。
""".strip()
REWRITE_SCHEMA_HINT = """
请输出 JSON
{
"title": "20字内中文标题明确价值点",
"summary": "80-120字中文摘要说明读者收获",
"body_markdown": "完整Markdown正文"
}
正文格式要求(必须遵循):
## 导语
2-3段交代背景、冲突与阅读价值。
## 核心观点
- 3~5条要点每条是完整信息句不要口号。
## 深度分析
### 1) 现象背后的原因
2-3段
### 2) 对行业/团队的影响
2-3段
### 3) 关键风险与边界
2-3段
## 落地建议
1. 三到五条可执行动作,尽量包含“谁在什么场景做什么”。
## 结语
1段收束观点并给出下一步建议。
""".strip()
class AIRewriter:
def __init__(self) -> None:
self._client = None
self._prefer_chat_first = False
if settings.openai_api_key:
base_url = settings.openai_base_url or ""
self._prefer_chat_first = "dashscope.aliyuncs.com" in base_url
self._client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url,
timeout=settings.openai_timeout,
max_retries=1,
)
def rewrite(self, req: RewriteRequest) -> RewriteResponse:
cleaned_source = self._clean_source(req.source_text)
started = time.monotonic()
# Primary: model rewrite + quality gate + optional second-pass polish.
if self._client:
# DashScope/Qwen works better with a single stable call.
if self._prefer_chat_first:
first_pass_timeout = max(18.0, min(30.0, settings.openai_timeout))
else:
first_pass_timeout = max(20.0, min(50.0, settings.openai_timeout))
draft = self._model_rewrite(req, cleaned_source, timeout_sec=first_pass_timeout)
if draft:
normalized = self._normalize_result(draft)
issues = self._quality_issues(req, cleaned_source, normalized)
elapsed = time.monotonic() - started
remaining_budget = max(0.0, (first_pass_timeout + 20.0) - elapsed)
if issues and (not self._prefer_chat_first) and remaining_budget >= 10.0:
polished = self._model_polish(
req,
cleaned_source,
normalized,
issues,
timeout_sec=min(30.0, remaining_budget),
)
if polished:
normalized = self._normalize_result(polished)
final_issues = self._quality_issues(req, cleaned_source, normalized)
if not final_issues:
return RewriteResponse(**normalized, mode="ai", quality_notes=[])
logger.warning("rewrite quality gate fallback triggered: %s", final_issues)
# Secondary: deterministic fallback with publishable structure.
return self._fallback_rewrite(req, cleaned_source, reason="模型超时或质量未达标,已使用结构化保底稿")
def _model_rewrite(self, req: RewriteRequest, cleaned_source: str, timeout_sec: float) -> dict | None:
user_prompt = self._build_user_prompt(req, cleaned_source)
return self._call_model_json(user_prompt, timeout_sec=timeout_sec)
def _model_polish(
self,
req: RewriteRequest,
cleaned_source: str,
normalized: dict,
issues: list[str],
timeout_sec: float,
) -> dict | None:
issue_text = "\n".join([f"- {i}" for i in issues])
user_prompt = f"""
你上一次的改写稿质量未达标,请基于下面问题做彻底重写,不要只改几个词:
{issue_text}
原始内容:
{cleaned_source}
上一次草稿:
标题:{normalized.get('title', '')}
摘要:{normalized.get('summary', '')}
正文:
{normalized.get('body_markdown', '')}
用户改写偏好:
- 标题参考:{req.title_hint or '自动生成'}
- 语气风格:{req.tone}
- 目标读者:{req.audience}
- 必须保留观点:{req.keep_points or ''}
- 避免词汇:{req.avoid_words or ''}
请输出一个全新且高质量版本。{REWRITE_SCHEMA_HINT}
""".strip()
return self._call_model_json(user_prompt, timeout_sec=timeout_sec)
def _build_user_prompt(self, req: RewriteRequest, cleaned_source: str) -> str:
return f"""
原始内容(已清洗):
{cleaned_source}
用户改写偏好:
- 标题参考:{req.title_hint or '自动生成'}
- 语气风格:{req.tone}
- 目标读者:{req.audience}
- 必须保留观点:{req.keep_points or ''}
- 避免词汇:{req.avoid_words or ''}
任务:请输出可直接用于公众号发布的文章。{REWRITE_SCHEMA_HINT}
""".strip()
def _fallback_rewrite(self, req: RewriteRequest, cleaned_source: str, reason: str) -> RewriteResponse:
sentences = self._extract_sentences(cleaned_source)
points = self._pick_key_points(sentences, limit=5)
title = req.title_hint.strip() or self._build_fallback_title(sentences)
summary = self._build_fallback_summary(points, cleaned_source)
intro = self._build_intro(points, cleaned_source)
analysis = self._build_analysis(points)
actions = self._build_actions(points)
conclusion = "如果你准备把这类内容持续做成栏目,建议建立固定模板:观点来源、关键证据、执行建议、复盘结论。"
body = (
"## 导语\n"
f"{intro}\n\n"
"## 核心观点\n"
+ "\n".join([f"- {p}" for p in points])
+ "\n\n"
"## 深度分析\n"
"### 1) 现象背后的原因\n"
f"{analysis['cause']}\n\n"
"### 2) 对行业/团队的影响\n"
f"{analysis['impact']}\n\n"
"### 3) 关键风险与边界\n"
f"{analysis['risk']}\n\n"
"## 落地建议\n"
+ "\n".join([f"{i + 1}. {a}" for i, a in enumerate(actions)])
+ "\n\n"
"## 结语\n"
f"{conclusion}"
)
normalized = {
"title": title,
"summary": summary,
"body_markdown": self._format_markdown(body),
}
return RewriteResponse(**normalized, mode="fallback", quality_notes=[reason])
def _build_fallback_title(self, sentences: list[str]) -> str:
seed = sentences[0] if sentences else "内容改写"
seed = shorten(seed, width=16, placeholder="")
return f"{seed}:给内容创作者的实战拆解"
def _build_fallback_summary(self, points: list[str], source: str) -> str:
if len(points) >= 2:
return f"本文提炼了{points[0]},并进一步分析{points[1]},最后给出可直接执行的发布建议,帮助你把观点内容做成高质量公众号文章。"
return shorten(re.sub(r"\s+", " ", source), width=110, placeholder="...")
def _build_intro(self, points: list[str], source: str) -> str:
focus = points[0] if points else shorten(source, width=42, placeholder="...")
return (
f"这篇内容的价值不在“信息多”,而在于它点出了一个真正值得关注的问题:{focus}\n\n"
"对公众号读者来说,最关心的是这件事会带来什么变化、现在能做什么。"
"因此本文不做逐句复述,而是按“观点-影响-动作”重组,方便直接落地。"
)
def _build_analysis(self, points: list[str]) -> dict[str, str]:
p1 = points[0] if points else "行业正在从信息堆叠转向结果导向"
p2 = points[1] if len(points) > 1 else "团队协作方式被自动化流程重塑"
p3 = points[2] if len(points) > 2 else "内容质量会成为真正分水岭"
return {
"cause": (
f"从表面看是工具迭代,实质是生产逻辑变化。{p1},意味着过去依赖经验的环节,正在被标准化流程替代。"
"谁先完成流程化改造,谁就更容易稳定产出。"
),
"impact": (
f"短期影响体现在效率,中长期影响体现在品牌认知。{p2}"
"如果只追求发布速度,内容会快速同质化;如果把洞察和表达打磨成体系,内容资产会持续增值。"
),
"risk": (
f"最大的风险不是‘不用 AI而是只用 AI{p3}"
"没有事实校对与人工观点把关,文章容易出现空泛表达、错误引用和结论过度。"
),
}
def _build_actions(self, points: list[str]) -> list[str]:
anchor = points[0] if points else "核心观点"
return [
f"先确定本篇唯一主线:围绕“{anchor}”展开,删除与主线无关的段落。",
"按“导语-观点-分析-建议-结语”五段式重排正文,每段只解决一个问题。",
"为每个核心观点补一条可验证依据(数据、案例或公开来源),提升可信度。",
"发布前做一次反 AI 味检查:删掉空话,替换为具体动作和明确对象。",
"将高表现文章沉淀为模板,下次复用同样结构提高稳定性。",
]
def _clean_source(self, text: str) -> str:
src = (text or "").replace("\r\n", "\n").strip()
src = re.sub(r"https?://\S+", "", src)
src = re.sub(r"(?m)^\s*>+\s*", "", src)
src = re.sub(r"(?m)^\s*[@#][^\s]+\s*$", "", src)
src = re.sub(r"\n{3,}", "\n\n", src)
src = re.sub(r"\s+", " ", src)
src = src.strip()
max_chars = max(1200, settings.openai_source_max_chars)
if len(src) > max_chars:
src = src[:max_chars] + " ...(原文过长,已截断后改写)"
return src
def _extract_sentences(self, text: str) -> list[str]:
parts = re.split(r"[。!?;;.!?\n]+", text)
cleaned = [p.strip(" ,;::。") for p in parts if p.strip()]
return cleaned
def _pick_key_points(self, sentences: list[str], limit: int) -> list[str]:
points: list[str] = []
templates = [
"核心变化:{}",
"关键问题:{}",
"方法调整:{}",
"结果反馈:{}",
"结论启示:{}",
]
for s in sentences:
if len(s) < 12:
continue
if len(points) >= limit:
break
normalized = re.sub(r"^(第一|第二|第三|第四|第五)[,:]?", "", s).strip()
normalized = re.sub(r"^[-•\\d\\.\\)\\s]+", "", normalized)
text = shorten(normalized, width=50, placeholder="...")
points.append(templates[len(points) % len(templates)].format(text))
if not points:
points = ["原始内容信息密度较高,建议先聚焦一个核心问题再展开"]
return points
def _parse_response_json(self, text: str) -> dict:
raw = (text or "").strip()
if not raw:
raise ValueError("empty model output")
try:
return json.loads(raw)
except json.JSONDecodeError:
pass
fenced = re.sub(r"^```(?:json)?\s*|\s*```$", "", raw, flags=re.IGNORECASE).strip()
if fenced != raw:
try:
return json.loads(fenced)
except json.JSONDecodeError:
pass
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
return json.loads(raw[start : end + 1])
raise ValueError("model output is not valid JSON")
def _call_model_json(self, user_prompt: str, timeout_sec: float) -> dict | None:
logger.info(
"AI request start model=%s timeout=%.1fs prefer_chat_first=%s prompt_chars=%d",
settings.openai_model,
timeout_sec,
self._prefer_chat_first,
len(user_prompt),
)
methods = ["chat", "responses"] if self._prefer_chat_first else ["responses", "chat"]
for method in methods:
if method == "responses":
try:
completion = self._client.responses.create(
model=settings.openai_model,
input=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
text={"format": {"type": "json_object"}},
timeout=timeout_sec,
)
output_text = completion.output_text or ""
logger.info("AI raw output (responses): %s", output_text)
return self._parse_response_json(output_text)
except Exception as exc:
logger.warning("responses API failed: %s", exc)
continue
if method == "chat":
try:
completion = self._client.chat.completions.create(
model=settings.openai_model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
max_tokens=1800,
temperature=0.4,
extra_body={"enable_thinking": False},
timeout=timeout_sec,
)
msg = completion.choices[0].message.content if completion.choices else ""
logger.info("AI raw output (chat.completions): %s", msg or "")
return self._parse_response_json(msg or "")
except Exception as exc:
logger.warning("chat.completions API failed: %s", exc)
# DashScope compatibility path: don't spend extra time on responses fallback.
if self._prefer_chat_first:
return None
continue
return None
def _normalize_result(self, data: dict) -> dict:
title = str(data.get("title", "")).strip()
summary = str(data.get("summary", "")).strip()
body = str(data.get("body_markdown", "")).strip()
if not title:
title = "公众号改写稿"
if not summary:
summary = shorten(re.sub(r"\s+", " ", body), width=110, placeholder="...")
body = self._ensure_sections(body)
body = self._format_markdown(body)
return {"title": title, "summary": summary, "body_markdown": body}
def _ensure_sections(self, body: str) -> str:
text = (body or "").strip()
required = ["## 导语", "## 核心观点", "## 深度分析", "## 落地建议", "## 结语"]
missing = [h for h in required if h not in text]
if not text:
text = "## 导语\n\n内容生成失败,请重试。\n"
if missing:
# Light touch: append missing sections to keep publish structure stable.
pads = "\n\n".join([f"{h}\n\n(待补充)" for h in missing])
text = f"{text}\n\n{pads}"
return text
def _quality_issues(self, req: RewriteRequest, source: str, normalized: dict) -> list[str]:
issues: list[str] = []
title = normalized.get("title", "")
summary = normalized.get("summary", "")
body = normalized.get("body_markdown", "")
if len(title) < 8 or len(title) > 34:
issues.append("标题长度不理想(建议 8-34 字)")
if len(summary) < 60:
issues.append("摘要过短,信息量不足")
headings = re.findall(r"(?m)^##\s+.+$", body)
if len(headings) < 5:
issues.append("二级标题不足,结构不完整")
paragraphs = [p.strip() for p in body.split("\n\n") if p.strip()]
if len(paragraphs) < 10:
issues.append("正文段落偏少,展开不充分")
if len(body) < 900:
issues.append("正文过短,无法支撑公众号发布")
if self._looks_like_raw_copy(source, body):
issues.append("改写与原文相似度过高,疑似未充分重写")
if req.avoid_words:
bad_words = [w.strip() for w in re.split(r"[,]\s*", req.avoid_words) if w.strip()]
hit = [w for w in bad_words if w in body or w in summary or w in title]
if hit:
issues.append(f"命中禁用词: {', '.join(hit)}")
ai_phrases = ["首先", "其次", "最后", "总而言之", "赋能", "闭环", "颠覆"]
hit_ai = [w for w in ai_phrases if body.count(w) >= 3]
if hit_ai:
issues.append("存在明显 AI 套话堆叠")
return issues
def _looks_like_raw_copy(self, source: str, rewritten: str) -> bool:
src = re.sub(r"\s+", "", source or "")
dst = re.sub(r"\s+", "", rewritten or "")
if not src or not dst:
return True
if dst in src or src in dst:
return True
ratio = difflib.SequenceMatcher(a=src[:3500], b=dst[:3500]).ratio()
return ratio >= 0.80
def _format_markdown(self, text: str) -> str:
body = text.replace("\r\n", "\n").strip()
body = re.sub(r"\n{3,}", "\n\n", body)
body = re.sub(r"(?m)^(#{1,3}\s[^\n]+)\n(?!\n)", r"\1\n\n", body)
return body.strip() + "\n"