81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Any
|
|
|
|
from engine.config import AppConfig
|
|
|
|
|
|
def _provider(cfg: AppConfig, path: str, default: str) -> str:
|
|
env_map = {
|
|
"llm.provider": "ENGINE_LLM_PROVIDER",
|
|
"image.provider": "ENGINE_IMAGE_PROVIDER",
|
|
"image_fallback.provider": "ENGINE_IMAGE_FALLBACK_PROVIDER",
|
|
"video.provider": "ENGINE_VIDEO_PROVIDER",
|
|
"tts.provider": "ENGINE_TTS_PROVIDER",
|
|
}
|
|
env_key = env_map.get(path)
|
|
if env_key:
|
|
env_val = str(os.environ.get(env_key, "")).strip()
|
|
if env_val:
|
|
return env_val
|
|
v = cfg.get(path, default)
|
|
return str(v or default).strip() or default
|
|
|
|
|
|
def get_model(name: str, cfg: AppConfig) -> Any:
|
|
if name == "llm":
|
|
provider = _provider(cfg, "llm.provider", "openai")
|
|
if provider == "mock":
|
|
from engine.adapters.llm.mock_adapter import MockLLM
|
|
|
|
return MockLLM()
|
|
from engine.adapters.llm.openai_adapter import OpenAIAdapter
|
|
|
|
return OpenAIAdapter(cfg)
|
|
|
|
if name in ("image", "image_fallback"):
|
|
section = "image" if name == "image" else "image_fallback"
|
|
# Important: fallback must default to mock, not follow primary image provider.
|
|
provider_default = "mock" if name == "image_fallback" else _provider(cfg, "image.provider", "mock")
|
|
provider = _provider(cfg, f"{section}.provider", provider_default)
|
|
if provider == "comfy":
|
|
from engine.adapters.image.comfy_adapter import ComfyAdapter
|
|
|
|
return ComfyAdapter(cfg)
|
|
if provider == "replicate":
|
|
from engine.adapters.image.replicate_adapter import ReplicateAdapter
|
|
|
|
return ReplicateAdapter(cfg)
|
|
if provider == "openai":
|
|
from engine.adapters.image.openai_image_adapter import OpenAIImageAdapter
|
|
|
|
return OpenAIImageAdapter(cfg)
|
|
|
|
from engine.adapters.image.mock_adapter import MockImageGen
|
|
|
|
return MockImageGen()
|
|
|
|
if name == "video":
|
|
provider = _provider(cfg, "video.provider", "moviepy")
|
|
if provider == "ltx":
|
|
from engine.adapters.video.ltx_adapter import LTXVideoGen
|
|
|
|
return LTXVideoGen(cfg)
|
|
from engine.adapters.video.moviepy_adapter import MoviePyVideoGen
|
|
|
|
return MoviePyVideoGen(cfg)
|
|
|
|
if name == "tts":
|
|
provider = _provider(cfg, "tts.provider", "edge")
|
|
if provider == "mock":
|
|
from engine.adapters.tts.mock_adapter import MockTTS
|
|
|
|
return MockTTS()
|
|
from engine.adapters.tts.edge_adapter import EdgeTTS
|
|
|
|
return EdgeTTS(cfg)
|
|
|
|
raise ValueError(f"Unknown model adapter name: {name}")
|
|
|