mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-30 11:06:22 +08:00
feat: Update environment variables and modify model mappings
- Added `OPENROUTER_API_KEY` and `AZURE_API_KEYS` to `example.env`. - Updated `AZURE_DEFAULT_MODEL` to "model-router" in `example.env`. - Added `AZURE_ROUTES` with multiple model URLs in `example.env`. - Changed the mapping for `"phi-4-multimodal"` in `DeepInfraChat.py` to `"microsoft/Phi-4-multimodal-instruct"`. - Added `media` parameter to `GptOss.create_completion` method and raised a `ValueError` if `media` is provided. - Updated `model_aliases` in `any_model_map.py` to include new mappings for various models. - Removed several model aliases from `PollinationsAI` in `any_model_map.py`. - Added new models and updated existing models in `model_map` across various files, including `any_model_map.py` and `__init__.py`. - Refactored `AnyModelProviderMixin` to include `model_aliases` and updated the logic for handling model aliases.
This commit is contained in:
20
example.env
20
example.env
@@ -8,6 +8,20 @@ TOGETHER_API_KEY=
|
|||||||
DEEPINFRA_API_KEY=
|
DEEPINFRA_API_KEY=
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
GROQ_API_KEY=
|
GROQ_API_KEY=
|
||||||
AZURE_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
AZURE_API_ENDPOINT=
|
AZURE_API_KEYS='{
|
||||||
AZURE_DEFAULT_MODEL=
|
"default": "",
|
||||||
|
"flux-1.1-pro": "",
|
||||||
|
"flux.1-kontext-pro": ""
|
||||||
|
}'
|
||||||
|
AZURE_DEFAULT_MODEL="model-router"
|
||||||
|
AZURE_ROUTES='{
|
||||||
|
"model-router": "https://HOST.cognitiveservices.azure.com/openai/deployments/model-router/chat/completions?api-version=2025-01-01-preview",
|
||||||
|
"deepseek-r1": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||||
|
"gpt-4.1": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4.1/chat/completions?api-version=2025-01-01-preview",
|
||||||
|
"gpt-4o-mini-audio-preview": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini-audio-preview/chat/completions?api-version=2025-01-01-preview",
|
||||||
|
"o4-mini": "https://HOST.cognitiveservices.azure.com/openai/deployments/o4-mini/chat/completions?api-version=2025-01-01-preview",
|
||||||
|
"grok-3": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||||
|
"flux-1.1-pro": "https://HOST.cognitiveservices.azure.com/openai/deployments/FLUX-1.1-pro/images/generations?api-version=2025-04-01-preview",
|
||||||
|
"flux.1-kontext-pro": "https://HOST.services.ai.azure.com/openai/deployments/FLUX.1-Kontext-pro/images/edits?api-version=2025-04-01-preview"
|
||||||
|
}'
|
||||||
@@ -2,10 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from .template import OpenaiTemplate
|
from .template import OpenaiTemplate
|
||||||
from ..errors import ModelNotFoundError
|
|
||||||
from ..config import DEFAULT_MODEL
|
from ..config import DEFAULT_MODEL
|
||||||
from .. import debug
|
|
||||||
|
|
||||||
|
|
||||||
class DeepInfraChat(OpenaiTemplate):
|
class DeepInfraChat(OpenaiTemplate):
|
||||||
parent = "DeepInfra"
|
parent = "DeepInfra"
|
||||||
@@ -86,7 +83,7 @@ class DeepInfraChat(OpenaiTemplate):
|
|||||||
|
|
||||||
# microsoft
|
# microsoft
|
||||||
"phi-4": "microsoft/phi-4",
|
"phi-4": "microsoft/phi-4",
|
||||||
"phi-4-multimodal": default_vision_model,
|
"phi-4-multimodal": "microsoft/Phi-4-multimodal-instruct",
|
||||||
"phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus",
|
"phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus",
|
||||||
"wizardlm-2-7b": "microsoft/WizardLM-2-7B",
|
"wizardlm-2-7b": "microsoft/WizardLM-2-7B",
|
||||||
"wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
|
"wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
|
||||||
@@ -100,29 +97,4 @@ class DeepInfraChat(OpenaiTemplate):
|
|||||||
"qwen-3-32b": "Qwen/Qwen3-32B",
|
"qwen-3-32b": "Qwen/Qwen3-32B",
|
||||||
"qwen-3-235b": "Qwen/Qwen3-235B-A22B",
|
"qwen-3-235b": "Qwen/Qwen3-235B-A22B",
|
||||||
"qwq-32b": "Qwen/QwQ-32B",
|
"qwq-32b": "Qwen/QwQ-32B",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model(cls, model: str, **kwargs) -> str:
|
|
||||||
"""Get the internal model name from the user-provided model name."""
|
|
||||||
# kwargs can contain api_key, api_base, etc. but we don't need them for model selection
|
|
||||||
if not model:
|
|
||||||
return cls.default_model
|
|
||||||
|
|
||||||
# Check if the model exists directly in our models list
|
|
||||||
if model in cls.models:
|
|
||||||
return model
|
|
||||||
|
|
||||||
# Check if there's an alias for this model
|
|
||||||
if model in cls.model_aliases:
|
|
||||||
alias = cls.model_aliases[model]
|
|
||||||
# If the alias is a list, randomly select one of the options
|
|
||||||
if isinstance(alias, list):
|
|
||||||
import random
|
|
||||||
selected_model = random.choice(alias)
|
|
||||||
debug.log(f"DeepInfraChat: Selected model '{selected_model}' from alias '{model}'")
|
|
||||||
return selected_model
|
|
||||||
debug.log(f"DeepInfraChat: Using model '{alias}' for alias '{model}'")
|
|
||||||
return alias
|
|
||||||
|
|
||||||
raise ModelNotFoundError(f"Model {model} not found")
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..providers.response import JsonConversation, Reasoning, TitleGeneration
|
from ..providers.response import JsonConversation, Reasoning, TitleGeneration
|
||||||
from ..requests import StreamSession, raise_for_status
|
from ..requests import StreamSession, raise_for_status
|
||||||
from ..config import DEFAULT_MODEL
|
from ..config import DEFAULT_MODEL
|
||||||
@@ -26,11 +26,14 @@ class GptOss(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
|
media: MediaListType = None,
|
||||||
conversation: JsonConversation = None,
|
conversation: JsonConversation = None,
|
||||||
reasoning_effort: str = "high",
|
reasoning_effort: str = "high",
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
|
if media:
|
||||||
|
raise ValueError("Media is not supported by gpt-oss")
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
user_message = get_last_user_message(messages)
|
user_message = get_last_user_message(messages)
|
||||||
cookies = {}
|
cookies = {}
|
||||||
|
|||||||
@@ -88,11 +88,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
vision_models = [default_vision_model]
|
vision_models = [default_vision_model]
|
||||||
_models_loaded = False
|
_models_loaded = False
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"gpt-4": "openai",
|
"gpt-4.1-nano": "openai",
|
||||||
"gpt-4o": "openai",
|
|
||||||
"gpt-4.1-mini": "openai",
|
|
||||||
"gpt-4o-mini": "openai",
|
|
||||||
"gpt-4.1-nano": "openai-fast",
|
|
||||||
"gpt-4.1": "openai-large",
|
"gpt-4.1": "openai-large",
|
||||||
"o4-mini": "openai-reasoning",
|
"o4-mini": "openai-reasoning",
|
||||||
"qwen-2.5-coder-32b": "qwen-coder",
|
"qwen-2.5-coder-32b": "qwen-coder",
|
||||||
@@ -106,7 +102,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"grok-3-mini": "grok",
|
"grok-3-mini": "grok",
|
||||||
"grok-3-mini-high": "grok",
|
"grok-3-mini-high": "grok",
|
||||||
"gpt-4o-mini-audio": "openai-audio",
|
"gpt-4o-mini-audio": "openai-audio",
|
||||||
"gpt-4o-audio": "openai-audio",
|
|
||||||
"sdxl-turbo": "turbo",
|
"sdxl-turbo": "turbo",
|
||||||
"gpt-image": "gptimage",
|
"gpt-image": "gptimage",
|
||||||
"flux-dev": "flux",
|
"flux-dev": "flux",
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..providers.types import BaseProvider, ProviderType
|
from ..providers.types import BaseProvider, ProviderType
|
||||||
from ..providers.retry_provider import RetryProvider, IterListProvider
|
from ..providers.retry_provider import RetryProvider, IterListProvider, RotatedProvider
|
||||||
from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider
|
from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider
|
||||||
from ..providers.create_images import CreateImagesProvider
|
from ..providers.create_images import CreateImagesProvider
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class Azure(OpenaiTemplate):
|
|||||||
raise ModelNotFoundError(f"No API endpoint found for model: {model}")
|
raise ModelNotFoundError(f"No API endpoint found for model: {model}")
|
||||||
if not api_endpoint:
|
if not api_endpoint:
|
||||||
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
|
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
|
||||||
if not api_key:
|
if cls.api_keys:
|
||||||
api_key = cls.api_keys.get(model, cls.api_keys.get("default"))
|
api_key = cls.api_keys.get(model, cls.api_keys.get("default"))
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
|
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from ..template import OpenaiTemplate
|
from ..template import OpenaiTemplate
|
||||||
from ...config import DEFAULT_MODEL
|
from ...config import DEFAULT_MODEL
|
||||||
from ...errors import ModelNotFoundError
|
|
||||||
from ... import debug
|
|
||||||
|
|
||||||
class Together(OpenaiTemplate):
|
class Together(OpenaiTemplate):
|
||||||
label = "Together"
|
label = "Together"
|
||||||
@@ -141,28 +139,4 @@ class Together(OpenaiTemplate):
|
|||||||
"flux-dev": ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev-lora"],
|
"flux-dev": ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev-lora"],
|
||||||
"flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro",
|
"flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro",
|
||||||
"flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev",
|
"flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model(cls, model: str, api_key: str = None, api_base: str = None) -> str:
|
|
||||||
"""Get the internal model name from the user-provided model name."""
|
|
||||||
if not model:
|
|
||||||
return cls.default_model
|
|
||||||
|
|
||||||
# Check if the model exists directly in our models list
|
|
||||||
if model in cls.models:
|
|
||||||
return model
|
|
||||||
|
|
||||||
# Check if there's an alias for this model
|
|
||||||
if model in cls.model_aliases:
|
|
||||||
alias = cls.model_aliases[model]
|
|
||||||
# If the alias is a list, randomly select one of the options
|
|
||||||
if isinstance(alias, list):
|
|
||||||
import random # Add this import at the top of the file
|
|
||||||
selected_model = random.choice(alias)
|
|
||||||
debug.log(f"Together: Selected model '{selected_model}' from alias '{model}'")
|
|
||||||
return selected_model
|
|
||||||
debug.log(f"Together: Using model '{alias}' for alias '{model}'")
|
|
||||||
return alias
|
|
||||||
|
|
||||||
raise ModelNotFoundError(f"Together: Model {model} not found")
|
|
||||||
@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
data = data.get("data") if isinstance(data, dict) else data
|
data = data.get("data") if isinstance(data, dict) else data
|
||||||
cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image")]
|
cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image") or model.get("type") == "image"]
|
||||||
cls.vision_models = cls.vision_models.copy()
|
cls.vision_models = cls.vision_models.copy()
|
||||||
cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")]
|
cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")]
|
||||||
cls.models = [model.get("id", model.get("name")) for model in data]
|
cls.models = [model.get("id", model.get("name")) for model in data]
|
||||||
|
|||||||
@@ -363,6 +363,7 @@ class Api:
|
|||||||
"owned_by": getattr(provider, "label", provider.__name__),
|
"owned_by": getattr(provider, "label", provider.__name__),
|
||||||
"image": model in getattr(provider, "image_models", []),
|
"image": model in getattr(provider, "image_models", []),
|
||||||
"vision": model in getattr(provider, "vision_models", []),
|
"vision": model in getattr(provider, "vision_models", []),
|
||||||
|
"type": "image" if model in getattr(provider, "image_models", []) else "text",
|
||||||
} for model in models]
|
} for model in models]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ class Backend_Api(Api):
|
|||||||
except MissingAuthError as e:
|
except MissingAuthError as e:
|
||||||
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401
|
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
|
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -6,7 +6,7 @@ import json
|
|||||||
from ..typing import AsyncResult, Messages, MediaListType, Union
|
from ..typing import AsyncResult, Messages, MediaListType, Union
|
||||||
from ..errors import ModelNotFoundError
|
from ..errors import ModelNotFoundError
|
||||||
from ..image import is_data_an_audio
|
from ..image import is_data_an_audio
|
||||||
from ..providers.retry_provider import IterListProvider
|
from ..providers.retry_provider import RotatedProvider
|
||||||
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
||||||
from ..Provider.hf_space import HuggingSpace
|
from ..Provider.hf_space import HuggingSpace
|
||||||
from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS
|
||||||
@@ -18,7 +18,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|||||||
from .. import Provider
|
from .. import Provider
|
||||||
from .. import models
|
from .. import models
|
||||||
from .. import debug
|
from .. import debug
|
||||||
from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents
|
from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents, model_aliases
|
||||||
|
|
||||||
PROVIERS_LIST_1 = [
|
PROVIERS_LIST_1 = [
|
||||||
CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
|
CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM,
|
||||||
@@ -73,6 +73,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
|||||||
models_count = models_count
|
models_count = models_count
|
||||||
models = list(model_map.keys())
|
models = list(model_map.keys())
|
||||||
model_map: dict[str, dict[str, str]] = model_map
|
model_map: dict[str, dict[str, str]] = model_map
|
||||||
|
model_aliases: dict[str, str] = model_aliases
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extend_ignored(cls, ignored: list[str]) -> list[str]:
|
def extend_ignored(cls, ignored: list[str]) -> list[str]:
|
||||||
@@ -102,7 +103,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
|||||||
cls.create_model_map()
|
cls.create_model_map()
|
||||||
file = os.path.join(os.path.dirname(__file__), "any_model_map.py")
|
file = os.path.join(os.path.dirname(__file__), "any_model_map.py")
|
||||||
with open(file, "w", encoding="utf-8") as f:
|
with open(file, "w", encoding="utf-8") as f:
|
||||||
for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents"]:
|
for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents", "model_aliases"]:
|
||||||
value = getattr(cls, key)
|
value = getattr(cls, key)
|
||||||
f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n")
|
f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n")
|
||||||
|
|
||||||
@@ -118,11 +119,14 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
|||||||
"default": {provider.__name__: "" for provider in models.default.best_provider.providers},
|
"default": {provider.__name__: "" for provider in models.default.best_provider.providers},
|
||||||
}
|
}
|
||||||
cls.model_map.update({
|
cls.model_map.update({
|
||||||
model.name: {
|
name: {
|
||||||
provider.__name__: model.get_long_name() for provider in providers
|
provider.__name__: model.get_long_name() for provider in providers
|
||||||
if provider.working
|
if provider.working
|
||||||
} for _, (model, providers) in models.__models__.items()
|
} for name, (model, providers) in models.__models__.items()
|
||||||
})
|
})
|
||||||
|
for name, (model, providers) in models.__models__.items():
|
||||||
|
if isinstance(model, models.ImageModel):
|
||||||
|
cls.image_models.append(name)
|
||||||
|
|
||||||
# Process special providers
|
# Process special providers
|
||||||
for provider in PROVIERS_LIST_2:
|
for provider in PROVIERS_LIST_2:
|
||||||
@@ -234,6 +238,11 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
|||||||
elif provider.__name__ not in cls.parents[provider.get_parent()]:
|
elif provider.__name__ not in cls.parents[provider.get_parent()]:
|
||||||
cls.parents[provider.get_parent()].append(provider.__name__)
|
cls.parents[provider.get_parent()].append(provider.__name__)
|
||||||
|
|
||||||
|
for model, providers in cls.model_map.items():
|
||||||
|
for provider, alias in providers.items():
|
||||||
|
if alias != model and isinstance(alias, str) and alias not in cls.model_map:
|
||||||
|
cls.model_aliases[alias] = model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
|
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
|
||||||
unsorted_models = cls.get_models(ignored=ignored)
|
unsorted_models = cls.get_models(ignored=ignored)
|
||||||
@@ -299,7 +308,7 @@ class AnyModelProviderMixin(ProviderModelMixin):
|
|||||||
groups["image"].append(model)
|
groups["image"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for OpenAI models
|
# Check for OpenAI models
|
||||||
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"):
|
elif model.startswith(("gpt-", "chatgpt-", "o1", "o1", "o3", "o4")) or model in ("auto", "searchgpt"):
|
||||||
groups["openai"].append(model)
|
groups["openai"].append(model)
|
||||||
added = True
|
added = True
|
||||||
# Check for video models
|
# Check for video models
|
||||||
@@ -371,10 +380,14 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
|
|||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
model = submodel
|
model = submodel
|
||||||
else:
|
else:
|
||||||
|
if model not in cls.model_map:
|
||||||
|
if model in cls.model_aliases:
|
||||||
|
model = cls.model_aliases[model]
|
||||||
if model in cls.model_map:
|
if model in cls.model_map:
|
||||||
for provider, alias in cls.model_map[model].items():
|
for provider, alias in cls.model_map[model].items():
|
||||||
provider = Provider.__map__[provider]
|
provider = Provider.__map__[provider]
|
||||||
provider.model_aliases[model] = alias
|
if model not in provider.model_aliases:
|
||||||
|
provider.model_aliases[model] = alias
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
if not providers:
|
if not providers:
|
||||||
for provider in PROVIERS_LIST_1:
|
for provider in PROVIERS_LIST_1:
|
||||||
@@ -390,7 +403,7 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
|
|||||||
|
|
||||||
debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'")
|
debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'")
|
||||||
|
|
||||||
async for chunk in IterListProvider(providers).create_async_generator(
|
async for chunk in RotatedProvider(providers).create_async_generator(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import random
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
@@ -21,6 +21,7 @@ from .response import BaseConversation, AuthResult
|
|||||||
from .helper import concat_chunks
|
from .helper import concat_chunks
|
||||||
from ..cookies import get_cookies_dir
|
from ..cookies import get_cookies_dir
|
||||||
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
|
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
|
||||||
|
from .. import debug
|
||||||
|
|
||||||
SAFE_PARAMETERS = [
|
SAFE_PARAMETERS = [
|
||||||
"model", "messages", "stream", "timeout",
|
"model", "messages", "stream", "timeout",
|
||||||
@@ -368,7 +369,13 @@ class ProviderModelMixin:
|
|||||||
if not model and cls.default_model is not None:
|
if not model and cls.default_model is not None:
|
||||||
model = cls.default_model
|
model = cls.default_model
|
||||||
if model in cls.model_aliases:
|
if model in cls.model_aliases:
|
||||||
model = cls.model_aliases[model]
|
alias = cls.model_aliases[model]
|
||||||
|
if isinstance(alias, list):
|
||||||
|
selected_model = random.choice(alias)
|
||||||
|
debug.log(f"{cls.__name__}: Selected model '{selected_model}' from alias '{model}'")
|
||||||
|
return selected_model
|
||||||
|
debug.log(f"{cls.__name__}: Using model '{alias}' for alias '{model}'")
|
||||||
|
return alias
|
||||||
if model not in cls.model_aliases.values():
|
if model not in cls.model_aliases.values():
|
||||||
if model not in cls.get_models(**kwargs) and cls.models:
|
if model not in cls.get_models(**kwargs) and cls.models:
|
||||||
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
|
raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
|
||||||
|
|||||||
@@ -2,13 +2,176 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ..typing import Type, List, CreateResult, Messages, AsyncResult
|
from ..typing import Dict, Type, List, CreateResult, Messages, AsyncResult
|
||||||
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
||||||
from .response import ProviderInfo, JsonConversation, is_content
|
from .response import ProviderInfo, JsonConversation, is_content
|
||||||
from .. import debug
|
from .. import debug
|
||||||
from ..tools.run_tools import AuthManager
|
from ..tools.run_tools import AuthManager
|
||||||
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
|
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
|
||||||
|
|
||||||
|
class RotatedProvider(BaseRetryProvider):
|
||||||
|
"""
|
||||||
|
A provider that rotates through a list of providers, attempting one provider per
|
||||||
|
request and advancing to the next one upon failure. This distributes load and
|
||||||
|
retries across multiple providers in a round-robin fashion.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
providers: List[Type[BaseProvider]],
|
||||||
|
shuffle: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the RotatedProvider.
|
||||||
|
Args:
|
||||||
|
providers (List[Type[BaseProvider]]): A non-empty list of providers to rotate through.
|
||||||
|
shuffle (bool): If True, shuffles the provider list once at initialization
|
||||||
|
to randomize the rotation order.
|
||||||
|
"""
|
||||||
|
if not isinstance(providers, list) or len(providers) == 0:
|
||||||
|
raise ValueError('RotatedProvider requires a non-empty list of providers.')
|
||||||
|
|
||||||
|
self.providers = providers
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(self.providers)
|
||||||
|
|
||||||
|
self.current_index = 0
|
||||||
|
self.last_provider: Type[BaseProvider] = None
|
||||||
|
|
||||||
|
def _get_current_provider(self) -> Type[BaseProvider]:
|
||||||
|
"""Gets the provider at the current index."""
|
||||||
|
return self.providers[self.current_index]
|
||||||
|
|
||||||
|
def _rotate_provider(self) -> None:
|
||||||
|
"""Rotates to the next provider in the list."""
|
||||||
|
self.current_index = (self.current_index + 1) % len(self.providers)
|
||||||
|
#new_provider_name = self.providers[self.current_index].__name__
|
||||||
|
#debug.log(f"Rotated to next provider: {new_provider_name}")
|
||||||
|
|
||||||
|
def create_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
ignored: list[str] = [], # 'ignored' is less relevant now but kept for compatibility
|
||||||
|
api_key: str = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> CreateResult:
|
||||||
|
"""
|
||||||
|
Create a completion using the current provider and rotating on failure.
|
||||||
|
|
||||||
|
It will try each provider in the list once per call, rotating after each
|
||||||
|
failed attempt, until one succeeds or all have failed.
|
||||||
|
"""
|
||||||
|
exceptions: Dict[str, Exception] = {}
|
||||||
|
|
||||||
|
# Loop over the number of providers, giving each one a chance
|
||||||
|
for _ in range(len(self.providers)):
|
||||||
|
provider = self._get_current_provider()
|
||||||
|
self.last_provider = provider
|
||||||
|
self._rotate_provider()
|
||||||
|
|
||||||
|
# Skip if provider is in the ignored list
|
||||||
|
if provider.get_parent() in ignored:
|
||||||
|
continue
|
||||||
|
|
||||||
|
alias = model or getattr(provider, "default_model", None)
|
||||||
|
if hasattr(provider, "model_aliases"):
|
||||||
|
alias = provider.model_aliases.get(model, model)
|
||||||
|
if isinstance(alias, list):
|
||||||
|
alias = random.choice(alias)
|
||||||
|
|
||||||
|
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
|
||||||
|
yield ProviderInfo(**provider.get_dict(), model=alias, alias=model)
|
||||||
|
|
||||||
|
extra_body = kwargs.copy()
|
||||||
|
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||||
|
if not current_api_key:
|
||||||
|
current_api_key = AuthManager.load_api_key(provider)
|
||||||
|
if current_api_key:
|
||||||
|
extra_body["api_key"] = current_api_key
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt to get a response from the current provider
|
||||||
|
response = provider.create_function(alias, messages, **extra_body)
|
||||||
|
started = False
|
||||||
|
for chunk in response:
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
if is_content(chunk):
|
||||||
|
started = True
|
||||||
|
if started:
|
||||||
|
# Success, so we return and do not rotate
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
exceptions[provider.__name__] = e
|
||||||
|
debug.error(f"{provider.__name__} failed: {e}")
|
||||||
|
|
||||||
|
# If the loop completes, all providers have failed
|
||||||
|
raise_exceptions(exceptions)
|
||||||
|
|
||||||
|
async def create_async_generator(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
ignored: list[str] = [],
|
||||||
|
api_key: str = None,
|
||||||
|
conversation: JsonConversation = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
"""
|
||||||
|
Asynchronously create a completion, rotating through providers on failure.
|
||||||
|
"""
|
||||||
|
exceptions: Dict[str, Exception] = {}
|
||||||
|
|
||||||
|
for _ in range(len(self.providers)):
|
||||||
|
provider = self._get_current_provider()
|
||||||
|
self._rotate_provider()
|
||||||
|
self.last_provider = provider
|
||||||
|
|
||||||
|
if provider.get_parent() in ignored:
|
||||||
|
continue
|
||||||
|
|
||||||
|
alias = model or getattr(provider, "default_model", None)
|
||||||
|
if hasattr(provider, "model_aliases"):
|
||||||
|
alias = provider.model_aliases.get(model, model)
|
||||||
|
if isinstance(alias, list):
|
||||||
|
alias = random.choice(alias)
|
||||||
|
|
||||||
|
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
|
||||||
|
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||||
|
|
||||||
|
extra_body = kwargs.copy()
|
||||||
|
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||||
|
if not current_api_key:
|
||||||
|
current_api_key = AuthManager.load_api_key(provider)
|
||||||
|
if current_api_key:
|
||||||
|
extra_body["api_key"] = current_api_key
|
||||||
|
if conversation and hasattr(conversation, provider.__name__):
|
||||||
|
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = provider.async_create_function(alias, messages, **extra_body)
|
||||||
|
started = False
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk, JsonConversation):
|
||||||
|
if conversation is None: conversation = JsonConversation()
|
||||||
|
setattr(conversation, provider.__name__, chunk.get_dict())
|
||||||
|
yield conversation
|
||||||
|
elif chunk:
|
||||||
|
yield chunk
|
||||||
|
if is_content(chunk):
|
||||||
|
started = True
|
||||||
|
if started:
|
||||||
|
return # Success
|
||||||
|
except Exception as e:
|
||||||
|
exceptions[provider.__name__] = e
|
||||||
|
debug.error(f"{provider.__name__} failed: {e}")
|
||||||
|
|
||||||
|
raise_exceptions(exceptions)
|
||||||
|
|
||||||
|
# Maintain API compatibility
|
||||||
|
create_function = create_completion
|
||||||
|
async_create_function = create_async_generator
|
||||||
|
|
||||||
class IterListProvider(BaseRetryProvider):
|
class IterListProvider(BaseRetryProvider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user