This commit is contained in:
Daniel
2026-04-28 11:50:55 +08:00
parent 1bbabc2a78
commit 2724e69b4f
20 changed files with 3881 additions and 554 deletions

View File

@@ -0,0 +1,473 @@
from __future__ import annotations
import asyncio
import base64
import logging
import re
import textwrap
from io import BytesIO
from pathlib import Path
import httpx
from openai import OpenAI
from PIL import Image, ImageDraw, ImageFont
from app.config import settings
from app.schemas import (
CoverGenerateResponse,
PosterGenerateRequest,
PosterGenerateResponse,
PosterPreviewItem,
WechatCoverGenerateRequest,
)
from app.services.wechat import WechatPublisher
logger = logging.getLogger(__name__)
_FONT_CANDIDATES = [
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/Hiragino Sans GB.ttc",
"/Library/Fonts/Arial Unicode.ttf",
]
def _split_paragraphs(body_markdown: str) -> list[str]:
raw = (body_markdown or "").replace("\r\n", "\n").strip()
if not raw:
return []
return [p.strip() for p in re.split(r"\n\s*\n+", raw) if p.strip()]
def _pick_font(size: int) -> ImageFont.ImageFont:
for p in _FONT_CANDIDATES:
if Path(p).is_file():
try:
return ImageFont.truetype(p, size=size)
except Exception:
continue
return ImageFont.load_default()
def _to_jpeg_under_limit(content: bytes, max_bytes: int) -> bytes:
im = Image.open(BytesIO(content)).convert("RGB")
widths = [1080, 1024, 960, 900, 840, 780, 720, 660]
qualities = [88, 82, 76, 70, 64, 58, 52]
for w in widths:
if im.width > w:
h = max(1, int(im.height * (w / im.width)))
cur = im.resize((w, h), Image.Resampling.LANCZOS)
else:
cur = im
for q in qualities:
buf = BytesIO()
cur.save(buf, format="JPEG", quality=q, optimize=True)
out = buf.getvalue()
if len(out) <= max_bytes:
return out
buf = BytesIO()
h = max(1, int(im.height * (640 / im.width)))
im.resize((640, h), Image.Resampling.LANCZOS).save(buf, format="JPEG", quality=48, optimize=True)
return buf.getvalue()
def _cover_to_jpeg(
content: bytes,
max_bytes: int,
size: tuple[int, int] = (900, 383),
title: str = "",
summary: str = "",
overlay_title: bool = False,
) -> bytes:
im = Image.open(BytesIO(content)).convert("RGB")
target_w, target_h = size
src_ratio = im.width / max(1, im.height)
dst_ratio = target_w / target_h
if src_ratio > dst_ratio:
new_w = int(im.height * dst_ratio)
x0 = max(0, (im.width - new_w) // 2)
im = im.crop((x0, 0, x0 + new_w, im.height))
elif src_ratio < dst_ratio:
new_h = int(im.width / dst_ratio)
y0 = max(0, (im.height - new_h) // 2)
im = im.crop((0, y0, im.width, y0 + new_h))
im = im.resize(size, Image.Resampling.LANCZOS)
if overlay_title:
im = _draw_cover_text_overlay(im, title, summary)
for q in [92, 88, 84, 80, 76, 72, 68, 62]:
buf = BytesIO()
im.save(buf, format="JPEG", quality=q, optimize=True, progressive=True)
out = buf.getvalue()
if len(out) <= max_bytes:
return out
buf = BytesIO()
im.save(buf, format="JPEG", quality=58, optimize=True)
return buf.getvalue()
def _draw_cover_text_overlay(im: Image.Image, title: str, summary: str) -> Image.Image:
im = im.convert("RGBA")
overlay = Image.new("RGBA", im.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
draw.rounded_rectangle([42, 46, 572, 337], radius=30, fill=(255, 255, 255, 228))
draw.rounded_rectangle([70, 74, 222, 119], radius=18, fill=(31, 111, 91, 245))
tag_font = _pick_font(26)
title_font = _pick_font(46)
summary_font = _pick_font(24)
draw.text((94, 83), "公众号封面", font=tag_font, fill=(255, 255, 255, 255))
clean_title = re.sub(r"\s+", "", title or "公众号文章")
title_lines = textwrap.wrap(clean_title, width=12)[:2]
y = 146
for line in title_lines:
draw.text((72, y), line, font=title_font, fill=(23, 32, 51, 255))
y += 56
clean_summary = re.sub(r"\s+", "", summary or "一眼看懂主题,明确文章价值。")
summary_lines = textwrap.wrap(clean_summary, width=23)[:2]
sy = max(y + 10, 270)
for line in summary_lines:
draw.text((74, sy), line, font=summary_font, fill=(104, 115, 133, 255))
sy += 30
return Image.alpha_composite(im, overlay).convert("RGB")
class PosterMaterialService:
def __init__(self, wechat: WechatPublisher) -> None:
self._wechat = wechat
self._image_client = None
if settings.openai_api_key:
self._image_client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url,
timeout=settings.openai_timeout,
max_retries=max(0, int(settings.openai_max_retries)),
)
async def generate(
self,
req: PosterGenerateRequest,
request_id: str = "",
account: dict | None = None,
) -> PosterGenerateResponse:
rid = request_id or "-"
paragraphs = _split_paragraphs(req.body_markdown)
if len(paragraphs) <= 1:
return PosterGenerateResponse(
ok=True,
detail="正文不足两段:按规则首段不生成图片,因此无需海报。",
posters=[],
body_markdown_with_posters=req.body_markdown,
warnings=[],
)
max_images = max(1, min(int(req.max_images or settings.poster_max_images), 12))
posters: list[PosterPreviewItem] = []
warnings: list[str] = []
wechat_urls_by_para: dict[int, str] = {}
for idx, paragraph in enumerate(paragraphs):
if idx == 0:
continue
if len(posters) >= max_images:
break
prompt = self._build_prompt(req, paragraph, idx, len(paragraphs))
jpeg_bytes, note = await asyncio.to_thread(self._create_poster_jpeg, prompt, paragraph, idx)
preview_data_url = "data:image/jpeg;base64," + base64.b64encode(jpeg_bytes).decode("ascii")
wechat_url = ""
uploaded = False
if req.upload_to_wechat:
if not account:
warnings.append("未绑定公众号:已生成海报预览,但未上传微信素材 URL。")
else:
filename = f"poster_p{idx + 1}.jpg"
out = await self._wechat.upload_article_image(
filename,
jpeg_bytes,
request_id=rid,
account=account,
)
if out.ok:
wechat_url = ((out.data or {}).get("url") or "").strip()
uploaded = bool(wechat_url)
if uploaded:
wechat_urls_by_para[idx] = wechat_url
else:
warnings.append(f"{idx + 1} 段海报上传失败:{out.detail}")
posters.append(
PosterPreviewItem(
paragraph_index=idx,
paragraph_excerpt=textwrap.shorten(paragraph.replace("\n", " "), width=80, placeholder=""),
prompt=prompt,
preview_data_url=preview_data_url,
wechat_url=wechat_url,
uploaded=uploaded,
note=note,
)
)
merged = req.body_markdown
if wechat_urls_by_para:
merged = self._merge_body_with_posters(paragraphs, wechat_urls_by_para)
detail = f"已生成 {len(posters)} 张段落海报(首段跳过)"
if req.upload_to_wechat:
detail += f",成功上传 {sum(1 for p in posters if p.uploaded)}"
logger.info(
"poster_generate rid=%s posters=%d upload_to_wechat=%s uploaded=%d warnings=%d",
rid,
len(posters),
req.upload_to_wechat,
sum(1 for p in posters if p.uploaded),
len(warnings),
)
return PosterGenerateResponse(
ok=True,
detail=detail,
posters=posters,
body_markdown_with_posters=merged,
warnings=warnings,
)
async def generate_cover(
self,
req: WechatCoverGenerateRequest,
request_id: str = "",
account: dict | None = None,
) -> CoverGenerateResponse:
rid = request_id or "-"
title = (req.title or "").strip()
summary = (req.summary or "").strip()
if not title:
return CoverGenerateResponse(ok=False, detail="请先填写标题,或先完成改写生成标题")
prompt = self._build_cover_prompt(req)
jpeg_bytes, note = await asyncio.to_thread(self._create_cover_jpeg, prompt, title, summary)
preview_data_url = "data:image/jpeg;base64," + base64.b64encode(jpeg_bytes).decode("ascii")
warnings: list[str] = []
thumb_media_id = ""
if req.upload_to_wechat:
if not account:
warnings.append("未绑定公众号:已生成封面预览,但未上传为微信封面素材。")
else:
out = await self._wechat.upload_cover("wechat_cover_900x383.jpg", jpeg_bytes, request_id=rid, account=account)
if out.ok:
thumb_media_id = ((out.data or {}).get("thumb_media_id") or "").strip()
else:
warnings.append(f"封面上传失败:{out.detail}")
detail = "已生成公众号封面900×383"
if thumb_media_id:
detail += ",并已绑定 thumb_media_id"
elif warnings:
detail += ",但未完成微信绑定"
logger.info(
"cover_generate rid=%s title_chars=%d upload_to_wechat=%s uploaded=%s note=%s warnings=%d",
rid,
len(title),
req.upload_to_wechat,
bool(thumb_media_id),
note,
len(warnings),
)
return CoverGenerateResponse(
ok=True,
detail=detail,
preview_data_url=preview_data_url,
thumb_media_id=thumb_media_id,
width=900,
height=383,
note=note,
warnings=warnings,
)
def _build_cover_prompt(self, req: WechatCoverGenerateRequest) -> str:
title = (req.title or "公众号文章").strip()
summary = (req.summary or "").strip()
style_hint = (req.style_hint or "").strip() or "成熟公众号封面,清晰、克制、信息强,适合作为文章列表首图"
return (
"生成一张微信公众号文章封面图,最终会裁切为 900x383 横版比例。"
f"封面主标题:{title}"
f"文章摘要:{summary}"
f"风格要求:{style_hint}"
"画面要突出封面的点击引导作用:主题明确、视觉焦点强、留出标题安全区、中文字少且清晰。"
"不要出现二维码、水印、品牌 logo、真人肖像、杂乱小字和侵权素材。"
)
def _build_prompt(self, req: PosterGenerateRequest, paragraph: str, idx: int, total: int) -> str:
title = (req.title or "公众号内容").strip()
summary = (req.summary or "").strip()
style_hint = (req.style_hint or "").strip() or "现代、干净、中文可读、公众号海报风格"
para = paragraph.strip()
return (
"请生成一张中文竖版海报,适合公众号正文插图。"
f"主题标题:{title}"
f"这是第 {idx + 1}/{total} 段对应海报(首段不配图)。"
f"段落核心内容:{para}"
f"摘要参考:{summary}"
f"风格要求:{style_hint}"
"画面需信息聚焦、可读性强不要出现水印、二维码、logo、真人肖像。"
)
def _create_poster_jpeg(self, prompt: str, paragraph: str, idx: int) -> tuple[bytes, str]:
max_bytes = max(300_000, int(settings.poster_upload_max_bytes or 950_000))
if self._image_client:
try:
raw = self._generate_with_model(prompt)
if raw:
return _to_jpeg_under_limit(raw, max_bytes), "ai"
except Exception as exc:
logger.warning("poster_ai_failed detail=%s", str(exc)[:240])
fallback = self._generate_fallback_poster(paragraph, idx)
return _to_jpeg_under_limit(fallback, max_bytes), "fallback"
def _create_cover_jpeg(self, prompt: str, title: str, summary: str) -> tuple[bytes, str]:
max_bytes = max(300_000, int(settings.poster_upload_max_bytes or 950_000))
if self._image_client:
try:
raw = self._generate_with_model(prompt)
if raw:
return _cover_to_jpeg(raw, max_bytes, title=title, summary=summary, overlay_title=True), "ai_900x383"
except Exception as exc:
logger.warning("cover_ai_failed detail=%s", str(exc)[:240])
fallback = self._generate_fallback_cover(title, summary)
return _cover_to_jpeg(fallback, max_bytes), "fallback_900x383"
def _generate_with_model(self, prompt: str) -> bytes | None:
rsp = self._image_client.images.generate(
model=settings.openai_image_model,
prompt=prompt,
size=settings.poster_image_size,
)
data = getattr(rsp, "data", None) or []
if not data:
return None
first = data[0]
b64 = ""
image_url = ""
if isinstance(first, dict):
b64 = (first.get("b64_json") or "").strip()
image_url = (first.get("url") or "").strip()
else:
b64 = (getattr(first, "b64_json", "") or "").strip()
image_url = (getattr(first, "url", "") or "").strip()
if b64:
return base64.b64decode(b64)
if image_url:
with httpx.Client(timeout=30) as client:
r = client.get(image_url)
r.raise_for_status()
return r.content
return None
def _generate_fallback_poster(self, paragraph: str, idx: int) -> bytes:
w, h = 1080, 1520
im = Image.new("RGB", (w, h), (240, 246, 255))
draw = ImageDraw.Draw(im)
for y in range(h):
c = int(240 - (y / h) * 36)
draw.line([(0, y), (w, y)], fill=(c, c + 6, 255), width=1)
for i in range(8):
x0 = int(w * 0.08) + i * 54
y0 = int(h * 0.66) + i * 22
x1 = x0 + 260
y1 = y0 + 100
color = (160 - i * 8, 190 - i * 9, 230 - i * 8)
draw.rounded_rectangle([x0, y0, x1, y1], radius=24, outline=color, width=2)
tag_font = _pick_font(36)
title_font = _pick_font(58)
body_font = _pick_font(42)
draw.rounded_rectangle([70, 70, 340, 142], radius=20, fill=(31, 77, 185))
draw.text((102, 90), f"段落 {idx + 1}", font=tag_font, fill=(255, 255, 255))
draw.text((70, 190), "AI 图文海报", font=title_font, fill=(16, 42, 102))
words = re.sub(r"\s+", "", paragraph)
if len(words) > 120:
words = words[:120] + ""
wrapped = textwrap.fill(words, width=19)
draw.multiline_text(
(72, 330),
wrapped,
font=body_font,
fill=(35, 54, 92),
spacing=14,
align="left",
)
buf = BytesIO()
im.save(buf, format="PNG")
return buf.getvalue()
def _generate_fallback_cover(self, title: str, summary: str) -> bytes:
w, h = 900, 383
im = Image.new("RGB", (w, h), (247, 249, 252))
draw = ImageDraw.Draw(im)
for y in range(h):
t = y / h
r = int(252 - t * 28)
g = int(250 - t * 18)
b = int(241 - t * 8)
draw.line([(0, y), (w, y)], fill=(r, g, b), width=1)
draw.rounded_rectangle([36, 34, 864, 349], radius=34, fill=(255, 255, 255), outline=(223, 229, 238), width=2)
draw.rounded_rectangle([604, 60, 830, 290], radius=34, fill=(238, 244, 241))
draw.ellipse([660, 95, 810, 245], fill=(229, 196, 122))
draw.ellipse([690, 126, 780, 216], fill=(255, 250, 229))
draw.arc([684, 118, 784, 226], start=20, end=168, fill=(199, 159, 81), width=5)
tag_font = _pick_font(28)
title_font = _pick_font(48)
summary_font = _pick_font(24)
small_font = _pick_font(20)
draw.rounded_rectangle([72, 70, 226, 118], radius=18, fill=(31, 111, 91))
draw.text((96, 80), "公众号封面", font=tag_font, fill=(255, 255, 255))
clean_title = re.sub(r"\s+", "", title or "公众号文章")
title_lines = textwrap.wrap(clean_title, width=12)[:2]
y = 146
for line in title_lines:
draw.text((72, y), line, font=title_font, fill=(23, 32, 51))
y += 58
clean_summary = re.sub(r"\s+", "", summary or "清晰表达主题,让读者一眼知道文章价值。")
summary_lines = textwrap.wrap(clean_summary, width=24)[:2]
sy = max(y + 12, 268)
for line in summary_lines:
draw.text((74, sy), line, font=summary_font, fill=(104, 115, 133))
sy += 32
draw.text((72, 320), "900 x 383", font=small_font, fill=(140, 150, 166))
buf = BytesIO()
im.save(buf, format="PNG")
return buf.getvalue()
def _merge_body_with_posters(self, paragraphs: list[str], wechat_urls_by_para: dict[int, str]) -> str:
merged: list[str] = []
for idx, para in enumerate(paragraphs):
if idx > 0:
url = (wechat_urls_by_para.get(idx) or "").strip()
if url:
merged.append(f"![段落配图 {idx + 1}]({url})")
merged.append(para)
return "\n\n".join(merged)

View File

@@ -66,6 +66,27 @@ class UserStore:
)
"""
)
pref_cols = self._table_columns(c, "user_prefs")
if "active_ai_model_id" not in pref_cols:
c.execute("ALTER TABLE user_prefs ADD COLUMN active_ai_model_id INTEGER")
c.execute(
"""
CREATE TABLE IF NOT EXISTS ai_models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
model_name TEXT NOT NULL,
api_key TEXT NOT NULL,
base_url TEXT NOT NULL DEFAULT '',
model TEXT NOT NULL,
timeout_sec REAL NOT NULL DEFAULT 120.0,
max_output_tokens INTEGER NOT NULL DEFAULT 8192,
max_retries INTEGER NOT NULL DEFAULT 0,
updated_at INTEGER NOT NULL,
UNIQUE(user_id, model_name),
FOREIGN KEY(user_id) REFERENCES users(id)
)
"""
)
# 兼容历史单绑定结构,自动迁移为默认账号
rows = c.execute(
"SELECT user_id, appid, secret, author, thumb_media_id, thumb_image_path, updated_at FROM wechat_bindings"
@@ -111,7 +132,16 @@ class UserStore:
return {str(r["name"]) for r in rows}
def _ensure_users_table(self, c: sqlite3.Connection) -> None:
required = {"id", "username", "password_hash", "password_salt", "created_at"}
required = {
"id",
"username",
"password_hash",
"password_salt",
"reset_code_hash",
"reset_code_salt",
"created_at",
"deleted_at",
}
c.execute(
"""
CREATE TABLE IF NOT EXISTS users (
@@ -119,7 +149,10 @@ class UserStore:
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
password_salt TEXT NOT NULL,
created_at INTEGER NOT NULL
reset_code_hash TEXT NOT NULL DEFAULT '',
reset_code_salt TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
deleted_at INTEGER
)
"""
)
@@ -137,29 +170,44 @@ class UserStore:
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
password_salt TEXT NOT NULL,
created_at INTEGER NOT NULL
reset_code_hash TEXT NOT NULL,
reset_code_salt TEXT NOT NULL,
created_at INTEGER NOT NULL,
deleted_at INTEGER
)
"""
)
if {"username", "password_hash", "password_salt"}.issubset(cols):
if "created_at" in cols:
rows = c.execute("SELECT * FROM users").fetchall()
for r in rows:
username = (r["username"] or "").strip()
if not username:
continue
reset_salt = r["reset_code_salt"] if "reset_code_salt" in cols else secrets.token_hex(8)
reset_hash = r["reset_code_hash"] if "reset_code_hash" in cols else ""
if not reset_hash:
legacy_code = self._generate_reset_code()
reset_hash = self._hash_reset_code(legacy_code, reset_salt)
created_at = int(r["created_at"]) if "created_at" in cols and r["created_at"] else now
deleted_at = int(r["deleted_at"]) if "deleted_at" in cols and r["deleted_at"] else None
c.execute(
"""
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
SELECT id, username, password_hash, password_salt, COALESCE(created_at, ?)
FROM users
INSERT OR IGNORE INTO users_new(
id, username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(now,),
)
else:
c.execute(
"""
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
SELECT id, username, password_hash, password_salt, ?
FROM users
""",
(now,),
(
int(r["id"]),
username,
r["password_hash"],
r["password_salt"],
reset_hash,
reset_salt,
created_at,
deleted_at,
),
)
elif {"username", "password"}.issubset(cols):
if "created_at" in cols:
@@ -173,13 +221,18 @@ class UserStore:
continue
salt = secrets.token_hex(16)
pwd_hash = self._hash_password(raw_pwd, salt)
reset_code = self._generate_reset_code()
reset_salt = secrets.token_hex(8)
reset_hash = self._hash_reset_code(reset_code, reset_salt)
created_at = int(r["created_at"]) if "created_at" in r.keys() and r["created_at"] else now
c.execute(
"""
INSERT OR IGNORE INTO users_new(id, username, password_hash, password_salt, created_at)
VALUES (?, ?, ?, ?, ?)
INSERT OR IGNORE INTO users_new(
id, username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, NULL)
""",
(int(r["id"]), username, pwd_hash, salt, created_at),
(int(r["id"]), username, pwd_hash, salt, reset_hash, reset_salt, created_at),
)
c.execute("DROP TABLE users")
@@ -222,18 +275,31 @@ class UserStore:
def _hash_token(self, token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _generate_reset_code(self) -> str:
return secrets.token_urlsafe(9).replace("-", "").replace("_", "")[:12]
def _hash_reset_code(self, reset_code: str, salt: str) -> str:
return hashlib.sha256(f"{salt}:{reset_code}".encode("utf-8")).hexdigest()
def create_user(self, username: str, password: str) -> dict | None:
now = int(time.time())
salt = secrets.token_hex(16)
pwd_hash = self._hash_password(password, salt)
reset_code = self._generate_reset_code()
reset_salt = secrets.token_hex(8)
reset_hash = self._hash_reset_code(reset_code, reset_salt)
try:
with self._conn() as c:
cur = c.execute(
"INSERT INTO users(username, password_hash, password_salt, created_at) VALUES (?, ?, ?, ?)",
(username, pwd_hash, salt, now),
"""
INSERT INTO users(
username, password_hash, password_salt, reset_code_hash, reset_code_salt, created_at, deleted_at
) VALUES (?, ?, ?, ?, ?, ?, NULL)
""",
(username, pwd_hash, salt, reset_hash, reset_salt, now),
)
uid = int(cur.lastrowid)
return {"id": uid, "username": username}
return {"id": uid, "username": username, "reset_code": reset_code}
except sqlite3.IntegrityError:
return None
except sqlite3.Error as exc:
@@ -243,7 +309,11 @@ class UserStore:
try:
with self._conn() as c:
row = c.execute(
"SELECT id, username, password_hash, password_salt FROM users WHERE username=?",
"""
SELECT id, username, password_hash, password_salt
FROM users
WHERE username=? AND deleted_at IS NULL
""",
(username,),
).fetchone()
except sqlite3.Error as exc:
@@ -274,14 +344,27 @@ class UserStore:
)
return True
def reset_password_by_username(self, username: str, new_password: str) -> bool:
def reset_password_by_username(self, username: str, reset_code: str, new_password: str) -> bool:
uname = (username or "").strip()
rcode = (reset_code or "").strip()
if not uname:
return False
with self._conn() as c:
row = c.execute("SELECT id FROM users WHERE username=?", (uname,)).fetchone()
row = c.execute(
"""
SELECT id, reset_code_hash, reset_code_salt
FROM users
WHERE username=? AND deleted_at IS NULL
""",
(uname,),
).fetchone()
if not row:
return False
if not rcode:
return False
calc = self._hash_reset_code(rcode, row["reset_code_salt"] or "")
if not hmac.compare_digest(calc, row["reset_code_hash"] or ""):
return False
new_salt = secrets.token_hex(16)
new_hash = self._hash_password(new_password, new_salt)
c.execute(
@@ -324,7 +407,7 @@ class UserStore:
SELECT u.id, u.username
FROM sessions s
JOIN users u ON u.id=s.user_id
WHERE s.token_hash=? AND s.expires_at>=?
WHERE s.token_hash=? AND s.expires_at>=? AND u.deleted_at IS NULL
""",
(th, now),
).fetchone()
@@ -332,6 +415,37 @@ class UserStore:
return None
return {"id": int(row["id"]), "username": row["username"]}
def delete_user_logically(self, user_id: int, password: str, reset_code: str) -> bool:
now = int(time.time())
with self._conn() as c:
row = c.execute(
"""
SELECT id, password_hash, password_salt, reset_code_hash, reset_code_salt
FROM users
WHERE id=? AND deleted_at IS NULL
""",
(user_id,),
).fetchone()
if not row:
return False
calc_pwd = self._hash_password(password or "", row["password_salt"] or "")
if not hmac.compare_digest(calc_pwd, row["password_hash"] or ""):
return False
calc_reset = self._hash_reset_code(reset_code or "", row["reset_code_salt"] or "")
if not hmac.compare_digest(calc_reset, row["reset_code_hash"] or ""):
return False
c.execute("DELETE FROM sessions WHERE user_id=?", (user_id,))
c.execute("DELETE FROM wechat_accounts WHERE user_id=?", (user_id,))
c.execute("DELETE FROM ai_models WHERE user_id=?", (user_id,))
c.execute("DELETE FROM user_prefs WHERE user_id=?", (user_id,))
c.execute("DELETE FROM wechat_bindings WHERE user_id=?", (user_id,))
c.execute(
"UPDATE users SET deleted_at=?, username=username || '#deleted' || ? WHERE id=?",
(now, str(now), user_id),
)
return True
def save_wechat_binding(
self,
user_id: int,
@@ -510,3 +624,230 @@ class UserStore:
"thumb_image_path": row["thumb_image_path"] or "",
"updated_at": int(row["updated_at"] or 0),
}
def list_ai_models(self, user_id: int) -> list[dict]:
with self._conn() as c:
rows = c.execute(
"""
SELECT id, model_name, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at
FROM ai_models
WHERE user_id=?
ORDER BY updated_at DESC, id DESC
""",
(user_id,),
).fetchall()
pref = c.execute(
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
(user_id,),
).fetchone()
active_id = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
out: list[dict] = []
for r in rows:
out.append(
{
"id": int(r["id"]),
"model_name": r["model_name"] or "",
"base_url": r["base_url"] or "",
"model": r["model"] or "",
"timeout_sec": float(r["timeout_sec"] or 120.0),
"max_output_tokens": int(r["max_output_tokens"] or 8192),
"max_retries": int(r["max_retries"] or 0),
"updated_at": int(r["updated_at"] or 0),
"active": int(r["id"]) == active_id,
}
)
return out
def add_ai_model(
self,
user_id: int,
model_name: str,
api_key: str,
base_url: str,
model: str,
timeout_sec: float = 120.0,
max_output_tokens: int = 8192,
max_retries: int = 0,
) -> dict:
now = int(time.time())
name = model_name.strip() or f"模型{now % 10000}"
with self._conn() as c:
try:
cur = c.execute(
"""
INSERT INTO ai_models
(user_id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(user_id, name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, now),
)
except sqlite3.IntegrityError:
name = f"{name}-{now % 1000}"
cur = c.execute(
"""
INSERT INTO ai_models
(user_id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(user_id, name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, now),
)
aid = int(cur.lastrowid)
c.execute(
"""
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
active_ai_model_id=excluded.active_ai_model_id,
updated_at=excluded.updated_at
""",
(user_id, aid, now),
)
return {"id": aid, "model_name": name}
def switch_active_ai_model(self, user_id: int, model_id: int) -> bool:
now = int(time.time())
with self._conn() as c:
row = c.execute(
"SELECT id FROM ai_models WHERE id=? AND user_id=?",
(model_id, user_id),
).fetchone()
if not row:
return False
c.execute(
"""
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
active_ai_model_id=excluded.active_ai_model_id,
updated_at=excluded.updated_at
""",
(user_id, model_id, now),
)
return True
def delete_ai_model(self, user_id: int, model_id: int) -> bool:
now = int(time.time())
with self._conn() as c:
row = c.execute(
"SELECT id FROM ai_models WHERE id=? AND user_id=?",
(model_id, user_id),
).fetchone()
if not row:
return False
c.execute("DELETE FROM ai_models WHERE id=? AND user_id=?", (model_id, user_id))
pref = c.execute(
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
(user_id,),
).fetchone()
active_id = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
if active_id == model_id:
replacement = c.execute(
"""
SELECT id FROM ai_models WHERE user_id=?
ORDER BY updated_at DESC, id DESC
LIMIT 1
""",
(user_id,),
).fetchone()
replacement_id = int(replacement["id"]) if replacement else None
c.execute(
"""
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
active_ai_model_id=excluded.active_ai_model_id,
updated_at=excluded.updated_at
""",
(user_id, replacement_id, now),
)
return True
def get_active_ai_model(self, user_id: int) -> dict | None:
with self._conn() as c:
pref = c.execute(
"SELECT active_ai_model_id FROM user_prefs WHERE user_id=?",
(user_id,),
).fetchone()
aid = int(pref["active_ai_model_id"]) if pref and pref["active_ai_model_id"] else None
row = None
if aid:
row = c.execute(
"""
SELECT id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at
FROM ai_models
WHERE id=? AND user_id=?
""",
(aid, user_id),
).fetchone()
if not row:
row = c.execute(
"""
SELECT id, model_name, api_key, base_url, model, timeout_sec, max_output_tokens, max_retries, updated_at
FROM ai_models
WHERE user_id=?
ORDER BY updated_at DESC, id DESC
LIMIT 1
""",
(user_id,),
).fetchone()
if row:
c.execute(
"""
INSERT INTO user_prefs(user_id, active_ai_model_id, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
active_ai_model_id=excluded.active_ai_model_id,
updated_at=excluded.updated_at
""",
(user_id, int(row["id"]), int(time.time())),
)
if not row:
return None
return {
"id": int(row["id"]),
"model_name": row["model_name"] or "",
"api_key": row["api_key"] or "",
"base_url": row["base_url"] or "",
"model": row["model"] or "",
"timeout_sec": float(row["timeout_sec"] or 120.0),
"max_output_tokens": int(row["max_output_tokens"] or 8192),
"max_retries": int(row["max_retries"] or 0),
"updated_at": int(row["updated_at"] or 0),
}
def delete_wechat_binding(self, user_id: int, account_id: int) -> bool:
now = int(time.time())
with self._conn() as c:
row = c.execute(
"SELECT id FROM wechat_accounts WHERE id=? AND user_id=?",
(account_id, user_id),
).fetchone()
if not row:
return False
c.execute("DELETE FROM wechat_accounts WHERE id=? AND user_id=?", (account_id, user_id))
pref = c.execute(
"SELECT active_wechat_account_id FROM user_prefs WHERE user_id=?",
(user_id,),
).fetchone()
active_id = int(pref["active_wechat_account_id"]) if pref and pref["active_wechat_account_id"] else None
if active_id == account_id:
replacement = c.execute(
"""
SELECT id FROM wechat_accounts WHERE user_id=?
ORDER BY updated_at DESC, id DESC
LIMIT 1
""",
(user_id,),
).fetchone()
replacement_id = int(replacement["id"]) if replacement else None
c.execute(
"""
INSERT INTO user_prefs(user_id, active_wechat_account_id, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
active_wechat_account_id=excluded.active_wechat_account_id,
updated_at=excluded.updated_at
""",
(user_id, replacement_id, now),
)
return True

View File

@@ -69,11 +69,11 @@ class WechatPublisher:
def _resolve_account(self, account: dict | None = None) -> dict[str, str]:
src = account or {}
appid = (src.get("appid") or settings.wechat_appid or "").strip()
secret = (src.get("secret") or settings.wechat_secret or "").strip()
author = (src.get("author") or settings.wechat_author or "").strip()
thumb_media_id = (src.get("thumb_media_id") or settings.wechat_thumb_media_id or "").strip()
thumb_image_path = (src.get("thumb_image_path") or settings.wechat_thumb_image_path or "").strip()
appid = (src.get("appid") or "").strip()
secret = (src.get("secret") or "").strip()
author = (src.get("author") or "").strip()
thumb_media_id = (src.get("thumb_media_id") or "").strip()
thumb_image_path = (src.get("thumb_image_path") or "").strip()
return {
"appid": appid,
"secret": secret,
@@ -132,7 +132,7 @@ class WechatPublisher:
{
"article_type": "news",
"title": req.title[:32] if len(req.title) > 32 else req.title,
"author": (req.author or acct["author"] or settings.wechat_author)[:16],
"author": (req.author or acct["author"] or "AI发糕")[:16],
"digest": (req.summary or "")[:128],
"content": html,
"content_source_url": "",
@@ -246,6 +246,37 @@ class WechatPublisher:
)
return PublishResponse(ok=True, detail="素材上传成功", data=material)
async def upload_article_image(
self, filename: str, content: bytes, request_id: str = "", account: dict | None = None
) -> PublishResponse:
"""上传图文正文图片uploadimg返回可直接插入正文 HTML/Markdown 的 URL。"""
rid = request_id or "-"
acct = self._resolve_account(account)
if not acct["appid"] or not acct["secret"]:
return PublishResponse(ok=False, detail="缺少 WECHAT_APPID / WECHAT_SECRET 配置")
if not content:
return PublishResponse(ok=False, detail="素材文件为空")
token, _, token_err_body = await self._get_access_token(acct["appid"], acct["secret"])
if not token:
return PublishResponse(ok=False, detail=_detail_for_token_error(token_err_body), data=token_err_body)
async with httpx.AsyncClient(timeout=60) as client:
out = await self._upload_article_image_url(client, token, content, filename)
if not out:
return PublishResponse(
ok=False,
detail="正文配图上传失败请检查图片格式与大小jpg/png建议小于 1MB或查看日志 wechat_uploadimg_failed",
)
logger.info(
"wechat_uploadimg_ok rid=%s filename=%s url=%s",
rid,
filename,
out.get("url"),
)
return PublishResponse(ok=True, detail="正文配图上传成功", data=out)
async def _upload_permanent_image(
self, client: httpx.AsyncClient, token: str, content: bytes, filename: str
) -> dict[str, str] | None:
@@ -263,6 +294,23 @@ class WechatPublisher:
return None
return {"media_id": mid, "url": data.get("url") or ""}
async def _upload_article_image_url(
self, client: httpx.AsyncClient, token: str, content: bytes, filename: str
) -> dict[str, str] | None:
url = f"https://api.weixin.qq.com/cgi-bin/media/uploadimg?access_token={token}"
ctype = "image/png" if filename.lower().endswith(".png") else "image/jpeg"
files = {"media": (filename, content, ctype)}
r = await client.post(url, files=files)
data = r.json() if r.content else {}
if isinstance(data, dict) and data.get("errcode"):
logger.warning("wechat_uploadimg_failed body=%s", data)
return None
image_url = (data.get("url") if isinstance(data, dict) else "") or ""
if not image_url:
logger.warning("wechat_uploadimg_no_url body=%s", data)
return None
return {"url": image_url}
async def _resolve_thumb_media_id(
self, token: str, rid: str, *, force_skip_explicit: bool = False, account: dict | None = None
) -> str | None: