mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-11-02 22:24:03 +08:00
feat: Update environment variables and modify model mappings
- Added `OPENROUTER_API_KEY` and `AZURE_API_KEYS` to `example.env`. - Updated `AZURE_DEFAULT_MODEL` to "model-router" in `example.env`. - Added `AZURE_ROUTES` with multiple model URLs in `example.env`. - Changed the mapping for `"phi-4-multimodal"` in `DeepInfraChat.py` to `"microsoft/Phi-4-multimodal-instruct"`. - Added `media` parameter to `GptOss.create_completion` method and raised a `ValueError` if `media` is provided. - Updated `model_aliases` in `any_model_map.py` to include new mappings for various models. - Removed several model aliases from `PollinationsAI` in `any_model_map.py`. - Added new models and updated existing models in `model_map` across various files, including `any_model_map.py` and `__init__.py`. - Refactored `AnyModelProviderMixin` to include `model_aliases` and updated the logic for handling model aliases.
This commit is contained in:
20
example.env
20
example.env
@@ -8,6 +8,20 @@ TOGETHER_API_KEY=
|
||||
DEEPINFRA_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
AZURE_API_KEY=
|
||||
AZURE_API_ENDPOINT=
|
||||
AZURE_DEFAULT_MODEL=
|
||||
OPENROUTER_API_KEY=
|
||||
AZURE_API_KEYS='{
|
||||
"default": "",
|
||||
"flux-1.1-pro": "",
|
||||
"flux.1-kontext-pro": ""
|
||||
}'
|
||||
AZURE_DEFAULT_MODEL="model-router"
|
||||
AZURE_ROUTES='{
|
||||
"model-router": "https://HOST.cognitiveservices.azure.com/openai/deployments/model-router/chat/completions?api-version=2025-01-01-preview",
|
||||
"deepseek-r1": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||
"gpt-4.1": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4.1/chat/completions?api-version=2025-01-01-preview",
|
||||
"gpt-4o-mini-audio-preview": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini-audio-preview/chat/completions?api-version=2025-01-01-preview",
|
||||
"o4-mini": "https://HOST.cognitiveservices.azure.com/openai/deployments/o4-mini/chat/completions?api-version=2025-01-01-preview",
|
||||
"grok-3": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||
"flux-1.1-pro": "https://HOST.cognitiveservices.azure.com/openai/deployments/FLUX-1.1-pro/images/generations?api-version=2025-04-01-preview",
|
||||
"flux.1-kontext-pro": "https://HOST.services.ai.azure.com/openai/deployments/FLUX.1-Kontext-pro/images/edits?api-version=2025-04-01-preview"
|
||||
}'
|
||||
@@ -2,10 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import requests
|
||||
from .template import OpenaiTemplate
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..config import DEFAULT_MODEL
|
||||
from .. import debug
|
||||
|
||||
|
||||
class DeepInfraChat(OpenaiTemplate):
|
||||
parent = "DeepInfra"
|
||||
@@ -86,7 +83,7 @@ class DeepInfraChat(OpenaiTemplate):
|
||||
|
||||
# microsoft
|
||||
"phi-4": "microsoft/phi-4",
|
||||
"phi-4-multimodal": default_vision_model,
|
||||
"phi-4-multimodal": "microsoft/Phi-4-multimodal-instruct",
|
||||
"phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus",
|
||||
"wizardlm-2-7b": "microsoft/WizardLM-2-7B",
|
||||
"wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
|
||||
@@ -101,28 +98,3 @@ class DeepInfraChat(OpenaiTemplate):
|
||||
"qwen-3-235b": "Qwen/Qwen3-235B-A22B",
|
||||
"qwq-32b": "Qwen/QwQ-32B",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
"""Get the internal model name from the user-provided model name."""
|
||||
# kwargs can contain api_key, api_base, etc. but we don't need them for model selection
|
||||
if not model:
|
||||
return cls.default_model
|
||||
|
||||
# Check if the model exists directly in our models list
|
||||
if model in cls.models:
|
||||
return model
|
||||
|
||||
# Check if there's an alias for this model
|
||||
if model in cls.model_aliases:
|
||||
alias = cls.model_aliases[model]
|
||||
# If the alias is a list, randomly select one of the options
|
||||
if isinstance(alias, list):
|
||||
import random
|
||||
selected_model = random.choice(alias)
|
||||
debug.log(f"DeepInfraChat: Selected model '{selected_model}' from alias '{model}'")
|
||||
return selected_model
|
||||
debug.log(f"DeepInfraChat: Using model '{alias}' for alias '{model}'")
|
||||
return alias
|
||||
|
||||
raise ModelNotFoundError(f"Model {model} not found")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..providers.response import JsonConversation, Reasoning, TitleGeneration
|
||||
from ..requests import StreamSession, raise_for_status
|
||||
from ..config import DEFAULT_MODEL
|
||||
@@ -26,11 +26,14 @@ class GptOss(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
media: MediaListType = None,
|
||||
conversation: JsonConversation = None,
|
||||
reasoning_effort: str = "high",
|
||||
proxy: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if media:
|
||||
raise ValueError("Media is not supported by gpt-oss")
|
||||
model = cls.get_model(model)
|
||||
user_message = get_last_user_message(messages)
|
||||
cookies = {}
|
||||
|
||||
@@ -88,11 +88,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
vision_models = [default_vision_model]
|
||||
_models_loaded = False
|
||||
model_aliases = {
|
||||
"gpt-4": "openai",
|
||||
"gpt-4o": "openai",
|
||||
"gpt-4.1-mini": "openai",
|
||||
"gpt-4o-mini": "openai",
|
||||
"gpt-4.1-nano": "openai-fast",
|
||||
"gpt-4.1-nano": "openai",
|
||||
"gpt-4.1": "openai-large",
|
||||
"o4-mini": "openai-reasoning",
|
||||
"qwen-2.5-coder-32b": "qwen-coder",
|
||||
@@ -106,7 +102,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"grok-3-mini": "grok",
|
||||
"grok-3-mini-high": "grok",
|
||||
"gpt-4o-mini-audio": "openai-audio",
|
||||
"gpt-4o-audio": "openai-audio",
|
||||
"sdxl-turbo": "turbo",
|
||||
"gpt-image": "gptimage",
|
||||
"flux-dev": "flux",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..providers.types import BaseProvider, ProviderType
|
||||
from ..providers.retry_provider import RetryProvider, IterListProvider
|
||||
from ..providers.retry_provider import RetryProvider, IterListProvider, RotatedProvider
|
||||
from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider
|
||||
from ..providers.create_images import CreateImagesProvider
|
||||
from .. import debug
|
||||
|
||||
@@ -79,7 +79,7 @@ class Azure(OpenaiTemplate):
|
||||
raise ModelNotFoundError(f"No API endpoint found for model: {model}")
|
||||
if not api_endpoint:
|
||||
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
|
||||
if not api_key:
|
||||
if cls.api_keys:
|
||||
api_key = cls.api_keys.get(model, cls.api_keys.get("default"))
|
||||
if not api_key:
|
||||
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
|
||||
|
||||
@@ -3,8 +3,6 @@ from __future__ import annotations
|
||||
|
||||
from ..template import OpenaiTemplate
|
||||
from ...config import DEFAULT_MODEL
|
||||
from ...errors import ModelNotFoundError
|
||||
from ... import debug
|
||||
|
||||
class Together(OpenaiTemplate):
|
||||
label = "Together"
|
||||
@@ -142,27 +140,3 @@ class Together(OpenaiTemplate):
|
||||
"flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro",
|
||||
"flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, api_key: str = None, api_base: str = None) -> str:
|
||||
"""Get the internal model name from the user-provided model name."""
|
||||
if not model:
|
||||
return cls.default_model
|
||||
|
||||
# Check if the model exists directly in our models list
|
||||
if model in cls.models:
|
||||
return model
|
||||
|
||||
# Check if there's an alias for this model
|
||||
if model in cls.model_aliases:
|
||||
alias = cls.model_aliases[model]
|
||||
# If the alias is a list, randomly select one of the options
|
||||
if isinstance(alias, list):
|
||||
import random # Add this import at the top of the file
|
||||
selected_model = random.choice(alias)
|
||||
debug.log(f"Together: Selected model '{selected_model}' from alias '{model}'")
|
||||
return selected_model
|
||||
debug.log(f"Together: Using model '{alias}' for alias '{model}'")
|
||||
return alias
|
||||
|
||||
raise ModelNotFoundError(f"Together: Model {model} not found")
|
||||
@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
data = data.get("data") if isinstance(data, dict) else data
|
||||
cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image")]
|
||||
cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image") or model.get("type") == "image"]
|
||||
cls.vision_models = cls.vision_models.copy()
|
||||
cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")]
|
||||
cls.models = [model.get("id", model.get("name")) for model in data]
|
||||
|
||||
@@ -363,6 +363,7 @@ class Api:
|
||||
"owned_by": getattr(provider, "label", provider.__name__),
|
||||
"image": model in getattr(provider, "image_models", []),
|
||||
"vision": model in getattr(provider, "vision_models", []),
|
||||
"type": "image" if model in getattr(provider, "image_models", []) else "text",
|
||||
} for model in models]
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +127,7 @@ class Backend_Api(Api):
|
||||
except MissingAuthError as e:
|
||||
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -6,7 +6,7 @@ import json
|
||||
from ..typing import AsyncResult, Messages, MediaListType, Union
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..image import is_data_an_audio
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.retry_provider import RotatedProvider
|
||||
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
||||
from ..Provider.hf_space import HuggingSpace
|
||||
from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
||||
@@ -18,7 +18,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .. import Provider
|
||||
from .. import models
|
||||
from .. import debug
|
||||
from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents
|
||||
from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents, model_aliases
|
||||
|
||||
PROVIERS_LIST_1 = [
|
||||
CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
|
||||
@@ -73,6 +73,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
||||
models_count = models_count
|
||||
models = list(model_map.keys())
|
||||
model_map: dict[str, dict[str, str]] = model_map
|
||||
model_aliases: dict[str, str] = model_aliases
|
||||
|
||||
@classmethod
|
||||
def extend_ignored(cls, ignored: list[str]) -> list[str]:
|
||||
@@ -102,7 +103,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
||||
cls.create_model_map()
|
||||
file = os.path.join(os.path.dirname(__file__), "any_model_map.py")
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents"]:
|
||||
for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents", "model_aliases"]:
|
||||
value = getattr(cls, key)
|
||||
f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n")
|
||||
|
||||
@@ -118,11 +119,14 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
||||
"default": {provider.__name__: "" for provider in models.default.best_provider.providers},
|
||||
}
|
||||
cls.model_map.update({
|
||||
model.name: {
|
||||
name: {
|
||||
provider.__name__: model.get_long_name() for provider in providers
|
||||
if provider.working
|
||||
} for _, (model, providers) in models.__models__.items()
|
||||
} for name, (model, providers) in models.__models__.items()
|
||||
})
|
||||
for name, (model, providers) in models.__models__.items():
|
||||
if isinstance(model, models.ImageModel):
|
||||
cls.image_models.append(name)
|
||||
|
||||
# Process special providers
|
||||
for provider in PROVIERS_LIST_2:
|
||||
@@ -234,6 +238,11 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
||||
elif provider.__name__ not in cls.parents[provider.get_parent()]:
|
||||
cls.parents[provider.get_parent()].append(provider.__name__)
|
||||
|
||||
for model, providers in cls.model_map.items():
|
||||
for provider, alias in providers.items():
|
||||
if alias != model and isinstance(alias, str) and alias not in cls.model_map:
|
||||
cls.model_aliases[alias] = model
|
||||
|
||||
@classmethod
|
||||
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
|
||||
unsorted_models = cls.get_models(ignored=ignored)
|
||||
@@ -299,7 +308,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
||||
groups["image"].append(model)
|
||||
added = True
|
||||
# Check for OpenAI models
|
||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"):
|
||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1", "o3", "o4")) or model in ("auto", "searchgpt"):
|
||||
groups["openai"].append(model)
|
||||
added = True
|
||||
# Check for video models
|
||||
@@ -371,9 +380,13 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
|
||||
providers.append(provider)
|
||||
model = submodel
|
||||
else:
|
||||
if model not in cls.model_map:
|
||||
if model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
if model in cls.model_map:
|
||||
for provider, alias in cls.model_map[model].items():
|
||||
provider = Provider.__map__[provider]
|
||||
if model not in provider.model_aliases:
|
||||
provider.model_aliases[model] = alias
|
||||
providers.append(provider)
|
||||
if not providers:
|
||||
@@ -390,7 +403,7 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
|
||||
|
||||
debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'")
|
||||
|
||||
async for chunk in IterListProvider(providers).create_async_generator(
|
||||
async for chunk in RotatedProvider(providers).create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import random
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
@@ -21,6 +21,7 @@ from .response import BaseConversation, AuthResult
|
||||
from .helper import concat_chunks
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
"model", "messages", "stream", "timeout",
|
||||
@@ -368,7 +369,13 @@ class ProviderModelMixin:
|
||||
if not model and cls.default_model is not None:
|
||||
model = cls.default_model
|
||||
if model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
alias = cls.model_aliases[model]
|
||||
if isinstance(alias, list):
|
||||
selected_model = random.choice(alias)
|
||||
debug.log(f"{cls.__name__}: Selected model '{selected_model}' from alias '{model}'")
|
||||
return selected_model
|
||||
debug.log(f"{cls.__name__}: Using model '{alias}' for alias '{model}'")
|
||||
return alias
|
||||
if model not in cls.model_aliases.values():
|
||||
if model not in cls.get_models(**kwargs) and cls.models:
|
||||
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||
|
||||
@@ -2,13 +2,176 @@ from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from ..typing import Type, List, CreateResult, Messages, AsyncResult
|
||||
from ..typing import Dict, Type, List, CreateResult, Messages, AsyncResult
|
||||
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
||||
from .response import ProviderInfo, JsonConversation, is_content
|
||||
from .. import debug
|
||||
from ..tools.run_tools import AuthManager
|
||||
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
|
||||
|
||||
class RotatedProvider(BaseRetryProvider):
|
||||
"""
|
||||
A provider that rotates through a list of providers, attempting one provider per
|
||||
request and advancing to the next one upon failure. This distributes load and
|
||||
retries across multiple providers in a round-robin fashion.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
providers: List[Type[BaseProvider]],
|
||||
shuffle: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RotatedProvider.
|
||||
Args:
|
||||
providers (List[Type[BaseProvider]]): A non-empty list of providers to rotate through.
|
||||
shuffle (bool): If True, shuffles the provider list once at initialization
|
||||
to randomize the rotation order.
|
||||
"""
|
||||
if not isinstance(providers, list) or len(providers) == 0:
|
||||
raise ValueError('RotatedProvider requires a non-empty list of providers.')
|
||||
|
||||
self.providers = providers
|
||||
if shuffle:
|
||||
random.shuffle(self.providers)
|
||||
|
||||
self.current_index = 0
|
||||
self.last_provider: Type[BaseProvider] = None
|
||||
|
||||
def _get_current_provider(self) -> Type[BaseProvider]:
|
||||
"""Gets the provider at the current index."""
|
||||
return self.providers[self.current_index]
|
||||
|
||||
def _rotate_provider(self) -> None:
|
||||
"""Rotates to the next provider in the list."""
|
||||
self.current_index = (self.current_index + 1) % len(self.providers)
|
||||
#new_provider_name = self.providers[self.current_index].__name__
|
||||
#debug.log(f"Rotated to next provider: {new_provider_name}")
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
ignored: list[str] = [], # 'ignored' is less relevant now but kept for compatibility
|
||||
api_key: str = None,
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion using the current provider and rotating on failure.
|
||||
|
||||
It will try each provider in the list once per call, rotating after each
|
||||
failed attempt, until one succeeds or all have failed.
|
||||
"""
|
||||
exceptions: Dict[str, Exception] = {}
|
||||
|
||||
# Loop over the number of providers, giving each one a chance
|
||||
for _ in range(len(self.providers)):
|
||||
provider = self._get_current_provider()
|
||||
self.last_provider = provider
|
||||
self._rotate_provider()
|
||||
|
||||
# Skip if provider is in the ignored list
|
||||
if provider.get_parent() in ignored:
|
||||
continue
|
||||
|
||||
alias = model or getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
|
||||
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias, alias=model)
|
||||
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
|
||||
try:
|
||||
# Attempt to get a response from the current provider
|
||||
response = provider.create_function(alias, messages, **extra_body)
|
||||
started = False
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
# Success, so we return and do not rotate
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.error(f"{provider.__name__} failed: {e}")
|
||||
|
||||
# If the loop completes, all providers have failed
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
ignored: list[str] = [],
|
||||
api_key: str = None,
|
||||
conversation: JsonConversation = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
Asynchronously create a completion, rotating through providers on failure.
|
||||
"""
|
||||
exceptions: Dict[str, Exception] = {}
|
||||
|
||||
for _ in range(len(self.providers)):
|
||||
provider = self._get_current_provider()
|
||||
self._rotate_provider()
|
||||
self.last_provider = provider
|
||||
|
||||
if provider.get_parent() in ignored:
|
||||
continue
|
||||
|
||||
alias = model or getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
|
||||
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
if conversation and hasattr(conversation, provider.__name__):
|
||||
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||
|
||||
try:
|
||||
response = provider.async_create_function(alias, messages, **extra_body)
|
||||
started = False
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, JsonConversation):
|
||||
if conversation is None: conversation = JsonConversation()
|
||||
setattr(conversation, provider.__name__, chunk.get_dict())
|
||||
yield conversation
|
||||
elif chunk:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
return # Success
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.error(f"{provider.__name__} failed: {e}")
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
# Maintain API compatibility
|
||||
create_function = create_completion
|
||||
async_create_function = create_async_generator
|
||||
|
||||
class IterListProvider(BaseRetryProvider):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user