feat: Update environment variables and modify model mappings

- Added `OPENROUTER_API_KEY` and `AZURE_API_KEYS` to `example.env`.
- Updated `AZURE_DEFAULT_MODEL` to "model-router" in `example.env`.
- Added `AZURE_ROUTES` with multiple model URLs in `example.env`.
- Changed the mapping for `"phi-4-multimodal"` in `DeepInfraChat.py` to `"microsoft/Phi-4-multimodal-instruct"`.
- Added `media` parameter to `GptOss.create_completion` method and raised a `ValueError` if `media` is provided.
- Updated `model_aliases` in `any_model_map.py` to include new mappings for various models.
- Removed several model aliases from `PollinationsAI` in `any_model_map.py`.
- Added new models and updated existing models in `model_map` across various files, including `any_model_map.py` and `__init__.py`.
- Refactored `AnyModelProviderMixin` to include `model_aliases` and updated the logic for handling model aliases.
This commit is contained in:
hlohaus
2025-08-07 01:21:22 +02:00
parent 2c3a437c75
commit 9563f8df3a
14 changed files with 1317 additions and 781 deletions

View File

@@ -8,6 +8,20 @@ TOGETHER_API_KEY=
DEEPINFRA_API_KEY= DEEPINFRA_API_KEY=
OPENAI_API_KEY= OPENAI_API_KEY=
GROQ_API_KEY= GROQ_API_KEY=
AZURE_API_KEY= OPENROUTER_API_KEY=
AZURE_API_ENDPOINT= AZURE_API_KEYS='{
AZURE_DEFAULT_MODEL= "default": "",
"flux-1.1-pro": "",
"flux.1-kontext-pro": ""
}'
AZURE_DEFAULT_MODEL="model-router"
AZURE_ROUTES='{
"model-router": "https://HOST.cognitiveservices.azure.com/openai/deployments/model-router/chat/completions?api-version=2025-01-01-preview",
"deepseek-r1": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
"gpt-4.1": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4.1/chat/completions?api-version=2025-01-01-preview",
"gpt-4o-mini-audio-preview": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini-audio-preview/chat/completions?api-version=2025-01-01-preview",
"o4-mini": "https://HOST.cognitiveservices.azure.com/openai/deployments/o4-mini/chat/completions?api-version=2025-01-01-preview",
"grok-3": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
"flux-1.1-pro": "https://HOST.cognitiveservices.azure.com/openai/deployments/FLUX-1.1-pro/images/generations?api-version=2025-04-01-preview",
"flux.1-kontext-pro": "https://HOST.services.ai.azure.com/openai/deployments/FLUX.1-Kontext-pro/images/edits?api-version=2025-04-01-preview"
}'

View File

@@ -2,10 +2,7 @@ from __future__ import annotations
import requests import requests
from .template import OpenaiTemplate from .template import OpenaiTemplate
from ..errors import ModelNotFoundError
from ..config import DEFAULT_MODEL from ..config import DEFAULT_MODEL
from .. import debug
class DeepInfraChat(OpenaiTemplate): class DeepInfraChat(OpenaiTemplate):
parent = "DeepInfra" parent = "DeepInfra"
@@ -86,7 +83,7 @@ class DeepInfraChat(OpenaiTemplate):
# microsoft # microsoft
"phi-4": "microsoft/phi-4", "phi-4": "microsoft/phi-4",
"phi-4-multimodal": default_vision_model, "phi-4-multimodal": "microsoft/Phi-4-multimodal-instruct",
"phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus", "phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus",
"wizardlm-2-7b": "microsoft/WizardLM-2-7B", "wizardlm-2-7b": "microsoft/WizardLM-2-7B",
"wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B", "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
@@ -100,29 +97,4 @@ class DeepInfraChat(OpenaiTemplate):
"qwen-3-32b": "Qwen/Qwen3-32B", "qwen-3-32b": "Qwen/Qwen3-32B",
"qwen-3-235b": "Qwen/Qwen3-235B-A22B", "qwen-3-235b": "Qwen/Qwen3-235B-A22B",
"qwq-32b": "Qwen/QwQ-32B", "qwq-32b": "Qwen/QwQ-32B",
} }
@classmethod
def get_model(cls, model: str, **kwargs) -> str:
"""Get the internal model name from the user-provided model name."""
# kwargs can contain api_key, api_base, etc. but we don't need them for model selection
if not model:
return cls.default_model
# Check if the model exists directly in our models list
if model in cls.models:
return model
# Check if there's an alias for this model
if model in cls.model_aliases:
alias = cls.model_aliases[model]
# If the alias is a list, randomly select one of the options
if isinstance(alias, list):
import random
selected_model = random.choice(alias)
debug.log(f"DeepInfraChat: Selected model '{selected_model}' from alias '{model}'")
return selected_model
debug.log(f"DeepInfraChat: Using model '{alias}' for alias '{model}'")
return alias
raise ModelNotFoundError(f"Model {model} not found")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages, MediaListType
from ..providers.response import JsonConversation, Reasoning, TitleGeneration from ..providers.response import JsonConversation, Reasoning, TitleGeneration
from ..requests import StreamSession, raise_for_status from ..requests import StreamSession, raise_for_status
from ..config import DEFAULT_MODEL from ..config import DEFAULT_MODEL
@@ -26,11 +26,14 @@ class GptOss(AsyncGeneratorProvider, ProviderModelMixin):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
media: MediaListType = None,
conversation: JsonConversation = None, conversation: JsonConversation = None,
reasoning_effort: str = "high", reasoning_effort: str = "high",
proxy: str = None, proxy: str = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if media:
raise ValueError("Media is not supported by gpt-oss")
model = cls.get_model(model) model = cls.get_model(model)
user_message = get_last_user_message(messages) user_message = get_last_user_message(messages)
cookies = {} cookies = {}

View File

@@ -88,11 +88,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
vision_models = [default_vision_model] vision_models = [default_vision_model]
_models_loaded = False _models_loaded = False
model_aliases = { model_aliases = {
"gpt-4": "openai", "gpt-4.1-nano": "openai",
"gpt-4o": "openai",
"gpt-4.1-mini": "openai",
"gpt-4o-mini": "openai",
"gpt-4.1-nano": "openai-fast",
"gpt-4.1": "openai-large", "gpt-4.1": "openai-large",
"o4-mini": "openai-reasoning", "o4-mini": "openai-reasoning",
"qwen-2.5-coder-32b": "qwen-coder", "qwen-2.5-coder-32b": "qwen-coder",
@@ -106,7 +102,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"grok-3-mini": "grok", "grok-3-mini": "grok",
"grok-3-mini-high": "grok", "grok-3-mini-high": "grok",
"gpt-4o-mini-audio": "openai-audio", "gpt-4o-mini-audio": "openai-audio",
"gpt-4o-audio": "openai-audio",
"sdxl-turbo": "turbo", "sdxl-turbo": "turbo",
"gpt-image": "gptimage", "gpt-image": "gptimage",
"flux-dev": "flux", "flux-dev": "flux",

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from ..providers.types import BaseProvider, ProviderType from ..providers.types import BaseProvider, ProviderType
from ..providers.retry_provider import RetryProvider, IterListProvider from ..providers.retry_provider import RetryProvider, IterListProvider, RotatedProvider
from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider
from ..providers.create_images import CreateImagesProvider from ..providers.create_images import CreateImagesProvider
from .. import debug from .. import debug

View File

@@ -79,7 +79,7 @@ class Azure(OpenaiTemplate):
raise ModelNotFoundError(f"No API endpoint found for model: {model}") raise ModelNotFoundError(f"No API endpoint found for model: {model}")
if not api_endpoint: if not api_endpoint:
api_endpoint = os.environ.get("AZURE_API_ENDPOINT") api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
if not api_key: if cls.api_keys:
api_key = cls.api_keys.get(model, cls.api_keys.get("default")) api_key = cls.api_keys.get(model, cls.api_keys.get("default"))
if not api_key: if not api_key:
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.") raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")

View File

@@ -3,8 +3,6 @@ from __future__ import annotations
from ..template import OpenaiTemplate from ..template import OpenaiTemplate
from ...config import DEFAULT_MODEL from ...config import DEFAULT_MODEL
from ...errors import ModelNotFoundError
from ... import debug
class Together(OpenaiTemplate): class Together(OpenaiTemplate):
label = "Together" label = "Together"
@@ -141,28 +139,4 @@ class Together(OpenaiTemplate):
"flux-dev": ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev-lora"], "flux-dev": ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev-lora"],
"flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro", "flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro",
"flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev", "flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev",
} }
@classmethod
def get_model(cls, model: str, api_key: str = None, api_base: str = None) -> str:
"""Get the internal model name from the user-provided model name."""
if not model:
return cls.default_model
# Check if the model exists directly in our models list
if model in cls.models:
return model
# Check if there's an alias for this model
if model in cls.model_aliases:
alias = cls.model_aliases[model]
# If the alias is a list, randomly select one of the options
if isinstance(alias, list):
import random # Add this import at the top of the file
selected_model = random.choice(alias)
debug.log(f"Together: Selected model '{selected_model}' from alias '{model}'")
return selected_model
debug.log(f"Together: Using model '{alias}' for alias '{model}'")
return alias
raise ModelNotFoundError(f"Together: Model {model} not found")

View File

@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
raise_for_status(response) raise_for_status(response)
data = response.json() data = response.json()
data = data.get("data") if isinstance(data, dict) else data data = data.get("data") if isinstance(data, dict) else data
cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image")] cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image") or model.get("type") == "image"]
cls.vision_models = cls.vision_models.copy() cls.vision_models = cls.vision_models.copy()
cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")] cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")]
cls.models = [model.get("id", model.get("name")) for model in data] cls.models = [model.get("id", model.get("name")) for model in data]

View File

@@ -363,6 +363,7 @@ class Api:
"owned_by": getattr(provider, "label", provider.__name__), "owned_by": getattr(provider, "label", provider.__name__),
"image": model in getattr(provider, "image_models", []), "image": model in getattr(provider, "image_models", []),
"vision": model in getattr(provider, "vision_models", []), "vision": model in getattr(provider, "vision_models", []),
"type": "image" if model in getattr(provider, "image_models", []) else "text",
} for model in models] } for model in models]
} }

View File

@@ -127,6 +127,7 @@ class Backend_Api(Api):
except MissingAuthError as e: except MissingAuthError as e:
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401 return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401
except Exception as e: except Exception as e:
logger.exception(e)
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500 return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
return jsonify(response) return jsonify(response)

File diff suppressed because one or more lines are too long

View File

