Files
AiVideo/engine/script_gen.py
2026-03-25 19:35:37 +08:00

176 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
from typing import Any
from openai import OpenAI
from .config import AppConfig
from .types import Scene
def _looks_like_api_key(v: str) -> bool:
vv = (v or "").strip()
# Common prefixes: DashScope uses "sk-..."; we keep it minimal and permissive.
return bool(vv) and vv.startswith("sk-")
def _looks_like_url(v: str) -> bool:
vv = (v or "").strip()
return vv.startswith("http://") or vv.startswith("https://")
def _resolve_openai_credentials(cfg: AppConfig) -> tuple[str, str | None]:
api_key_env = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY") or "").strip()
base_url_env = str(cfg.get("openai.base_url_env", "OPENAI_BASE_URL") or "").strip()
# 1) Resolve api_key: allow both "env var name" and "literal key" for safety.
api_key = os.environ.get(api_key_env) if api_key_env else None
if not api_key and api_key_env and _looks_like_api_key(api_key_env):
api_key = api_key_env
if not api_key:
raise RuntimeError(f"Missing OpenAI compatible API key (env={api_key_env})")
# 2) Resolve base_url: allow both "env var name" and "literal URL".
base_url = os.environ.get(base_url_env) if base_url_env else None
if not base_url and base_url_env and _looks_like_url(base_url_env):
base_url = base_url_env
if base_url:
base_url = str(base_url).strip() or None
return str(api_key), base_url
def _system_prompt(scene_count: int, min_chars: int, max_chars: int) -> str:
return f"""你是一个专业短视频编剧与分镜师。
请把用户的创意扩展为 {scene_count} 个分镜(Scene) 的 JSON。
硬性约束:
1) 三个分镜的主角描述Character Description必须保持一致姓名/外观/服饰/风格不可前后矛盾。
2) 每个分镜必须包含字段image_prompt, video_motion, narration。
3) narration 为中文旁白,每段严格控制在约 {min_chars}-{max_chars} 字左右(宁可略短,不要超过太多)。
4) 画面描述要具体可视化video_motion 描述镜头运动/人物动作。
5) 只输出 JSON不要输出任何解释、markdown、代码块。
输出 JSON Schema示例结构
{{
"character_description": "...一致的主角设定...",
"scenes": [
{{"image_prompt":"...","video_motion":"...","narration":"..."}},
{{"image_prompt":"...","video_motion":"...","narration":"..."}},
{{"image_prompt":"...","video_motion":"...","narration":"..."}}
]
}}
"""
def _refine_system_prompt(min_chars: int, max_chars: int) -> str:
return f"""你是短视频分镜润色助手。
你会收到用户的原始创意 prompt、以及一组三分镜其中主角设定需一致
你的任务:只润色指定的一个 Scene使其更具体、更镜头化、更适合生成视频同时保持主角描述与其它分镜一致。
硬性约束:
1) 只修改目标 Scene不要改其它 Scene。
2) 目标 Scene 必须包含image_prompt, video_motion, narration。
3) narration 为中文旁白,每段控制在约 {min_chars}-{max_chars} 字左右。
4) 输出只允许 JSON不要解释、不要 markdown。
输出 JSON Schema
{{
"scene": {{"image_prompt":"...","video_motion":"...","narration":"..."}}
}}
"""
def generate_scenes(user_prompt: str, cfg: AppConfig) -> list[Scene]:
scene_count = int(cfg.get("script_gen.scene_count", 3))
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
max_chars = int(cfg.get("script_gen.narration_max_chars", 20))
model = str(cfg.get("openai.model", "gpt-4o-mini"))
api_key, base_url = _resolve_openai_credentials(cfg)
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": _system_prompt(scene_count, min_chars, max_chars)},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
temperature=0.6,
)
content = resp.choices[0].message.content or "{}"
data: Any = json.loads(content)
scenes_raw = data.get("scenes")
if not isinstance(scenes_raw, list) or len(scenes_raw) != scene_count:
raise ValueError(f"Model returned invalid scenes length: {type(scenes_raw)}")
scenes: list[Scene] = []
for i, s in enumerate(scenes_raw):
if not isinstance(s, dict):
raise ValueError(f"Scene[{i}] must be object, got {type(s)}")
image_prompt = str(s.get("image_prompt", "")).strip()
video_motion = str(s.get("video_motion", "")).strip()
narration = str(s.get("narration", "")).strip()
if not image_prompt or not narration:
raise ValueError(f"Scene[{i}] missing required fields")
scenes.append(Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration))
return scenes
def refine_scene(*, prompt: str, scenes: list[Scene], target_index: int, cfg: AppConfig) -> Scene:
if not (1 <= target_index <= len(scenes)):
raise ValueError("target_index out of range")
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
max_chars = int(cfg.get("script_gen.narration_max_chars", 20))
model = str(cfg.get("openai.model", "gpt-4o-mini"))
api_key, base_url = _resolve_openai_credentials(cfg)
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
scenes_payload = [
{"image_prompt": s.image_prompt, "video_motion": s.video_motion, "narration": s.narration}
for s in scenes
]
user_payload = {
"prompt": prompt,
"target_index": target_index,
"scenes": scenes_payload,
}
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": _refine_system_prompt(min_chars, max_chars)},
{"role": "user", "content": json.dumps(user_payload, ensure_ascii=False)},
],
response_format={"type": "json_object"},
temperature=0.6,
)
content = resp.choices[0].message.content or "{}"
data: Any = json.loads(content)
s = data.get("scene")
if not isinstance(s, dict):
raise ValueError("Model refine output missing scene")
image_prompt = str(s.get("image_prompt", "")).strip()
video_motion = str(s.get("video_motion", "")).strip()
narration = str(s.get("narration", "")).strip()
if not image_prompt or not narration:
raise ValueError("Refined scene missing required fields")
return Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration)