from __future__ import annotations from ..models import ModelUtils, ImageModel from ..Provider import ProviderUtils from ..providers.types import ProviderType class MediaModels(): def __init__(self, client, provider: ProviderType = None): self.client = client self.provider = provider def get(self, name, default=None) -> ProviderType: if name in ModelUtils.convert: return ModelUtils.convert[name].best_provider if name in ProviderUtils.convert: return ProviderUtils.convert[name] return default def get_all(self, api_key: str = None, **kwargs) -> list[str]: if self.provider is None: return [] if api_key is None: api_key = self.client.api_key return self.provider.get_models( **kwargs, **{} if api_key is None else {"api_key": api_key} ) def get_image(self, **kwargs) -> list[str]: if self.provider is None: return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)] self.get_all(**kwargs) if hasattr(self.provider, "image_models"): return self.provider.image_models return [] def get_video(self, **kwargs) -> list[str]: if self.provider is None: return [] self.get_all(**kwargs) if hasattr(self.provider, "video_models"): return self.provider.video_models return []