Files
AiVideo/engine/script_gen.py
2026-03-18 17:36:07 +08:00

81 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
from typing import Any
from openai import OpenAI
from .config import AppConfig
from .types import Scene
def _system_prompt(scene_count: int, min_chars: int, max_chars: int) -> str:
return f"""你是一个专业短视频编剧与分镜师。
请把用户的创意扩展为 {scene_count} 个分镜(Scene) 的 JSON。
硬性约束:
1) 三个分镜的主角描述Character Description必须保持一致姓名/外观/服饰/风格不可前后矛盾。
2) 每个分镜必须包含字段image_prompt, video_motion, narration。
3) narration 为中文旁白,每段严格控制在约 {min_chars}-{max_chars} 字左右(宁可略短,不要超过太多)。
4) 画面描述要具体可视化video_motion 描述镜头运动/人物动作。
5) 只输出 JSON不要输出任何解释、markdown、代码块。
输出 JSON Schema示例结构
{{
"character_description": "...一致的主角设定...",
"scenes": [
{{"image_prompt":"...","video_motion":"...","narration":"..."}},
{{"image_prompt":"...","video_motion":"...","narration":"..."}},
{{"image_prompt":"...","video_motion":"...","narration":"..."}}
]
}}
"""
def generate_scenes(user_prompt: str, cfg: AppConfig) -> list[Scene]:
scene_count = int(cfg.get("script_gen.scene_count", 3))
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
max_chars = int(cfg.get("script_gen.narration_max_chars", 20))
api_key_env = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY"))
base_url_env = str(cfg.get("openai.base_url_env", "OPENAI_BASE_URL"))
model = str(cfg.get("openai.model", "gpt-4o-mini"))
api_key = os.environ.get(api_key_env)
if not api_key:
raise RuntimeError(f"Missing env var {api_key_env} for OpenAI API key")
client = OpenAI(
api_key=api_key,
base_url=os.environ.get(base_url_env) or None,
)
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": _system_prompt(scene_count, min_chars, max_chars)},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
temperature=0.6,
)
content = resp.choices[0].message.content or "{}"
data: Any = json.loads(content)
scenes_raw = data.get("scenes")
if not isinstance(scenes_raw, list) or len(scenes_raw) != scene_count:
raise ValueError(f"Model returned invalid scenes length: {type(scenes_raw)}")
scenes: list[Scene] = []
for i, s in enumerate(scenes_raw):
if not isinstance(s, dict):
raise ValueError(f"Scene[{i}] must be object, got {type(s)}")
image_prompt = str(s.get("image_prompt", "")).strip()
video_motion = str(s.get("video_motion", "")).strip()
narration = str(s.get("narration", "")).strip()
if not image_prompt or not narration:
raise ValueError(f"Scene[{i}] missing required fields")
scenes.append(Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration))
return scenes