mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-07 09:11:38 +08:00
43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
from ..models import ModelUtils, ImageModel
|
|
from ..Provider import ProviderUtils
|
|
from ..providers.types import ProviderType
|
|
|
|
class MediaModels():
|
|
def __init__(self, client, provider: ProviderType = None):
|
|
self.client = client
|
|
self.provider = provider
|
|
|
|
def get(self, name, default=None) -> ProviderType:
|
|
if name in ModelUtils.convert:
|
|
return ModelUtils.convert[name].best_provider
|
|
if name in ProviderUtils.convert:
|
|
return ProviderUtils.convert[name]
|
|
return default
|
|
|
|
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
|
|
if self.provider is None:
|
|
return []
|
|
if api_key is None:
|
|
api_key = self.client.api_key
|
|
return self.provider.get_models(
|
|
**kwargs,
|
|
**{} if api_key is None else {"api_key": api_key}
|
|
)
|
|
|
|
def get_image(self, **kwargs) -> list[str]:
|
|
if self.provider is None:
|
|
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
|
|
self.get_all(**kwargs)
|
|
if hasattr(self.provider, "image_models"):
|
|
return self.provider.image_models
|
|
return []
|
|
|
|
def get_video(self, **kwargs) -> list[str]:
|
|
if self.provider is None:
|
|
return []
|
|
self.get_all(**kwargs)
|
|
if hasattr(self.provider, "video_models"):
|
|
return self.provider.video_models
|
|
return [] |