fix: 优化内容
This commit is contained in:
@@ -33,6 +33,24 @@ def _system_prompt(scene_count: int, min_chars: int, max_chars: int) -> str:
|
||||
"""
|
||||
|
||||
|
||||
def _refine_system_prompt(min_chars: int, max_chars: int) -> str:
|
||||
return f"""你是短视频分镜润色助手。
|
||||
你会收到用户的原始创意 prompt、以及一组三分镜(其中主角设定需一致)。
|
||||
你的任务:只润色指定的一个 Scene,使其更具体、更镜头化、更适合生成视频,同时保持主角描述与其它分镜一致。
|
||||
|
||||
硬性约束:
|
||||
1) 只修改目标 Scene,不要改其它 Scene。
|
||||
2) 目标 Scene 必须包含:image_prompt, video_motion, narration。
|
||||
3) narration 为中文旁白,每段控制在约 {min_chars}-{max_chars} 字左右。
|
||||
4) 输出只允许 JSON,不要解释、不要 markdown。
|
||||
|
||||
输出 JSON Schema:
|
||||
{{
|
||||
"scene": {{"image_prompt":"...","video_motion":"...","narration":"..."}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def generate_scenes(user_prompt: str, cfg: AppConfig) -> list[Scene]:
|
||||
scene_count = int(cfg.get("script_gen.scene_count", 3))
|
||||
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
|
||||
@@ -78,3 +96,56 @@ def generate_scenes(user_prompt: str, cfg: AppConfig) -> list[Scene]:
|
||||
raise ValueError(f"Scene[{i}] missing required fields")
|
||||
scenes.append(Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration))
|
||||
return scenes
|
||||
|
||||
|
||||
def refine_scene(*, prompt: str, scenes: list[Scene], target_index: int, cfg: AppConfig) -> Scene:
|
||||
if not (1 <= target_index <= len(scenes)):
|
||||
raise ValueError("target_index out of range")
|
||||
|
||||
min_chars = int(cfg.get("script_gen.narration_min_chars", 15))
|
||||
max_chars = int(cfg.get("script_gen.narration_max_chars", 20))
|
||||
|
||||
api_key_env = str(cfg.get("openai.api_key_env", "OPENAI_API_KEY"))
|
||||
base_url_env = str(cfg.get("openai.base_url_env", "OPENAI_BASE_URL"))
|
||||
model = str(cfg.get("openai.model", "gpt-4o-mini"))
|
||||
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if not api_key:
|
||||
raise RuntimeError(f"Missing env var {api_key_env} for OpenAI API key")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=os.environ.get(base_url_env) or None,
|
||||
)
|
||||
|
||||
scenes_payload = [
|
||||
{"image_prompt": s.image_prompt, "video_motion": s.video_motion, "narration": s.narration}
|
||||
for s in scenes
|
||||
]
|
||||
user_payload = {
|
||||
"prompt": prompt,
|
||||
"target_index": target_index,
|
||||
"scenes": scenes_payload,
|
||||
}
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": _refine_system_prompt(min_chars, max_chars)},
|
||||
{"role": "user", "content": json.dumps(user_payload, ensure_ascii=False)},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.6,
|
||||
)
|
||||
|
||||
content = resp.choices[0].message.content or "{}"
|
||||
data: Any = json.loads(content)
|
||||
s = data.get("scene")
|
||||
if not isinstance(s, dict):
|
||||
raise ValueError("Model refine output missing scene")
|
||||
image_prompt = str(s.get("image_prompt", "")).strip()
|
||||
video_motion = str(s.get("video_motion", "")).strip()
|
||||
narration = str(s.get("narration", "")).strip()
|
||||
if not image_prompt or not narration:
|
||||
raise ValueError("Refined scene missing required fields")
|
||||
return Scene(image_prompt=image_prompt, video_motion=video_motion, narration=narration)
|
||||
|
||||
Reference in New Issue
Block a user