mirror of
				https://github.com/xtekky/gpt4free.git
				synced 2025-10-31 19:42:45 +08:00 
			
		
		
		
	feat: Update environment variables and modify model mappings
- Added `OPENROUTER_API_KEY` and `AZURE_API_KEYS` to `example.env`. - Updated `AZURE_DEFAULT_MODEL` to "model-router" in `example.env`. - Added `AZURE_ROUTES` with multiple model URLs in `example.env`. - Changed the mapping for `"phi-4-multimodal"` in `DeepInfraChat.py` to `"microsoft/Phi-4-multimodal-instruct"`. - Added `media` parameter to `GptOss.create_completion` method and raised a `ValueError` if `media` is provided. - Updated `model_aliases` in `any_model_map.py` to include new mappings for various models. - Removed several model aliases from `PollinationsAI` in `any_model_map.py`. - Added new models and updated existing models in `model_map` across various files, including `any_model_map.py` and `__init__.py`. - Refactored `AnyModelProviderMixin` to include `model_aliases` and updated the logic for handling model aliases.
This commit is contained in:
		
							
								
								
									
										20
									
								
								example.env
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								example.env
									
									
									
									
									
								
							| @@ -8,6 +8,20 @@ TOGETHER_API_KEY= | |||||||
| DEEPINFRA_API_KEY= | DEEPINFRA_API_KEY= | ||||||
| OPENAI_API_KEY= | OPENAI_API_KEY= | ||||||
| GROQ_API_KEY= | GROQ_API_KEY= | ||||||
| AZURE_API_KEY= | OPENROUTER_API_KEY= | ||||||
| AZURE_API_ENDPOINT= | AZURE_API_KEYS='{ | ||||||
| AZURE_DEFAULT_MODEL= |   "default": "", | ||||||
|  |   "flux-1.1-pro": "", | ||||||
|  |   "flux.1-kontext-pro": "" | ||||||
|  | }' | ||||||
|  | AZURE_DEFAULT_MODEL="model-router" | ||||||
|  | AZURE_ROUTES='{ | ||||||
|  |   "model-router": "https://HOST.cognitiveservices.azure.com/openai/deployments/model-router/chat/completions?api-version=2025-01-01-preview", | ||||||
|  |   "deepseek-r1": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", | ||||||
|  |   "gpt-4.1": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4.1/chat/completions?api-version=2025-01-01-preview", | ||||||
|  |   "gpt-4o-mini-audio-preview": "https://HOST.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini-audio-preview/chat/completions?api-version=2025-01-01-preview", | ||||||
|  |   "o4-mini": "https://HOST.cognitiveservices.azure.com/openai/deployments/o4-mini/chat/completions?api-version=2025-01-01-preview", | ||||||
|  |   "grok-3": "https://HOST.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", | ||||||
|  |   "flux-1.1-pro": "https://HOST.cognitiveservices.azure.com/openai/deployments/FLUX-1.1-pro/images/generations?api-version=2025-04-01-preview", | ||||||
|  |   "flux.1-kontext-pro": "https://HOST.services.ai.azure.com/openai/deployments/FLUX.1-Kontext-pro/images/edits?api-version=2025-04-01-preview" | ||||||
|  | }' | ||||||
| @@ -2,10 +2,7 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| import requests | import requests | ||||||
| from .template import OpenaiTemplate | from .template import OpenaiTemplate | ||||||
| from ..errors import ModelNotFoundError |  | ||||||
| from ..config import DEFAULT_MODEL | from ..config import DEFAULT_MODEL | ||||||
| from .. import debug |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeepInfraChat(OpenaiTemplate): | class DeepInfraChat(OpenaiTemplate): | ||||||
|     parent = "DeepInfra" |     parent = "DeepInfra" | ||||||
| @@ -86,7 +83,7 @@ class DeepInfraChat(OpenaiTemplate): | |||||||
|  |  | ||||||
|         # microsoft |         # microsoft | ||||||
|         "phi-4": "microsoft/phi-4", |         "phi-4": "microsoft/phi-4", | ||||||
|         "phi-4-multimodal": default_vision_model, |         "phi-4-multimodal": "microsoft/Phi-4-multimodal-instruct", | ||||||
|         "phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus", |         "phi-4-reasoning-plus": "microsoft/phi-4-reasoning-plus", | ||||||
|         "wizardlm-2-7b": "microsoft/WizardLM-2-7B", |         "wizardlm-2-7b": "microsoft/WizardLM-2-7B", | ||||||
|         "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B", |         "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B", | ||||||
| @@ -101,28 +98,3 @@ class DeepInfraChat(OpenaiTemplate): | |||||||
|         "qwen-3-235b": "Qwen/Qwen3-235B-A22B", |         "qwen-3-235b": "Qwen/Qwen3-235B-A22B", | ||||||
|         "qwq-32b": "Qwen/QwQ-32B", |         "qwq-32b": "Qwen/QwQ-32B", | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def get_model(cls, model: str, **kwargs) -> str: |  | ||||||
|         """Get the internal model name from the user-provided model name.""" |  | ||||||
|         # kwargs can contain api_key, api_base, etc. but we don't need them for model selection |  | ||||||
|         if not model: |  | ||||||
|             return cls.default_model |  | ||||||
|          |  | ||||||
|         # Check if the model exists directly in our models list |  | ||||||
|         if model in cls.models: |  | ||||||
|             return model |  | ||||||
|          |  | ||||||
|         # Check if there's an alias for this model |  | ||||||
|         if model in cls.model_aliases: |  | ||||||
|             alias = cls.model_aliases[model] |  | ||||||
|             # If the alias is a list, randomly select one of the options |  | ||||||
|             if isinstance(alias, list): |  | ||||||
|                 import random |  | ||||||
|                 selected_model = random.choice(alias) |  | ||||||
|                 debug.log(f"DeepInfraChat: Selected model '{selected_model}' from alias '{model}'") |  | ||||||
|                 return selected_model |  | ||||||
|             debug.log(f"DeepInfraChat: Using model '{alias}' for alias '{model}'") |  | ||||||
|             return alias |  | ||||||
|          |  | ||||||
|         raise ModelNotFoundError(f"Model {model} not found") |  | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
|  |  | ||||||
| from ..typing import AsyncResult, Messages | from ..typing import AsyncResult, Messages, MediaListType | ||||||
| from ..providers.response import JsonConversation, Reasoning, TitleGeneration | from ..providers.response import JsonConversation, Reasoning, TitleGeneration | ||||||
| from ..requests import StreamSession, raise_for_status | from ..requests import StreamSession, raise_for_status | ||||||
| from ..config import DEFAULT_MODEL | from ..config import DEFAULT_MODEL | ||||||
| @@ -26,11 +26,14 @@ class GptOss(AsyncGeneratorProvider, ProviderModelMixin): | |||||||
|         cls, |         cls, | ||||||
|         model: str, |         model: str, | ||||||
|         messages: Messages, |         messages: Messages, | ||||||
|  |         media: MediaListType = None, | ||||||
|         conversation: JsonConversation = None, |         conversation: JsonConversation = None, | ||||||
|         reasoning_effort: str = "high", |         reasoning_effort: str = "high", | ||||||
|         proxy: str = None, |         proxy: str = None, | ||||||
|         **kwargs |         **kwargs | ||||||
|     ) -> AsyncResult: |     ) -> AsyncResult: | ||||||
|  |         if media: | ||||||
|  |             raise ValueError("Media is not supported by gpt-oss") | ||||||
|         model = cls.get_model(model) |         model = cls.get_model(model) | ||||||
|         user_message = get_last_user_message(messages) |         user_message = get_last_user_message(messages) | ||||||
|         cookies = {} |         cookies = {} | ||||||
|   | |||||||
| @@ -88,11 +88,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): | |||||||
|     vision_models = [default_vision_model] |     vision_models = [default_vision_model] | ||||||
|     _models_loaded = False |     _models_loaded = False | ||||||
|     model_aliases = { |     model_aliases = { | ||||||
|         "gpt-4": "openai", |         "gpt-4.1-nano": "openai", | ||||||
|         "gpt-4o": "openai", |  | ||||||
|         "gpt-4.1-mini": "openai", |  | ||||||
|         "gpt-4o-mini": "openai", |  | ||||||
|         "gpt-4.1-nano": "openai-fast", |  | ||||||
|         "gpt-4.1": "openai-large", |         "gpt-4.1": "openai-large", | ||||||
|         "o4-mini": "openai-reasoning", |         "o4-mini": "openai-reasoning", | ||||||
|         "qwen-2.5-coder-32b": "qwen-coder", |         "qwen-2.5-coder-32b": "qwen-coder", | ||||||
| @@ -106,7 +102,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): | |||||||
|         "grok-3-mini": "grok", |         "grok-3-mini": "grok", | ||||||
|         "grok-3-mini-high": "grok", |         "grok-3-mini-high": "grok", | ||||||
|         "gpt-4o-mini-audio": "openai-audio", |         "gpt-4o-mini-audio": "openai-audio", | ||||||
|         "gpt-4o-audio": "openai-audio", |  | ||||||
|         "sdxl-turbo": "turbo", |         "sdxl-turbo": "turbo", | ||||||
|         "gpt-image": "gptimage", |         "gpt-image": "gptimage", | ||||||
|         "flux-dev": "flux", |         "flux-dev": "flux", | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| from ..providers.types          import BaseProvider, ProviderType | from ..providers.types          import BaseProvider, ProviderType | ||||||
| from ..providers.retry_provider import RetryProvider, IterListProvider | from ..providers.retry_provider import RetryProvider, IterListProvider, RotatedProvider | ||||||
| from ..providers.base_provider  import AsyncProvider, AsyncGeneratorProvider | from ..providers.base_provider  import AsyncProvider, AsyncGeneratorProvider | ||||||
| from ..providers.create_images  import CreateImagesProvider | from ..providers.create_images  import CreateImagesProvider | ||||||
| from .. import debug | from .. import debug | ||||||
|   | |||||||
| @@ -79,7 +79,7 @@ class Azure(OpenaiTemplate): | |||||||
|                 raise ModelNotFoundError(f"No API endpoint found for model: {model}") |                 raise ModelNotFoundError(f"No API endpoint found for model: {model}") | ||||||
|         if not api_endpoint: |         if not api_endpoint: | ||||||
|             api_endpoint = os.environ.get("AZURE_API_ENDPOINT") |             api_endpoint = os.environ.get("AZURE_API_ENDPOINT") | ||||||
|         if not api_key: |         if cls.api_keys: | ||||||
|             api_key = cls.api_keys.get(model, cls.api_keys.get("default")) |             api_key = cls.api_keys.get(model, cls.api_keys.get("default")) | ||||||
|             if not api_key: |             if not api_key: | ||||||
|                 raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.") |                 raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.") | ||||||
|   | |||||||
| @@ -3,8 +3,6 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| from ..template import OpenaiTemplate | from ..template import OpenaiTemplate | ||||||
| from ...config import DEFAULT_MODEL | from ...config import DEFAULT_MODEL | ||||||
| from ...errors import ModelNotFoundError |  | ||||||
| from ... import debug |  | ||||||
|  |  | ||||||
| class Together(OpenaiTemplate): | class Together(OpenaiTemplate): | ||||||
|     label = "Together" |     label = "Together" | ||||||
| @@ -142,27 +140,3 @@ class Together(OpenaiTemplate): | |||||||
|         "flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro", |         "flux-kontext-pro": "black-forest-labs/FLUX.1-kontext-pro", | ||||||
|         "flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev", |         "flux-kontext-dev": "black-forest-labs/FLUX.1-kontext-dev", | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def get_model(cls, model: str, api_key: str = None, api_base: str = None) -> str: |  | ||||||
|         """Get the internal model name from the user-provided model name.""" |  | ||||||
|         if not model: |  | ||||||
|             return cls.default_model |  | ||||||
|          |  | ||||||
|         # Check if the model exists directly in our models list |  | ||||||
|         if model in cls.models: |  | ||||||
|             return model |  | ||||||
|          |  | ||||||
|         # Check if there's an alias for this model |  | ||||||
|         if model in cls.model_aliases: |  | ||||||
|             alias = cls.model_aliases[model] |  | ||||||
|             # If the alias is a list, randomly select one of the options |  | ||||||
|             if isinstance(alias, list): |  | ||||||
|                 import random  # Add this import at the top of the file |  | ||||||
|                 selected_model = random.choice(alias) |  | ||||||
|                 debug.log(f"Together: Selected model '{selected_model}' from alias '{model}'") |  | ||||||
|                 return selected_model |  | ||||||
|             debug.log(f"Together: Using model '{alias}' for alias '{model}'") |  | ||||||
|             return alias |  | ||||||
|          |  | ||||||
|         raise ModelNotFoundError(f"Together: Model {model} not found") |  | ||||||
| @@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin | |||||||
|                 raise_for_status(response) |                 raise_for_status(response) | ||||||
|                 data = response.json() |                 data = response.json() | ||||||
|                 data = data.get("data") if isinstance(data, dict) else data |                 data = data.get("data") if isinstance(data, dict) else data | ||||||
|                 cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image")] |                 cls.image_models = [model.get("id", model.get("name")) for model in data if model.get("image") or model.get("type") == "image"] | ||||||
|                 cls.vision_models = cls.vision_models.copy() |                 cls.vision_models = cls.vision_models.copy() | ||||||
|                 cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")] |                 cls.vision_models += [model.get("id", model.get("name")) for model in data if model.get("vision")] | ||||||
|                 cls.models = [model.get("id", model.get("name")) for model in data] |                 cls.models = [model.get("id", model.get("name")) for model in data] | ||||||
|   | |||||||
| @@ -363,6 +363,7 @@ class Api: | |||||||
|                     "owned_by": getattr(provider, "label", provider.__name__), |                     "owned_by": getattr(provider, "label", provider.__name__), | ||||||
|                     "image": model in getattr(provider, "image_models", []), |                     "image": model in getattr(provider, "image_models", []), | ||||||
|                     "vision": model in getattr(provider, "vision_models", []), |                     "vision": model in getattr(provider, "vision_models", []), | ||||||
|  |                     "type": "image" if model in getattr(provider, "image_models", []) else "text", | ||||||
|                 } for model in models] |                 } for model in models] | ||||||
|             } |             } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -127,6 +127,7 @@ class Backend_Api(Api): | |||||||
|             except MissingAuthError as e: |             except MissingAuthError as e: | ||||||
|                 return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401 |                 return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 401 | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|  |                 logger.exception(e) | ||||||
|                 return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500 |                 return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500 | ||||||
|             return jsonify(response) |             return jsonify(response) | ||||||
|  |  | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -6,7 +6,7 @@ import json | |||||||
| from ..typing import AsyncResult, Messages, MediaListType, Union | from ..typing import AsyncResult, Messages, MediaListType, Union | ||||||
| from ..errors import ModelNotFoundError | from ..errors import ModelNotFoundError | ||||||
| from ..image import is_data_an_audio | from ..image import is_data_an_audio | ||||||
| from ..providers.retry_provider import IterListProvider | from ..providers.retry_provider import RotatedProvider | ||||||
| from ..Provider.needs_auth import OpenaiChat, CopilotAccount | from ..Provider.needs_auth import OpenaiChat, CopilotAccount | ||||||
| from ..Provider.hf_space import HuggingSpace | from ..Provider.hf_space import HuggingSpace | ||||||
| from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS | from ..Provider import Copilot, Cloudflare, Gemini, GeminiPro, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, PuterJS | ||||||
| @@ -18,7 +18,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin | |||||||
| from .. import Provider | from .. import Provider | ||||||
| from .. import models | from .. import models | ||||||
| from .. import debug | from .. import debug | ||||||
| from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents | from .any_model_map import audio_models, image_models, vision_models, video_models, model_map, models_count, parents, model_aliases | ||||||
|  |  | ||||||
| PROVIERS_LIST_1 = [ | PROVIERS_LIST_1 = [ | ||||||
|     CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM, |     CopilotAccount, OpenaiChat, Cloudflare, PerplexityLabs, Gemini, Grok, DeepSeekAPI, Blackbox, OpenAIFM, | ||||||
| @@ -73,6 +73,7 @@ class AnyModelProviderMixin(ProviderModelMixin): | |||||||
|     models_count = models_count |     models_count = models_count | ||||||
|     models = list(model_map.keys()) |     models = list(model_map.keys()) | ||||||
|     model_map: dict[str, dict[str, str]] = model_map |     model_map: dict[str, dict[str, str]] = model_map | ||||||
|  |     model_aliases: dict[str, str] = model_aliases | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def extend_ignored(cls, ignored: list[str]) -> list[str]: |     def extend_ignored(cls, ignored: list[str]) -> list[str]: | ||||||
| @@ -102,7 +103,7 @@ class AnyModelProviderMixin(ProviderModelMixin): | |||||||
|         cls.create_model_map() |         cls.create_model_map() | ||||||
|         file = os.path.join(os.path.dirname(__file__), "any_model_map.py") |         file = os.path.join(os.path.dirname(__file__), "any_model_map.py") | ||||||
|         with open(file, "w", encoding="utf-8") as f: |         with open(file, "w", encoding="utf-8") as f: | ||||||
|             for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents"]: |             for key in ["audio_models", "image_models", "vision_models", "video_models", "model_map", "models_count", "parents", "model_aliases"]: | ||||||
|                 value = getattr(cls, key) |                 value = getattr(cls, key) | ||||||
|                 f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n") |                 f.write(f"{key} = {json.dumps(value, indent=2) if isinstance(value, dict) else repr(value)}\n") | ||||||
|  |  | ||||||
| @@ -118,11 +119,14 @@ class AnyModelProviderMixin(ProviderModelMixin): | |||||||
|             "default": {provider.__name__: "" for provider in models.default.best_provider.providers}, |             "default": {provider.__name__: "" for provider in models.default.best_provider.providers}, | ||||||
|         } |         } | ||||||
|         cls.model_map.update({  |         cls.model_map.update({  | ||||||
|             model.name: { |             name: { | ||||||
|                 provider.__name__: model.get_long_name() for provider in providers |                 provider.__name__: model.get_long_name() for provider in providers | ||||||
|                 if provider.working |                 if provider.working | ||||||
|             } for _, (model, providers) in models.__models__.items() |             } for name, (model, providers) in models.__models__.items() | ||||||
|         }) |         }) | ||||||
|  |         for name, (model, providers) in models.__models__.items(): | ||||||
|  |             if isinstance(model, models.ImageModel): | ||||||
|  |                 cls.image_models.append(name) | ||||||
|  |  | ||||||
|         # Process special providers |         # Process special providers | ||||||
|         for provider in PROVIERS_LIST_2: |         for provider in PROVIERS_LIST_2: | ||||||
| @@ -234,6 +238,11 @@ class AnyModelProviderMixin(ProviderModelMixin): | |||||||
|                 elif provider.__name__ not in cls.parents[provider.get_parent()]: |                 elif provider.__name__ not in cls.parents[provider.get_parent()]: | ||||||
|                     cls.parents[provider.get_parent()].append(provider.__name__) |                     cls.parents[provider.get_parent()].append(provider.__name__) | ||||||
|  |  | ||||||
|  |         for model, providers in cls.model_map.items(): | ||||||
|  |             for provider, alias in providers.items(): | ||||||
|  |                 if alias != model and isinstance(alias, str) and alias not in cls.model_map: | ||||||
|  |                     cls.model_aliases[alias] = model | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]: |     def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]: | ||||||
|         unsorted_models = cls.get_models(ignored=ignored) |         unsorted_models = cls.get_models(ignored=ignored) | ||||||
| @@ -299,7 +308,7 @@ class AnyModelProviderMixin(ProviderModelMixin): | |||||||
|                 groups["image"].append(model) |                 groups["image"].append(model) | ||||||
|                 added = True |                 added = True | ||||||
|             # Check for OpenAI models |             # Check for OpenAI models | ||||||
|             elif model.startswith(("gpt-", "chatgpt-", "o1", "o1-", "o3-", "o4-")) or model in ("auto", "searchgpt"): |             elif model.startswith(("gpt-", "chatgpt-", "o1", "o1", "o3", "o4")) or model in ("auto", "searchgpt"): | ||||||
|                 groups["openai"].append(model) |                 groups["openai"].append(model) | ||||||
|                 added = True |                 added = True | ||||||
|             # Check for video models |             # Check for video models | ||||||
| @@ -371,9 +380,13 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin): | |||||||
|                     providers.append(provider) |                     providers.append(provider) | ||||||
|                     model = submodel |                     model = submodel | ||||||
|         else: |         else: | ||||||
|  |             if model not in cls.model_map: | ||||||
|  |                 if model in cls.model_aliases: | ||||||
|  |                     model = cls.model_aliases[model] | ||||||
|             if model in cls.model_map: |             if model in cls.model_map: | ||||||
|                 for provider, alias in cls.model_map[model].items(): |                 for provider, alias in cls.model_map[model].items(): | ||||||
|                     provider = Provider.__map__[provider] |                     provider = Provider.__map__[provider] | ||||||
|  |                     if model not in provider.model_aliases: | ||||||
|                         provider.model_aliases[model] = alias |                         provider.model_aliases[model] = alias | ||||||
|                     providers.append(provider) |                     providers.append(provider) | ||||||
|         if not providers: |         if not providers: | ||||||
| @@ -390,7 +403,7 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin): | |||||||
|  |  | ||||||
|         debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'") |         debug.log(f"AnyProvider: Using providers: {[provider.__name__ for provider in providers]} for model '{model}'") | ||||||
|  |  | ||||||
|         async for chunk in IterListProvider(providers).create_async_generator( |         async for chunk in RotatedProvider(providers).create_async_generator( | ||||||
|             model, |             model, | ||||||
|             messages, |             messages, | ||||||
|             stream=stream, |             stream=stream, | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| import asyncio | import asyncio | ||||||
|  | import random | ||||||
| from asyncio import AbstractEventLoop | from asyncio import AbstractEventLoop | ||||||
| from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||||
| from abc import abstractmethod | from abc import abstractmethod | ||||||
| @@ -21,6 +21,7 @@ from .response import BaseConversation, AuthResult | |||||||
| from .helper import concat_chunks | from .helper import concat_chunks | ||||||
| from ..cookies import get_cookies_dir | from ..cookies import get_cookies_dir | ||||||
| from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError | from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError | ||||||
|  | from .. import debug | ||||||
|  |  | ||||||
| SAFE_PARAMETERS = [ | SAFE_PARAMETERS = [ | ||||||
|     "model", "messages", "stream", "timeout", |     "model", "messages", "stream", "timeout", | ||||||
| @@ -368,7 +369,13 @@ class ProviderModelMixin: | |||||||
|         if not model and cls.default_model is not None: |         if not model and cls.default_model is not None: | ||||||
|             model = cls.default_model |             model = cls.default_model | ||||||
|         if model in cls.model_aliases: |         if model in cls.model_aliases: | ||||||
|             model = cls.model_aliases[model] |             alias = cls.model_aliases[model] | ||||||
|  |             if isinstance(alias, list): | ||||||
|  |                 selected_model = random.choice(alias) | ||||||
|  |                 debug.log(f"{cls.__name__}: Selected model '{selected_model}' from alias '{model}'") | ||||||
|  |                 return selected_model | ||||||
|  |             debug.log(f"{cls.__name__}: Using model '{alias}' for alias '{model}'") | ||||||
|  |             return alias | ||||||
|         if model not in cls.model_aliases.values(): |         if model not in cls.model_aliases.values(): | ||||||
|             if model not in cls.get_models(**kwargs) and cls.models: |             if model not in cls.get_models(**kwargs) and cls.models: | ||||||
|                 raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}") |                 raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}") | ||||||
|   | |||||||
| @@ -2,13 +2,176 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| import random | import random | ||||||
|  |  | ||||||
| from ..typing import Type, List, CreateResult, Messages, AsyncResult | from ..typing import Dict, Type, List, CreateResult, Messages, AsyncResult | ||||||
| from .types import BaseProvider, BaseRetryProvider, ProviderType | from .types import BaseProvider, BaseRetryProvider, ProviderType | ||||||
| from .response import ProviderInfo, JsonConversation, is_content | from .response import ProviderInfo, JsonConversation, is_content | ||||||
| from .. import debug | from .. import debug | ||||||
| from ..tools.run_tools import AuthManager | from ..tools.run_tools import AuthManager | ||||||
| from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError | from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError | ||||||
|  |  | ||||||
|  | class RotatedProvider(BaseRetryProvider): | ||||||
|  |     """ | ||||||
|  |     A provider that rotates through a list of providers, attempting one provider per | ||||||
|  |     request and advancing to the next one upon failure. This distributes load and | ||||||
|  |     retries across multiple providers in a round-robin fashion. | ||||||
|  |     """ | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         providers: List[Type[BaseProvider]], | ||||||
|  |         shuffle: bool = True | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         Initialize the RotatedProvider. | ||||||
|  |         Args: | ||||||
|  |             providers (List[Type[BaseProvider]]): A non-empty list of providers to rotate through. | ||||||
|  |             shuffle (bool): If True, shuffles the provider list once at initialization | ||||||
|  |                             to randomize the rotation order. | ||||||
|  |         """ | ||||||
|  |         if not isinstance(providers, list) or len(providers) == 0: | ||||||
|  |             raise ValueError('RotatedProvider requires a non-empty list of providers.') | ||||||
|  |          | ||||||
|  |         self.providers = providers | ||||||
|  |         if shuffle: | ||||||
|  |             random.shuffle(self.providers) | ||||||
|  |              | ||||||
|  |         self.current_index = 0 | ||||||
|  |         self.last_provider: Type[BaseProvider] = None | ||||||
|  |  | ||||||
|  |     def _get_current_provider(self) -> Type[BaseProvider]: | ||||||
|  |         """Gets the provider at the current index.""" | ||||||
|  |         return self.providers[self.current_index] | ||||||
|  |  | ||||||
|  |     def _rotate_provider(self) -> None: | ||||||
|  |         """Rotates to the next provider in the list.""" | ||||||
|  |         self.current_index = (self.current_index + 1) % len(self.providers) | ||||||
|  |         #new_provider_name = self.providers[self.current_index].__name__ | ||||||
|  |         #debug.log(f"Rotated to next provider: {new_provider_name}") | ||||||
|  |  | ||||||
|  |     def create_completion( | ||||||
|  |         self, | ||||||
|  |         model: str, | ||||||
|  |         messages: Messages, | ||||||
|  |         ignored: list[str] = [], # 'ignored' is less relevant now but kept for compatibility | ||||||
|  |         api_key: str = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ) -> CreateResult: | ||||||
|  |         """ | ||||||
|  |         Create a completion using the current provider and rotating on failure. | ||||||
|  |          | ||||||
|  |         It will try each provider in the list once per call, rotating after each | ||||||
|  |         failed attempt, until one succeeds or all have failed. | ||||||
|  |         """ | ||||||
|  |         exceptions: Dict[str, Exception] = {} | ||||||
|  |          | ||||||
|  |         # Loop over the number of providers, giving each one a chance | ||||||
|  |         for _ in range(len(self.providers)): | ||||||
|  |             provider = self._get_current_provider() | ||||||
|  |             self.last_provider = provider | ||||||
|  |             self._rotate_provider() | ||||||
|  |  | ||||||
|  |             # Skip if provider is in the ignored list | ||||||
|  |             if provider.get_parent() in ignored: | ||||||
|  |                 continue | ||||||
|  |              | ||||||
|  |             alias = model or getattr(provider, "default_model", None) | ||||||
|  |             if hasattr(provider, "model_aliases"): | ||||||
|  |                 alias = provider.model_aliases.get(model, model) | ||||||
|  |             if isinstance(alias, list): | ||||||
|  |                 alias = random.choice(alias) | ||||||
|  |              | ||||||
|  |             debug.log(f"Attempting provider: {provider.__name__} with model: {alias}") | ||||||
|  |             yield ProviderInfo(**provider.get_dict(), model=alias, alias=model) | ||||||
|  |              | ||||||
|  |             extra_body = kwargs.copy() | ||||||
|  |             current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key | ||||||
|  |             if not current_api_key: | ||||||
|  |                 current_api_key = AuthManager.load_api_key(provider) | ||||||
|  |             if current_api_key: | ||||||
|  |                 extra_body["api_key"] = current_api_key | ||||||
|  |              | ||||||
|  |             try: | ||||||
|  |                 # Attempt to get a response from the current provider | ||||||
|  |                 response = provider.create_function(alias, messages, **extra_body) | ||||||
|  |                 started = False | ||||||
|  |                 for chunk in response: | ||||||
|  |                     if chunk: | ||||||
|  |                         yield chunk | ||||||
|  |                         if is_content(chunk): | ||||||
|  |                             started = True | ||||||
|  |                 if started: | ||||||
|  |                     # Success, so we return and do not rotate | ||||||
|  |                     return | ||||||
|  |             except Exception as e: | ||||||
|  |                 exceptions[provider.__name__] = e | ||||||
|  |                 debug.error(f"{provider.__name__} failed: {e}") | ||||||
|  |          | ||||||
|  |         # If the loop completes, all providers have failed | ||||||
|  |         raise_exceptions(exceptions) | ||||||
|  |  | ||||||
|  |     async def create_async_generator( | ||||||
|  |         self, | ||||||
|  |         model: str, | ||||||
|  |         messages: Messages, | ||||||
|  |         ignored: list[str] = [], | ||||||
|  |         api_key: str = None, | ||||||
|  |         conversation: JsonConversation = None, | ||||||
|  |         **kwargs | ||||||
|  |     ) -> AsyncResult: | ||||||
|  |         """ | ||||||
|  |         Asynchronously create a completion, rotating through providers on failure. | ||||||
|  |         """ | ||||||
|  |         exceptions: Dict[str, Exception] = {} | ||||||
|  |  | ||||||
|  |         for _ in range(len(self.providers)): | ||||||
|  |             provider = self._get_current_provider() | ||||||
|  |             self._rotate_provider() | ||||||
|  |             self.last_provider = provider | ||||||
|  |  | ||||||
|  |             if provider.get_parent() in ignored: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             alias = model or getattr(provider, "default_model", None) | ||||||
|  |             if hasattr(provider, "model_aliases"): | ||||||
|  |                 alias = provider.model_aliases.get(model, model) | ||||||
|  |             if isinstance(alias, list): | ||||||
|  |                 alias = random.choice(alias) | ||||||
|  |              | ||||||
|  |             debug.log(f"Attempting provider: {provider.__name__} with model: {alias}") | ||||||
|  |             yield ProviderInfo(**provider.get_dict(), model=alias) | ||||||
|  |              | ||||||
|  |             extra_body = kwargs.copy() | ||||||
|  |             current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key | ||||||
|  |             if not current_api_key: | ||||||
|  |                 current_api_key = AuthManager.load_api_key(provider) | ||||||
|  |             if current_api_key: | ||||||
|  |                 extra_body["api_key"] = current_api_key | ||||||
|  |             if conversation and hasattr(conversation, provider.__name__): | ||||||
|  |                 extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__)) | ||||||
|  |              | ||||||
|  |             try: | ||||||
|  |                 response = provider.async_create_function(alias, messages, **extra_body) | ||||||
|  |                 started = False | ||||||
|  |                 async for chunk in response: | ||||||
|  |                     if isinstance(chunk, JsonConversation): | ||||||
|  |                         if conversation is None: conversation = JsonConversation() | ||||||
|  |                         setattr(conversation, provider.__name__, chunk.get_dict()) | ||||||
|  |                         yield conversation | ||||||
|  |                     elif chunk: | ||||||
|  |                         yield chunk | ||||||
|  |                         if is_content(chunk): | ||||||
|  |                             started = True | ||||||
|  |                 if started: | ||||||
|  |                     return # Success | ||||||
|  |             except Exception as e: | ||||||
|  |                 exceptions[provider.__name__] = e | ||||||
|  |                 debug.error(f"{provider.__name__} failed: {e}") | ||||||
|  |                  | ||||||
|  |         raise_exceptions(exceptions) | ||||||
|  |  | ||||||
|  |     # Maintain API compatibility | ||||||
|  |     create_function = create_completion | ||||||
|  |     async_create_function = create_async_generator | ||||||
|  |  | ||||||
| class IterListProvider(BaseRetryProvider): | class IterListProvider(BaseRetryProvider): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 hlohaus
					hlohaus