Files
AiVideo/engine/script_gen.py
2026-03-25 13:33:48 +08:00

152 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
from typing import Any
from openai import OpenAI
from .config import AppConfig
from .types import Scene
def _system_prompt(scene_count: int, min_chars: int, max_chars: int) -> str:
return f"""你是一个专业短视频编剧与分镜师。
请把用户的创意扩展为 {scene_count} 个分镜(Scene) 的 JSON。
硬性约束:
1) 三个分镜的主角描述Character Description必须保持一致姓名/外观/服饰/风格不可前后矛盾。
2) 每个分镜必须包含字段image_prompt, video_motion, narration。
3) narration 为中文旁白,每段严格控制在约 {min_chars}-{max_chars} 字左右(宁可略短,不要超过太多)。
4) 画面描述要具体可视化video_motion 描述镜头运动/人物动作。
5) 只输出 JSON不要输出任何解释、markdown、代码块。
输出 JSON Schema示例结构
{{
"character_description": "...一致的主角设定...",
"scenes": [
{{"image_prompt":"...","video_motion":"...","narration":"..."}},
{{"image_prompt":"...","video_motion":"...","narration":"..."}},
{{"image_prompt":"...","video_motion":"...","narration":"..."}}
]
}}
"""
def _refine_system_prompt(min_chars: int, max_chars: int) -> str:
return f"""你是短视频分镜润色助手。
你会收到用户的原始创意 prompt、以及一组三分镜其中主角设定需一致
你的任务:只润色指定的一个 Scene使其更具体、更镜头化、更适合生成视频同时保持主角描述与其它分镜一致。
硬性约束:
1) 只修改目标 Scene不要改其它 Scene。
2) 目标 Scene 必须包含image_prompt, video_motion, narration。
3) narration 为中文旁白,每段控制在约 {min_chars}-{max_chars} 字左右。
4) 输出只允许 JSON不要解释、不要 markdown。
输出 JSON Schema
{{
"scene": {{"image_prompt":"...","video_motion":"...","narration":"..."}}
}}
"""
def generate_scenes(user_prompt: str, cfg: AppConfig) -> list[Scene]:
scene_count = int(cfg.get("script_gen.scene_count", 3))
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
max_chars = int(cfg.get("script_gen.narration_max_chars", 20))
api_key_env = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY"))
base_url_env = str(cfg.get("openai.base_url_env", "OPENAI_BASE_URL"))
model = str(cfg.get("openai.model", "gpt-4o-mini"))
api_key = os.environ.get(api_key_env)
if not api_key:
raise RuntimeError(f"Missing env var {api_key_env} for OpenAI API key")
client = OpenAI(
api_key=api_key,
base_url=os.environ.get(base_url_env) or None,
)
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": _system_prompt(scene_count, min_chars, max_chars)},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
temperature=0.6,
)
content = resp.choices[0].message.content or "{}"
data: Any = json.loads(content)
scenes_raw = data.get("scenes")
if not isinstance(scenes_raw, list) or len(scenes_raw) != scene_count:
raise ValueError(f"Model returned invalid scenes length: {type(scenes_raw)}")
scenes: list[Scene] = []
for i, s in enumerate(scenes_raw):
if not isinstance(s, dict):
raise ValueError(f"Scene[{i}] must be object, got {type(s)}")
image_prompt = str(s.get("image_prompt", "")).strip()
video_motion = str(s.get("video_motion", "")).strip()
narration = str(s.get("narration", "")).strip()
if not image_prompt or not narration:
raise ValueError(f"Scene[{i}] missing required fields")
scenes.append(Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration))
return scenes
def refine_scene(*, prompt: str, scenes: list[Scene], target_index: int, cfg: AppConfig) -> Scene:
if not (1 <= target_index <= len(scenes)):
raise ValueError("target_index out of range")
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
max_chars = int(cfg.get("script_gen.narration_max_chars", 20))
api_key_env = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY"))
base_url_env = str(cfg.get("openai.base_url_env", "OPENAI_BASE_URL"))
model = str(cfg.get("openai.model", "gpt-4o-mini"))
api_key = os.environ.get(api_key_env)
if not api_key:
raise RuntimeError(f"Missing env var {api_key_env} for OpenAI API key")
client = OpenAI(
api_key=api_key,
base_url=os.environ.get(base_url_env) or None,
)
scenes_payload = [
{"image_prompt": s.image_prompt, "video_motion": s.video_motion, "narration": s.narration}
for s in scenes
]
user_payload = {
"prompt": prompt,
"target_index": target_index,
"scenes": scenes_payload,
}
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": _refine_system_prompt(min_chars, max_chars)},
{"role": "user", "content": json.dumps(user_payload, ensure_ascii=False)},
],
response_format={"type": "json_object"},
temperature=0.6,
)
content = resp.choices[0].message.content or "{}"
data: Any = json.loads(content)
s = data.get("scene")
if not isinstance(s, dict):
raise ValueError("Model refine output missing scene")
image_prompt = str(s.get("image_prompt", "")).strip()
video_motion = str(s.get("video_motion", "")).strip()
narration = str(s.get("narration", "")).strip()
if not image_prompt or not narration:
raise ValueError("Refined scene missing required fields")
return Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration)