Files
AiVideo/engine/model_factory.py
2026-03-25 19:35:37 +08:00

81 lines
2.6 KiB
Python

from __future__ import annotations
import os
from typing import Any
from engine.config import AppConfig
def _provider(cfg: AppConfig, path: str, default: str) -> str:
env_map = {
"llm.provider": "ENGINE_LLM_PROVIDER",
"image.provider": "ENGINE_IMAGE_PROVIDER",
"image_fallback.provider": "ENGINE_IMAGE_FALLBACK_PROVIDER",
"video.provider": "ENGINE_VIDEO_PROVIDER",
"tts.provider": "ENGINE_TTS_PROVIDER",
}
env_key = env_map.get(path)
if env_key:
env_val = str(os.environ.get(env_key, "")).strip()
if env_val:
return env_val
v = cfg.get(path, default)
return str(v or default).strip() or default
def get_model(name: str, cfg: AppConfig) -> Any:
if name == "llm":
provider = _provider(cfg, "llm.provider", "openai")
if provider == "mock":
from engine.adapters.llm.mock_adapter import MockLLM
return MockLLM()
from engine.adapters.llm.openai_adapter import OpenAIAdapter
return OpenAIAdapter(cfg)
if name in ("image", "image_fallback"):
section = "image" if name == "image" else "image_fallback"
# Important: fallback must default to mock, not follow primary image provider.
provider_default = "mock" if name == "image_fallback" else _provider(cfg, "image.provider", "mock")
provider = _provider(cfg, f"{section}.provider", provider_default)
if provider == "comfy":
from engine.adapters.image.comfy_adapter import ComfyAdapter
return ComfyAdapter(cfg)
if provider == "replicate":
from engine.adapters.image.replicate_adapter import ReplicateAdapter
return ReplicateAdapter(cfg)
if provider == "openai":
from engine.adapters.image.openai_image_adapter import OpenAIImageAdapter
return OpenAIImageAdapter(cfg)
from engine.adapters.image.mock_adapter import MockImageGen
return MockImageGen()
if name == "video":
provider = _provider(cfg, "video.provider", "moviepy")
if provider == "ltx":
from engine.adapters.video.ltx_adapter import LTXVideoGen
return LTXVideoGen(cfg)
from engine.adapters.video.moviepy_adapter import MoviePyVideoGen
return MoviePyVideoGen(cfg)
if name == "tts":
provider = _provider(cfg, "tts.provider", "edge")
if provider == "mock":
from engine.adapters.tts.mock_adapter import MockTTS
return MockTTS()
from engine.adapters.tts.edge_adapter import EdgeTTS
return EdgeTTS(cfg)
raise ValueError(f"Unknown model adapter name: {name}")