@@ -6,7 +6,7 @@ import json
from ..typing import AsyncResult, Messages, MediaListType, Union from ..typing import AsyncResult, Messages, MediaListType, Union
from ..errors import ModelNotFoundError from ..errors import ModelNotFoundError
from ..image import is_data_an_audio from ..image import is_data_an_audio
from ..providers.retry_provider import IterListProvider from ..providers.retry_provider import RotatedProvider
from ..Provider.needs_auth import OpenaiChat, CopilotAccount from ..Provider.needs_auth import OpenaiChat, CopilotAccount
from ..Provider.hf_space import HuggingSpace from ..Provider.hf_space import HuggingSpace
from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
@@ -18,7 +18,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .. import Provider from .. import Provider
from .. import models from .. import models
from .. import debug from .. import debug
from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents, model_aliases
PROVIERS_LIST_1 = [ PROVIERS_LIST_1 = [
CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM, CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
@@ -73,6 +73,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
models_count = models_count models_count = models_count
models = list(model_map.keys()) models = list(model_map.keys())
model_map: dict[str, dict[str, str]] = model_map model_map: dict[str, dict[str, str]] = model_map
model_aliases: dict[str, str] = model_aliases
@classmethod @classmethod
def extend_ignored(cls, ignored: list[str]) -> list[str]: def extend_ignored(cls, ignored: list[str]) -> list[str]:
@@ -102,7 +103,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
cls.create_model_map() cls.create_model_map()
file = os.path.join(os.path.dirname(__file__), "any_model_map.py") file = os.path.join(os.path.dirname(__file__), "any_model_map.py")
with open(file, "w", encoding="utf-8") as f: with open(file, "w", encoding="utf-8") as f:
for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents"]: for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents", "model_aliases"]:
value = getattr(cls, key) value = getattr(cls, key)
f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n") f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n")
@@ -118,11 +119,14 @@ class AnyModelProviderMixin(ProviderModelMixin):
"default": {provider.__name__: "" for provider in models.default.best_provider.providers}, "default": {provider.__name__: "" for provider in models.default.best_provider.providers},
} }
cls.model_map.update({ cls.model_map.update({
model.name: { name: {
provider.__name__: model.get_long_name() for provider in providers provider.__name__: model.get_long_name() for provider in providers
if provider.working if provider.working
} for _, (model, providers) in models.__models__.items() } for name, (model, providers) in models.__models__.items()
}) })
for name, (model, providers) in models.__models__.items():
if isinstance(model, models.ImageModel):
cls.image_models.append(name)
# Process special providers # Process special providers
for provider in PROVIERS_LIST_2: for provider in PROVIERS_LIST_2:
@@ -234,6 +238,11 @@ class AnyModelProviderMixin(ProviderModelMixin):
elif provider.__name__ not in cls.parents[provider.get_parent()]: elif provider.__name__ not in cls.parents[provider.get_parent()]:
cls.parents[provider.get_parent()].append(provider.__name__) cls.parents[provider.get_parent()].append(provider.__name__)
for model, providers in cls.model_map.items():
for provider, alias in providers.items():
if alias != model and isinstance(alias, str) and alias not in cls.model_map:
cls.model_aliases[alias] = model
@classmethod @classmethod
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]: def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
unsorted_models = cls.get_models(ignored=ignored) unsorted_models = cls.get_models(ignored=ignored)
@@ -299,7 +308,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
groups["image"].append(model) groups["image"].append(model)
added = True added = True
# Check for OpenAI models # Check for OpenAI models
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"): elif model.startswith(("gpt-", "chatgpt-", "o1", "o1", "o3", "o4")) or model in ("auto", "searchgpt"):
groups["openai"].append(model) groups["openai"].append(model)
added = True added = True
# Check for video models # Check for video models
@@ -371,10 +380,14 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
providers.append(provider) providers.append(provider)
model = submodel model = submodel
else: else:
if model not in cls.model_map:
if model in cls.model_aliases:
model = cls.model_aliases[model]
if model in cls.model_map: if model in cls.model_map:
for provider, alias in cls.model_map[model].items(): for provider, alias in cls.model_map[model].items():
provider = Provider.__map__[provider] provider = Provider.__map__[provider]
provider.model_aliases[model] = alias if model not in provider.model_aliases:
provider.model_aliases[model] = alias
providers.append(provider) providers.append(provider)
if not providers: if not providers:
for provider in PROVIERS_LIST_1: for provider in PROVIERS_LIST_1:
@@ -390,7 +403,7 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'") debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'")
async for chunk in IterListProvider(providers).create_async_generator( async for chunk in RotatedProvider(providers).create_async_generator(
model, model,
messages, messages,
stream=stream, stream=stream,

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import random
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod from abc import abstractmethod
@@ -21,6 +21,7 @@ from .response import BaseConversation, AuthResult
from .helper import concat_chunks from .helper import concat_chunks
from ..cookies import get_cookies_dir from ..cookies import get_cookies_dir
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
from .. import debug
SAFE_PARAMETERS = [ SAFE_PARAMETERS = [
"model", "messages", "stream", "timeout", "model", "messages", "stream", "timeout",
@@ -368,7 +369,13 @@ class ProviderModelMixin:
if not model and cls.default_model is not None: if not model and cls.default_model is not None:
model = cls.default_model model = cls.default_model
if model in cls.model_aliases: if model in cls.model_aliases:
model = cls.model_aliases[model] alias = cls.model_aliases[model]
if isinstance(alias, list):
selected_model = random.choice(alias)
debug.log(f"{cls.__name__}: Selected model '{selected_model}' from alias '{model}'")
return selected_model
debug.log(f"{cls.__name__}: Using model '{alias}' for alias '{model}'")
return alias
if model not in cls.model_aliases.values(): if model not in cls.model_aliases.values():
if model not in cls.get_models(**kwargs) and cls.models: if model not in cls.get_models(**kwargs) and cls.models:
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}") raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")

View File

@@ -2,13 +2,176 @@ from __future__ import annotations
import random import random
from ..typing import Type, List, CreateResult, Messages, AsyncResult from ..typing import Dict, Type, List, CreateResult, Messages, AsyncResult
from .types import BaseProvider, BaseRetryProvider, ProviderType from .types import BaseProvider, BaseRetryProvider, ProviderType
from .response import ProviderInfo, JsonConversation, is_content from .response import ProviderInfo, JsonConversation, is_content
from .. import debug from .. import debug
from ..tools.run_tools import AuthManager from ..tools.run_tools import AuthManager
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
class RotatedProvider(BaseRetryProvider):
"""
A provider that rotates through a list of providers, attempting one provider per
request and advancing to the next one upon failure. This distributes load and
retries across multiple providers in a round-robin fashion.
"""
def __init__(
self,
providers: List[Type[BaseProvider]],
shuffle: bool = True
) -> None:
"""
Initialize the RotatedProvider.
Args:
providers (List[Type[BaseProvider]]): A non-empty list of providers to rotate through.
shuffle (bool): If True, shuffles the provider list once at initialization
to randomize the rotation order.
"""
if not isinstance(providers, list) or len(providers) == 0:
raise ValueError('RotatedProvider requires a non-empty list of providers.')
self.providers = providers
if shuffle:
random.shuffle(self.providers)
self.current_index = 0
self.last_provider: Type[BaseProvider] = None
def _get_current_provider(self) -> Type[BaseProvider]:
"""Gets the provider at the current index."""
return self.providers[self.current_index]
def _rotate_provider(self) -> None:
"""Rotates to the next provider in the list."""
self.current_index = (self.current_index + 1) % len(self.providers)
#new_provider_name = self.providers[self.current_index].__name__
#debug.log(f"Rotated to next provider: {new_provider_name}")
def create_completion(
self,
model: str,
messages: Messages,
ignored: list[str] = [], # 'ignored' is less relevant now but kept for compatibility
api_key: str = None,
**kwargs,
) -> CreateResult:
"""
Create a completion using the current provider and rotating on failure.
It will try each provider in the list once per call, rotating after each
failed attempt, until one succeeds or all have failed.
"""
exceptions: Dict[str, Exception] = {}
# Loop over the number of providers, giving each one a chance
for _ in range(len(self.providers)):
provider = self._get_current_provider()
self.last_provider = provider
self._rotate_provider()
# Skip if provider is in the ignored list
if provider.get_parent() in ignored:
continue
alias = model or getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
yield ProviderInfo(**provider.get_dict(), model=alias, alias=model)
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
try:
# Attempt to get a response from the current provider
response = provider.create_function(alias, messages, **extra_body)
started = False
for chunk in response:
if chunk:
yield chunk
if is_content(chunk):
started = True
if started:
# Success, so we return and do not rotate
return
except Exception as e:
exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} failed: {e}")
# If the loop completes, all providers have failed
raise_exceptions(exceptions)
async def create_async_generator(
self,
model: str,
messages: Messages,
ignored: list[str] = [],
api_key: str = None,
conversation: JsonConversation = None,
**kwargs
) -> AsyncResult:
"""
Asynchronously create a completion, rotating through providers on failure.
"""
exceptions: Dict[str, Exception] = {}
for _ in range(len(self.providers)):
provider = self._get_current_provider()
self._rotate_provider()
self.last_provider = provider
if provider.get_parent() in ignored:
continue
alias = model or getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
yield ProviderInfo(**provider.get_dict(), model=alias)
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
if conversation and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
try:
response = provider.async_create_function(alias, messages, **extra_body)
started = False
async for chunk in response:
if isinstance(chunk, JsonConversation):
if conversation is None: conversation = JsonConversation()
setattr(conversation, provider.__name__, chunk.get_dict())
yield conversation
elif chunk:
yield chunk
if is_content(chunk):
started = True
if started:
return # Success
except Exception as e:
exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} failed: {e}")
raise_exceptions(exceptions)
# Maintain API compatibility
create_function = create_completion
async_create_function = create_async_generator
class IterListProvider(BaseRetryProvider): class IterListProvider(BaseRetryProvider):
def __init__( def __init__(
self, self,