* Fix arm v7 build / improve api

* Update stubs.py

* Fix unit tests
This commit is contained in:
H Lohaus
2024-11-24 17:43:45 +01:00
committed by GitHub
parent 4744d0b77d
commit 804a80bc7c
12 changed files with 248 additions and 219 deletions

View File

@@ -22,11 +22,11 @@ conversations: dict[dict[str, BaseConversation]] = {}
class Api:
@staticmethod
def get_models() -> list[str]:
def get_models():
return models._all_models
@staticmethod
def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
def get_provider_models(provider: str, api_key: str = None):
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
@@ -46,39 +46,7 @@ class Api:
return []
@staticmethod
def get_image_models() -> list[dict]:
image_models = []
index = []
for provider in __providers__:
if hasattr(provider, "image_models"):
if hasattr(provider, "get_models"):
provider.get_models()
parent = provider
if hasattr(provider, "parent"):
parent = __map__[provider.parent]
if parent.__name__ not in index:
for model in provider.image_models:
image_models.append({
"provider": parent.__name__,
"url": parent.url,
"label": parent.label if hasattr(parent, "label") else None,
"image_model": model,
"vision_model": getattr(parent, "default_vision_model", None)
})
index.append(parent.__name__)
elif hasattr(provider, "default_vision_model") and provider.__name__ not in index:
image_models.append({
"provider": provider.__name__,
"url": provider.url,
"label": provider.label if hasattr(provider, "label") else None,
"image_model": None,
"vision_model": provider.default_vision_model
})
index.append(provider.__name__)
return image_models
@staticmethod
def get_providers() -> list[str]:
def get_providers() -> dict[str, str]:
return {
provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
+ (" (Image Generation)" if getattr(provider, "image_models", None) else "")
@@ -90,7 +58,7 @@ class Api:
}
@staticmethod
def get_version():
def get_version() -> dict:
try:
current_version = version.utils.current_version
except VersionNotFoundError: