mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-09-27 04:36:17 +08:00
fix: resolve model duplication and improve provider handling
- Fixed duplicate model entries in Blackbox provider model_aliases - Added meta-llama- to llama- name cleaning in Cloudflare provider - Enhanced PollinationsAI provider with improved vision model detection - Added reasoning support to PollinationsAI provider - Fixed HuggingChat authentication to include headers and impersonate - Removed unused max_inputs_length parameter from HuggingFaceAPI - Renamed extra_data to extra_body for consistency across providers - Added Puter provider with grouped model support - Enhanced AnyProvider with grouped model display and better model organization - Fixed model cleaning in AnyProvider to handle more model name variations - Added api_key handling for HuggingFace providers in AnyProvider - Added see_stream helper function to parse event streams - Updated GUI server to handle JsonConversation properly - Fixed aspect ratio handling in image generation functions - Added ResponsesConfig and ClientResponse for new API endpoint - Updated requirements to include markitdown
This commit is contained in:
@@ -278,16 +278,11 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"qwerky-72b": "Qwerky 72B",
|
"qwerky-72b": "Qwerky 72B",
|
||||||
"qwq-32b": "QwQ 32B",
|
"qwq-32b": "QwQ 32B",
|
||||||
"qwq-32b-preview": "QwQ 32B Preview",
|
"qwq-32b-preview": "QwQ 32B Preview",
|
||||||
"qwq-32b": "QwQ 32B Preview",
|
|
||||||
"qwq-32b-arliai": "QwQ 32B RpR v1",
|
"qwq-32b-arliai": "QwQ 32B RpR v1",
|
||||||
"qwq-32b": "QwQ 32B RpR v1",
|
|
||||||
"deepseek-r1": "R1",
|
"deepseek-r1": "R1",
|
||||||
"deepseek-r1-distill-llama-70b": "R1 Distill Llama 70B",
|
"deepseek-r1-distill-llama-70b": "R1 Distill Llama 70B",
|
||||||
"deepseek-r1": "R1 Distill Llama 70B",
|
|
||||||
"deepseek-r1-distill-qwen-14b": "R1 Distill Qwen 14B",
|
"deepseek-r1-distill-qwen-14b": "R1 Distill Qwen 14B",
|
||||||
"deepseek-r1": "R1 Distill Qwen 14B",
|
|
||||||
"deepseek-r1-distill-qwen-32b": "R1 Distill Qwen 32B",
|
"deepseek-r1-distill-qwen-32b": "R1 Distill Qwen 32B",
|
||||||
"deepseek-r1": "R1 Distill Qwen 32B",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@@ -60,7 +60,8 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||||||
"-int8", "").replace(
|
"-int8", "").replace(
|
||||||
"-awq", "").replace(
|
"-awq", "").replace(
|
||||||
"-qvq", "").replace(
|
"-qvq", "").replace(
|
||||||
"-r1", "")
|
"-r1", "").replace(
|
||||||
|
"meta-llama-", "llama-")
|
||||||
model_map = {clean_name(model.get("name")): model.get("name") for model in json_data.get("models")}
|
model_map = {clean_name(model.get("name")): model.get("name") for model in json_data.get("models")}
|
||||||
cls.models = list(model_map.keys())
|
cls.models = list(model_map.keys())
|
||||||
cls.model_aliases = {**cls.model_aliases, **model_map}
|
cls.model_aliases = {**cls.model_aliases, **model_map}
|
||||||
|
@@ -13,11 +13,12 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|||||||
from ..typing import AsyncResult, Messages, MediaListType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..image import is_data_an_audio
|
from ..image import is_data_an_audio
|
||||||
from ..errors import ModelNotFoundError, ResponseError
|
from ..errors import ModelNotFoundError, ResponseError
|
||||||
|
from ..requests import see_stream
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
from ..requests.aiohttp import get_connector
|
from ..requests.aiohttp import get_connector
|
||||||
from ..image.copy_images import save_response_media
|
from ..image.copy_images import save_response_media
|
||||||
from ..image import use_aspect_ratio
|
from ..image import use_aspect_ratio
|
||||||
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse
|
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse, Reasoning
|
||||||
from ..tools.media import render_messages
|
from ..tools.media import render_messages
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"mistral-small-3.1-24b": "mistral",
|
"mistral-small-3.1-24b": "mistral",
|
||||||
"deepseek-r1": "deepseek-reasoning-large",
|
"deepseek-r1": "deepseek-reasoning-large",
|
||||||
"deepseek-r1-distill-llama-70b": "deepseek-reasoning-large",
|
"deepseek-r1-distill-llama-70b": "deepseek-reasoning-large",
|
||||||
"deepseek-r1-distill-llama-70b": "deepseek-r1-llama",
|
#"deepseek-r1-distill-llama-70b": "deepseek-r1-llama",
|
||||||
#"mistral-small-3.1-24b": "unity", # Personas
|
#"mistral-small-3.1-24b": "unity", # Personas
|
||||||
#"mirexa": "mirexa", # Personas
|
#"mirexa": "mirexa", # Personas
|
||||||
#"midijourney": "midijourney", # Personas
|
#"midijourney": "midijourney", # Personas
|
||||||
@@ -90,10 +91,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"deepseek-v3": "deepseek",
|
"deepseek-v3": "deepseek",
|
||||||
"deepseek-v3-0324": "deepseek",
|
"deepseek-v3-0324": "deepseek",
|
||||||
#"bidara": "bidara", # Personas
|
#"bidara": "bidara", # Personas
|
||||||
|
|
||||||
### Audio Models ###
|
### Audio Models ###
|
||||||
"gpt-4o-audio": "openai-audio",
|
"gpt-4o-audio": "openai-audio",
|
||||||
|
|
||||||
### Image Models ###
|
### Image Models ###
|
||||||
"sdxl-turbo": "turbo",
|
"sdxl-turbo": "turbo",
|
||||||
}
|
}
|
||||||
@@ -146,10 +147,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
for model in models
|
for model in models
|
||||||
if "output_modalities" in model and "audio" in model["output_modalities"] and model.get("name") != "gemini"
|
if "output_modalities" in model and "audio" in model["output_modalities"] and model.get("name") != "gemini"
|
||||||
}
|
}
|
||||||
|
|
||||||
if cls.default_audio_model in cls.audio_models:
|
if cls.default_audio_model in cls.audio_models:
|
||||||
cls.audio_models = {**cls.audio_models, **{voice: {} for voice in cls.audio_models[cls.default_audio_model]}}
|
cls.audio_models = {**cls.audio_models, **{voice: {} for voice in cls.audio_models[cls.default_audio_model]}}
|
||||||
|
|
||||||
|
cls.vision_models.extend([
|
||||||
|
model.get("name")
|
||||||
|
for model in models
|
||||||
|
if model.get("vision") and model not in cls.vision_models
|
||||||
|
])
|
||||||
|
for alias, model in cls.model_aliases.items():
|
||||||
|
if model in cls.vision_models and alias not in cls.vision_models:
|
||||||
|
cls.vision_models.append(alias)
|
||||||
|
|
||||||
# Create a set of unique text models starting with default model
|
# Create a set of unique text models starting with default model
|
||||||
unique_text_models = cls.text_models.copy()
|
unique_text_models = cls.text_models.copy()
|
||||||
|
|
||||||
@@ -193,6 +202,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
cache: bool = False,
|
cache: bool = False,
|
||||||
referrer: str = "https://gpt4free.github.io/",
|
referrer: str = "https://gpt4free.github.io/",
|
||||||
|
extra_body: dict = {},
|
||||||
# Image generation parameters
|
# Image generation parameters
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
aspect_ratio: str = "1:1",
|
aspect_ratio: str = "1:1",
|
||||||
@@ -244,7 +254,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
enhance=enhance,
|
enhance=enhance,
|
||||||
safe=safe,
|
safe=safe,
|
||||||
n=n,
|
n=n,
|
||||||
referrer=referrer
|
referrer=referrer,
|
||||||
|
extra_body=extra_body
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
@@ -273,6 +284,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
extra_parameters=extra_parameters,
|
extra_parameters=extra_parameters,
|
||||||
referrer=referrer,
|
referrer=referrer,
|
||||||
|
extra_body=extra_body,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
@@ -293,18 +305,20 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
enhance: bool,
|
enhance: bool,
|
||||||
safe: bool,
|
safe: bool,
|
||||||
n: int,
|
n: int,
|
||||||
referrer: str
|
referrer: str,
|
||||||
|
extra_body: dict
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
params = use_aspect_ratio({
|
extra_body = use_aspect_ratio({
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
"model": model,
|
"model": model,
|
||||||
"nologo": str(nologo).lower(),
|
"nologo": str(nologo).lower(),
|
||||||
"private": str(private).lower(),
|
"private": str(private).lower(),
|
||||||
"enhance": str(enhance).lower(),
|
"enhance": str(enhance).lower(),
|
||||||
"safe": str(safe).lower()
|
"safe": str(safe).lower(),
|
||||||
|
**extra_body
|
||||||
}, aspect_ratio)
|
}, aspect_ratio)
|
||||||
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in params.items() if v is not None)
|
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in extra_body.items() if v is not None)
|
||||||
prompt = quote_plus(prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
|
prompt = quote_plus(prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
|
||||||
url = f"{cls.image_api_endpoint}prompt/{prompt}?{query}"
|
url = f"{cls.image_api_endpoint}prompt/{prompt}?{query}"
|
||||||
def get_image_url(i: int, seed: Optional[int] = None):
|
def get_image_url(i: int, seed: Optional[int] = None):
|
||||||
@@ -344,6 +358,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
extra_parameters: list[str],
|
extra_parameters: list[str],
|
||||||
referrer: str,
|
referrer: str,
|
||||||
|
extra_body: dict,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if not cache and seed is None:
|
if not cache and seed is None:
|
||||||
@@ -357,43 +372,46 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
stream = False
|
stream = False
|
||||||
else:
|
else:
|
||||||
url = cls.openai_endpoint
|
url = cls.openai_endpoint
|
||||||
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
|
extra_body.update({param: kwargs[param] for param in extra_parameters if param in kwargs})
|
||||||
data = filter_none(**{
|
data = filter_none(
|
||||||
"messages": list(render_messages(messages, media)),
|
messages=list(render_messages(messages, media)),
|
||||||
"model": model,
|
model=model,
|
||||||
"temperature": temperature,
|
temperature=temperature,
|
||||||
"presence_penalty": presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
"top_p": top_p,
|
top_p=top_p,
|
||||||
"frequency_penalty": frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
"response_format": response_format,
|
response_format=response_format,
|
||||||
"stream": stream,
|
stream=stream,
|
||||||
"seed": seed,
|
seed=seed,
|
||||||
"cache": cache,
|
cache=cache,
|
||||||
**extra_parameters
|
**extra_body
|
||||||
})
|
)
|
||||||
async with session.post(url, json=data, headers={"referer": referrer}) as response:
|
async with session.post(url, json=data, headers={"referer": referrer}) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
if response.headers["content-type"].startswith("text/plain"):
|
if response.headers["content-type"].startswith("text/plain"):
|
||||||
yield await response.text()
|
yield await response.text()
|
||||||
return
|
return
|
||||||
elif response.headers["content-type"].startswith("text/event-stream"):
|
elif response.headers["content-type"].startswith("text/event-stream"):
|
||||||
async for line in response.content:
|
reasoning = False
|
||||||
if line.startswith(b"data: "):
|
async for result in see_stream(response.content):
|
||||||
if line[6:].startswith(b"[DONE]"):
|
if "error" in result:
|
||||||
break
|
raise ResponseError(result["error"].get("message", result["error"]))
|
||||||
result = json.loads(line[6:])
|
if result.get("usage") is not None:
|
||||||
if "error" in result:
|
yield Usage(**result["usage"])
|
||||||
raise ResponseError(result["error"].get("message", result["error"]))
|
choices = result.get("choices", [{}])
|
||||||
if result.get("usage") is not None:
|
choice = choices.pop() if choices else {}
|
||||||
yield Usage(**result["usage"])
|
content = choice.get("delta", {}).get("content")
|
||||||
choices = result.get("choices", [{}])
|
if content:
|
||||||
choice = choices.pop() if choices else {}
|
yield content
|
||||||
content = choice.get("delta", {}).get("content")
|
reasoning_content = choice.get("delta", {}).get("reasoning_content")
|
||||||
if content:
|
if reasoning_content:
|
||||||
yield content
|
reasoning = True
|
||||||
finish_reason = choice.get("finish_reason")
|
yield Reasoning(reasoning_content)
|
||||||
if finish_reason:
|
finish_reason = choice.get("finish_reason")
|
||||||
yield FinishReason(finish_reason)
|
if finish_reason:
|
||||||
|
yield FinishReason(finish_reason)
|
||||||
|
if reasoning:
|
||||||
|
yield Reasoning(status="Done")
|
||||||
elif response.headers["content-type"].startswith("application/json"):
|
elif response.headers["content-type"].startswith("application/json"):
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
if "choices" in result:
|
if "choices" in result:
|
||||||
|
@@ -10,7 +10,7 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
debug.error("Deprecated providers not loaded:", e)
|
debug.error("Deprecated providers not loaded:", e)
|
||||||
from .needs_auth import *
|
from .needs_auth import *
|
||||||
from .template import OpenaiTemplate, BackendApi
|
from .template import OpenaiTemplate, BackendApi, Puter
|
||||||
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference, HuggingFaceMedia
|
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference, HuggingFaceMedia
|
||||||
from .har import HarProvider
|
from .har import HarProvider
|
||||||
try:
|
try:
|
||||||
|
@@ -12,6 +12,7 @@ from ..helper import get_last_user_message
|
|||||||
from ..openai.har_file import get_headers
|
from ..openai.har_file import get_headers
|
||||||
|
|
||||||
class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
label = "LM Arena"
|
||||||
url = "https://lmarena.ai"
|
url = "https://lmarena.ai"
|
||||||
working = True
|
working = True
|
||||||
default_model = "chatgpt-4o-latest-20250326"
|
default_model = "chatgpt-4o-latest-20250326"
|
||||||
|
@@ -74,8 +74,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
if "hf-chat" in cookies:
|
if "hf-chat" in cookies:
|
||||||
yield AuthResult(
|
yield AuthResult(
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
impersonate="chrome",
|
headers=DEFAULT_HEADERS,
|
||||||
headers=DEFAULT_HEADERS
|
impersonate="chrome"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if cls.needs_auth:
|
if cls.needs_auth:
|
||||||
@@ -89,9 +89,11 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield AuthResult(
|
yield AuthResult(
|
||||||
cookies = {
|
cookies={
|
||||||
"hf-chat": str(uuid.uuid4()) # Generate a session ID
|
"hf-chat": str(uuid.uuid4()) # Generate a session ID
|
||||||
}
|
},
|
||||||
|
headers=DEFAULT_HEADERS,
|
||||||
|
impersonate="chrome"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@@ -78,7 +78,7 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||||||
api_base: str = None,
|
api_base: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
max_inputs_lenght: int = 10000,
|
# max_inputs_lenght: int = 10000,
|
||||||
media: MediaListType = None,
|
media: MediaListType = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -117,5 +117,6 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||||||
continue
|
continue
|
||||||
if error is not None:
|
if error is not None:
|
||||||
raise error
|
raise error
|
||||||
def calculate_lenght(messages: Messages) -> int:
|
|
||||||
return sum([len(message["content"]) + 16 for message in messages])
|
# def calculate_lenght(messages: Messages) -> int:
|
||||||
|
# return sum([len(message["content"]) + 16 for message in messages])
|
@@ -77,7 +77,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
temperature: float = None,
|
temperature: float = None,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
action: str = None,
|
action: str = None,
|
||||||
extra_data: dict = {},
|
extra_body: dict = {},
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
aspect_ratio: str = None,
|
aspect_ratio: str = None,
|
||||||
width: int = None,
|
width: int = None,
|
||||||
@@ -94,10 +94,10 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
}
|
}
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
image_extra_data = use_aspect_ratio({
|
image_extra_body = use_aspect_ratio({
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
**extra_data
|
**extra_body
|
||||||
}, aspect_ratio)
|
}, aspect_ratio)
|
||||||
async with StreamSession(
|
async with StreamSession(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -110,7 +110,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"response_format": "url",
|
"response_format": "url",
|
||||||
"prompt": format_image_prompt(messages, prompt),
|
"prompt": format_image_prompt(messages, prompt),
|
||||||
"model": model,
|
"model": model,
|
||||||
**image_extra_data
|
**image_extra_body
|
||||||
}
|
}
|
||||||
async with session.post(provider_together_urls[model], json=data) as response:
|
async with session.post(provider_together_urls[model], json=data) as response:
|
||||||
if response.status == 404:
|
if response.status == 404:
|
||||||
@@ -126,7 +126,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"return_full_text": False,
|
"return_full_text": False,
|
||||||
"max_new_tokens": max_tokens,
|
"max_new_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
**extra_data
|
**extra_body
|
||||||
}
|
}
|
||||||
do_continue = action == "continue"
|
do_continue = action == "continue"
|
||||||
if payload is None:
|
if payload is None:
|
||||||
@@ -135,7 +135,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if pipeline_tag == "text-to-image":
|
if pipeline_tag == "text-to-image":
|
||||||
stream = False
|
stream = False
|
||||||
inputs = format_image_prompt(messages, prompt)
|
inputs = format_image_prompt(messages, prompt)
|
||||||
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **image_extra_data}}
|
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **image_extra_body}}
|
||||||
elif pipeline_tag in ("text-generation", "image-text-to-text"):
|
elif pipeline_tag in ("text-generation", "image-text-to-text"):
|
||||||
model_type = None
|
model_type = None
|
||||||
if "config" in model_data and "model_type" in model_data["config"]:
|
if "config" in model_data and "model_type" in model_data["config"]:
|
||||||
|
@@ -55,9 +55,10 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
model["id"]: [
|
model["id"]: [
|
||||||
provider.get("task")
|
provider.get("task")
|
||||||
for provider in model.get("inferenceProviderMapping")
|
for provider in model.get("inferenceProviderMapping")
|
||||||
].pop()
|
]
|
||||||
for model in models
|
for model in models
|
||||||
}
|
}
|
||||||
|
cls.task_mapping = {model: task[0] for model, task in cls.task_mapping.items() if task}
|
||||||
prepend_models = []
|
prepend_models = []
|
||||||
for model, provider_keys in providers.items():
|
for model, provider_keys in providers.items():
|
||||||
task = cls.task_mapping.get(model)
|
task = cls.task_mapping.get(model)
|
||||||
@@ -97,7 +98,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
extra_data: dict = {},
|
extra_body: dict = {},
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 0,
|
timeout: int = 0,
|
||||||
@@ -128,7 +129,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if key in ["replicate", "together", "hf-inference"]
|
if key in ["replicate", "together", "hf-inference"]
|
||||||
}
|
}
|
||||||
provider_mapping = {**new_mapping, **provider_mapping}
|
provider_mapping = {**new_mapping, **provider_mapping}
|
||||||
async def generate(extra_data: dict, aspect_ratio: str = None):
|
async def generate(extra_body: dict, aspect_ratio: str = None):
|
||||||
last_response = None
|
last_response = None
|
||||||
for provider_key, provider in provider_mapping.items():
|
for provider_key, provider in provider_mapping.items():
|
||||||
if selected_provider is not None and selected_provider != provider_key:
|
if selected_provider is not None and selected_provider != provider_key:
|
||||||
@@ -143,24 +144,24 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
if aspect_ratio is None:
|
if aspect_ratio is None:
|
||||||
aspect_ratio = "1:1" if task == "text-to-image" else "16:9"
|
aspect_ratio = "1:1" if task == "text-to-image" else "16:9"
|
||||||
extra_data_image = use_aspect_ratio({
|
extra_body_image = use_aspect_ratio({
|
||||||
**extra_data,
|
**extra_body,
|
||||||
"height": height,
|
"height": height,
|
||||||
"width": width,
|
"width": width,
|
||||||
}, aspect_ratio)
|
}, aspect_ratio)
|
||||||
extra_data_video = {}
|
extra_body_video = {}
|
||||||
if task == "text-to-video" and provider_key != "novita":
|
if task == "text-to-video" and provider_key != "novita":
|
||||||
extra_data_video = {
|
extra_body_video = {
|
||||||
"num_inference_steps": 20,
|
"num_inference_steps": 20,
|
||||||
"resolution": resolution,
|
"resolution": resolution,
|
||||||
"aspect_ratio": aspect_ratio,
|
"aspect_ratio": aspect_ratio,
|
||||||
**extra_data
|
**extra_body
|
||||||
}
|
}
|
||||||
url = f"{api_base}/{provider_id}"
|
url = f"{api_base}/{provider_id}"
|
||||||
data = {
|
data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**{"width": width, "height": height},
|
**{"width": width, "height": height},
|
||||||
**(extra_data_video if task == "text-to-video" else extra_data_image),
|
**(extra_body_video if task == "text-to-video" else extra_body_image),
|
||||||
}
|
}
|
||||||
if provider_key == "fal-ai" and task == "text-to-image":
|
if provider_key == "fal-ai" and task == "text-to-image":
|
||||||
data = {
|
data = {
|
||||||
@@ -168,7 +169,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"height": height,
|
"height": height,
|
||||||
"width": width,
|
"width": width,
|
||||||
}, aspect_ratio),
|
}, aspect_ratio),
|
||||||
**extra_data
|
**extra_body
|
||||||
}
|
}
|
||||||
elif provider_key == "novita":
|
elif provider_key == "novita":
|
||||||
url = f"{api_base}/v3/hf/{provider_id}"
|
url = f"{api_base}/v3/hf/{provider_id}"
|
||||||
@@ -213,7 +214,13 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if "video" in result:
|
if "video" in result:
|
||||||
return provider_info, VideoResponse(result.get("video").get("url", result.get("video").get("video_url")), prompt)
|
return provider_info, VideoResponse(result.get("video").get("url", result.get("video").get("video_url")), prompt)
|
||||||
elif task == "text-to-image":
|
elif task == "text-to-image":
|
||||||
return provider_info, ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
|
try:
|
||||||
|
return provider_info, ImageResponse([
|
||||||
|
item["url"] if isinstance(item, dict) else item
|
||||||
|
for item in result.get("images", result.get("data", result.get("output")))
|
||||||
|
], prompt)
|
||||||
|
except:
|
||||||
|
raise ValueError(f"Unexpected response: {result}")
|
||||||
elif task == "text-to-video" and result.get("output") is not None:
|
elif task == "text-to-video" and result.get("output") is not None:
|
||||||
return provider_info, VideoResponse(result["output"], prompt)
|
return provider_info, VideoResponse(result["output"], prompt)
|
||||||
raise ValueError(f"Unexpected response: {result}")
|
raise ValueError(f"Unexpected response: {result}")
|
||||||
@@ -227,7 +234,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
started = time.time()
|
started = time.time()
|
||||||
while n > 0:
|
while n > 0:
|
||||||
n -= 1
|
n -= 1
|
||||||
task = asyncio.create_task(generate(extra_data, aspect_ratio))
|
task = asyncio.create_task(generate(extra_body, aspect_ratio))
|
||||||
background_tasks.add(task)
|
background_tasks.add(task)
|
||||||
running_tasks.add(task)
|
running_tasks.add(task)
|
||||||
task.add_done_callback(running_tasks.discard)
|
task.add_done_callback(running_tasks.discard)
|
||||||
|
@@ -6,6 +6,7 @@ import uuid
|
|||||||
|
|
||||||
from ...typing import AsyncResult, Messages
|
from ...typing import AsyncResult, Messages
|
||||||
from ...providers.response import Reasoning, JsonConversation
|
from ...providers.response import Reasoning, JsonConversation
|
||||||
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import get_last_user_message
|
from ..helper import get_last_user_message
|
||||||
from ... import debug
|
from ... import debug
|
||||||
@@ -82,6 +83,7 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
# Send join request
|
# Send join request
|
||||||
async with session.post(cls.api_endpoint, headers=headers_join, json=payload_join) as response:
|
async with session.post(cls.api_endpoint, headers=headers_join, json=payload_join) as response:
|
||||||
|
await raise_for_status(response)
|
||||||
(await response.json())['event_id']
|
(await response.json())['event_id']
|
||||||
|
|
||||||
# Prepare data stream request
|
# Prepare data stream request
|
||||||
|
@@ -73,7 +73,7 @@ class Anthropic(OpenaiAPI):
|
|||||||
headers: dict = None,
|
headers: dict = None,
|
||||||
impersonate: str = None,
|
impersonate: str = None,
|
||||||
tools: Optional[list] = None,
|
tools: Optional[list] = None,
|
||||||
extra_data: dict = {},
|
extra_body: dict = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
@@ -121,7 +121,7 @@ class Anthropic(OpenaiAPI):
|
|||||||
system=system,
|
system=system,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
**extra_data
|
**extra_body
|
||||||
)
|
)
|
||||||
async with session.post(f"{cls.api_base}/messages", json=data) as response:
|
async with session.post(f"{cls.api_base}/messages", json=data) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
@@ -89,7 +89,7 @@ class DeepInfra(OpenaiTemplate):
|
|||||||
api_base: str = "https://api.deepinfra.com/v1/inference",
|
api_base: str = "https://api.deepinfra.com/v1/inference",
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 180,
|
timeout: int = 180,
|
||||||
extra_data: dict = {},
|
extra_body: dict = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
headers = {
|
headers = {
|
||||||
@@ -115,7 +115,7 @@ class DeepInfra(OpenaiTemplate):
|
|||||||
timeout=timeout
|
timeout=timeout
|
||||||
) as session:
|
) as session:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
data = {"prompt": prompt, **extra_data}
|
data = {"prompt": prompt, **extra_body}
|
||||||
data = {"input": data} if model == cls.default_model else data
|
data = {"input": data} if model == cls.default_model else data
|
||||||
async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response:
|
async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
@@ -29,7 +29,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
top_k: float = None,
|
top_k: float = None,
|
||||||
stop: list = None,
|
stop: list = None,
|
||||||
extra_data: dict = {},
|
extra_body: dict = {},
|
||||||
headers: dict = {
|
headers: dict = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
},
|
},
|
||||||
@@ -60,7 +60,7 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stop_sequences=",".join(stop) if stop else None
|
stop_sequences=",".join(stop) if stop else None
|
||||||
),
|
),
|
||||||
**extra_data
|
**extra_body
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
url = f"{api_base.rstrip('/')}/{model}/predictions"
|
url = f"{api_base.rstrip('/')}/{model}/predictions"
|
||||||
|
@@ -48,4 +48,4 @@ class ThebApi(OpenaiTemplate):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return super().create_async_generator(model, messages, extra_data=data, **kwargs)
|
return super().create_async_generator(model, messages, extra_body=data, **kwargs)
|
||||||
|
@@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
headers: dict = None,
|
headers: dict = None,
|
||||||
impersonate: str = None,
|
impersonate: str = None,
|
||||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
|
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
|
||||||
extra_data: dict = {},
|
extra_body: dict = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if api_key is None and cls.api_key is not None:
|
if api_key is None and cls.api_key is not None:
|
||||||
@@ -107,7 +107,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**extra_parameters,
|
**extra_parameters,
|
||||||
**extra_data
|
**extra_body
|
||||||
)
|
)
|
||||||
if api_endpoint is None:
|
if api_endpoint is None:
|
||||||
api_endpoint = cls.api_endpoint
|
api_endpoint = cls.api_endpoint
|
||||||
|
72
g4f/Provider/template/Puter.py
Normal file
72
g4f/Provider/template/Puter.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ...typing import Messages, AsyncResult
|
||||||
|
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
|
||||||
|
class Puter(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
label = "Puter.js AI (live)"
|
||||||
|
working = True
|
||||||
|
models = [
|
||||||
|
{"group": "ChatGPT", "models": [
|
||||||
|
"gpt-4o-mini",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4.1",
|
||||||
|
"gpt-4.1-mini",
|
||||||
|
"gpt-4.1-nano",
|
||||||
|
"gpt-4.5-preview"
|
||||||
|
]},
|
||||||
|
{"group": "O Models", "models": [
|
||||||
|
"o1",
|
||||||
|
"o1-mini",
|
||||||
|
"o1-pro",
|
||||||
|
"o3",
|
||||||
|
"o3-mini",
|
||||||
|
"o4-mini"
|
||||||
|
]},
|
||||||
|
{"group": "Anthropic Claude", "models": [
|
||||||
|
"claude-3-7-sonnet",
|
||||||
|
"claude-3-5-sonnet"
|
||||||
|
]},
|
||||||
|
{"group": "Deepseek", "models": [
|
||||||
|
"deepseek-chat",
|
||||||
|
"deepseek-reasoner"
|
||||||
|
]},
|
||||||
|
{"group": "Google Gemini", "models": [
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"gemini-1.5-flash"
|
||||||
|
]},
|
||||||
|
{"group": "Meta Llama", "models": [
|
||||||
|
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
|
"meta-llama/Meta-Llama--70B-Instruct-Turbo",
|
||||||
|
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"
|
||||||
|
]},
|
||||||
|
{"group": "Other Models", "models": [
|
||||||
|
"mistral-large-latest",
|
||||||
|
"pixtral-large-latest",
|
||||||
|
"codestral-latest",
|
||||||
|
"google/gemma-2-27b-it",
|
||||||
|
"grok-beta"
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_grouped_models(cls) -> dict[str, list[str]]:
|
||||||
|
return cls.models
|
||||||
|
|
||||||
|
def get_models(cls) -> list[str]:
|
||||||
|
models = []
|
||||||
|
for model in cls.models:
|
||||||
|
if "models" in model:
|
||||||
|
models.extend(model["models"])
|
||||||
|
else:
|
||||||
|
models.append(model)
|
||||||
|
return models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_async_generator(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
raise NotImplementedError()
|
@@ -1,2 +1,3 @@
|
|||||||
from .BackendApi import BackendApi
|
from .BackendApi import BackendApi
|
||||||
from .OpenaiTemplate import OpenaiTemplate
|
from .OpenaiTemplate import OpenaiTemplate
|
||||||
|
from .Puter import Puter
|
@@ -42,7 +42,7 @@ except ImportError:
|
|||||||
|
|
||||||
import g4f
|
import g4f
|
||||||
import g4f.debug
|
import g4f.debug
|
||||||
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
|
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, ClientResponse, convert_to_provider
|
||||||
from g4f.providers.response import BaseConversation, JsonConversation
|
from g4f.providers.response import BaseConversation, JsonConversation
|
||||||
from g4f.client.helper import filter_none
|
from g4f.client.helper import filter_none
|
||||||
from g4f.image import is_data_an_media, EXTENSIONS_MAP
|
from g4f.image import is_data_an_media, EXTENSIONS_MAP
|
||||||
@@ -60,7 +60,8 @@ from .stubs import (
|
|||||||
ProviderResponseModel, ModelResponseModel,
|
ProviderResponseModel, ModelResponseModel,
|
||||||
ErrorResponseModel, ProviderResponseDetailModel,
|
ErrorResponseModel, ProviderResponseDetailModel,
|
||||||
FileResponseModel, UploadResponseModel,
|
FileResponseModel, UploadResponseModel,
|
||||||
TranscriptionResponseModel, AudioSpeechConfig
|
TranscriptionResponseModel, AudioSpeechConfig,
|
||||||
|
ResponsesConfig
|
||||||
)
|
)
|
||||||
from g4f import debug
|
from g4f import debug
|
||||||
|
|
||||||
@@ -257,7 +258,7 @@ class Api:
|
|||||||
"image": bool(getattr(provider, "image_models", False)),
|
"image": bool(getattr(provider, "image_models", False)),
|
||||||
"provider": True,
|
"provider": True,
|
||||||
} for provider_name, provider in Provider.ProviderUtils.convert.items()
|
} for provider_name, provider in Provider.ProviderUtils.convert.items()
|
||||||
if provider.working and provider_name != "Custom"
|
if provider.working and provider_name not in ("Custom", "Puter")
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,12 +325,9 @@ class Api:
|
|||||||
config.api_key = credentials.credentials
|
config.api_key = credentials.credentials
|
||||||
|
|
||||||
conversation = config.conversation
|
conversation = config.conversation
|
||||||
return_conversation = config.return_conversation
|
|
||||||
if conversation:
|
if conversation:
|
||||||
conversation = JsonConversation(**conversation)
|
conversation = JsonConversation(**conversation)
|
||||||
return_conversation = True
|
|
||||||
elif config.conversation_id is not None and config.provider is not None:
|
elif config.conversation_id is not None and config.provider is not None:
|
||||||
return_conversation = True
|
|
||||||
if config.conversation_id in self.conversations:
|
if config.conversation_id in self.conversations:
|
||||||
if config.provider in self.conversations[config.conversation_id]:
|
if config.provider in self.conversations[config.conversation_id]:
|
||||||
conversation = self.conversations[config.conversation_id][config.provider]
|
conversation = self.conversations[config.conversation_id][config.provider]
|
||||||
@@ -359,7 +357,6 @@ class Api:
|
|||||||
**config.dict(exclude_none=True),
|
**config.dict(exclude_none=True),
|
||||||
**{
|
**{
|
||||||
"conversation_id": None,
|
"conversation_id": None,
|
||||||
"return_conversation": return_conversation,
|
|
||||||
"conversation": conversation
|
"conversation": conversation
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -407,6 +404,58 @@ class Api:
|
|||||||
):
|
):
|
||||||
return await chat_completions(config, credentials, provider)
|
return await chat_completions(config, credentials, provider)
|
||||||
|
|
||||||
|
responses = {
|
||||||
|
HTTP_200_OK: {"model": ClientResponse},
|
||||||
|
HTTP_401_UNAUTHORIZED: {"model": ErrorResponseModel},
|
||||||
|
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||||
|
HTTP_422_UNPROCESSABLE_ENTITY: {"model": ErrorResponseModel},
|
||||||
|
HTTP_500_INTERNAL_SERVER_ERROR: {"model": ErrorResponseModel},
|
||||||
|
}
|
||||||
|
@self.app.post("/v1/responses", responses=responses)
|
||||||
|
async def v1_responses(
|
||||||
|
config: ResponsesConfig,
|
||||||
|
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None,
|
||||||
|
provider: str = None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if config.provider is None:
|
||||||
|
config.provider = AppConfig.provider if provider is None else provider
|
||||||
|
if credentials is not None and credentials.credentials != "secret":
|
||||||
|
config.api_key = credentials.credentials
|
||||||
|
|
||||||
|
conversation = None
|
||||||
|
if config.conversation is not None:
|
||||||
|
conversation = JsonConversation(**config.conversation)
|
||||||
|
|
||||||
|
return await self.client.responses.create(
|
||||||
|
**filter_none(
|
||||||
|
**{
|
||||||
|
"model": AppConfig.model,
|
||||||
|
"proxy": AppConfig.proxy,
|
||||||
|
**config.dict(exclude_none=True),
|
||||||
|
"conversation": conversation
|
||||||
|
},
|
||||||
|
ignored=AppConfig.ignored_providers
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except (ModelNotFoundError, ProviderNotFoundError) as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return ErrorResponse.from_exception(e, config, HTTP_404_NOT_FOUND)
|
||||||
|
except (MissingAuthError, NoValidHarFileError) as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return ErrorResponse.from_exception(e, config, HTTP_401_UNAUTHORIZED)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return ErrorResponse.from_exception(e, config, HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
@self.app.post("/api/{provider}/responses", responses=responses)
|
||||||
|
async def provider_responses(
|
||||||
|
provider: str,
|
||||||
|
config: ChatCompletionsConfig,
|
||||||
|
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None,
|
||||||
|
):
|
||||||
|
return await v1_responses(config, credentials, provider)
|
||||||
|
|
||||||
@self.app.post("/api/{provider}/{conversation_id}/chat/completions", responses=responses)
|
@self.app.post("/api/{provider}/{conversation_id}/chat/completions", responses=responses)
|
||||||
async def provider_chat_completions(
|
async def provider_chat_completions(
|
||||||
provider: str,
|
provider: str,
|
||||||
@@ -563,7 +612,7 @@ class Api:
|
|||||||
model=config.model,
|
model=config.model,
|
||||||
provider=config.provider if provider is None else provider,
|
provider=config.provider if provider is None else provider,
|
||||||
prompt=config.input,
|
prompt=config.input,
|
||||||
audio=filter_none(voice=config.voice, format=config.response_format),
|
audio=filter_none(voice=config.voice, format=config.response_format, language=config.language),
|
||||||
**filter_none(
|
**filter_none(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
@@ -5,16 +5,11 @@ from typing import Union, Optional
|
|||||||
|
|
||||||
from ..typing import Messages
|
from ..typing import Messages
|
||||||
|
|
||||||
class ChatCompletionsConfig(BaseModel):
|
class RequestConfig(BaseModel):
|
||||||
messages: Messages = Field(examples=[[{"role": "system", "content": ""}, {"role": "user", "content": ""}]])
|
|
||||||
model: str = Field(default="")
|
model: str = Field(default="")
|
||||||
provider: Optional[str] = None
|
provider: Optional[str] = None
|
||||||
stream: bool = False
|
|
||||||
image: Optional[str] = None
|
|
||||||
image_name: Optional[str] = None
|
|
||||||
images: Optional[list[tuple[str, str]]] = None
|
|
||||||
media: Optional[list[tuple[str, str]]] = None
|
media: Optional[list[tuple[str, str]]] = None
|
||||||
modalities: Optional[list[str]] = ["text", "audio"]
|
modalities: Optional[list[str]] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = None
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = None
|
||||||
@@ -25,10 +20,7 @@ class ChatCompletionsConfig(BaseModel):
|
|||||||
api_base: str = None
|
api_base: str = None
|
||||||
web_search: Optional[bool] = None
|
web_search: Optional[bool] = None
|
||||||
proxy: Optional[str] = None
|
proxy: Optional[str] = None
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
conversation: Optional[dict] = None
|
conversation: Optional[dict] = None
|
||||||
return_conversation: bool = True
|
|
||||||
history_disabled: Optional[bool] = None
|
|
||||||
timeout: Optional[int] = None
|
timeout: Optional[int] = None
|
||||||
tool_calls: list = Field(default=[], examples=[[
|
tool_calls: list = Field(default=[], examples=[[
|
||||||
{
|
{
|
||||||
@@ -39,15 +31,26 @@ class ChatCompletionsConfig(BaseModel):
|
|||||||
"type": "function"
|
"type": "function"
|
||||||
}
|
}
|
||||||
]])
|
]])
|
||||||
tools: list = None
|
|
||||||
parallel_tool_calls: bool = None
|
|
||||||
tool_choice: Optional[str] = None
|
|
||||||
reasoning_effort: Optional[str] = None
|
reasoning_effort: Optional[str] = None
|
||||||
logit_bias: Optional[dict] = None
|
logit_bias: Optional[dict] = None
|
||||||
modalities: Optional[list[str]] = None
|
modalities: Optional[list[str]] = None
|
||||||
audio: Optional[dict] = None
|
audio: Optional[dict] = None
|
||||||
response_format: Optional[dict] = None
|
response_format: Optional[dict] = None
|
||||||
extra_data: Optional[dict] = None
|
extra_body: Optional[dict] = None
|
||||||
|
|
||||||
|
class ChatCompletionsConfig(RequestConfig):
|
||||||
|
messages: Messages = Field(examples=[[{"role": "system", "content": ""}, {"role": "user", "content": ""}]])
|
||||||
|
stream: bool = False
|
||||||
|
image: Optional[str] = None
|
||||||
|
image_name: Optional[str] = None
|
||||||
|
images: Optional[list[tuple[str, str]]] = None
|
||||||
|
tools: list = None
|
||||||
|
parallel_tool_calls: bool = None
|
||||||
|
tool_choice: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
|
||||||
|
class ResponsesConfig(RequestConfig):
|
||||||
|
input: Union[Messages, str]
|
||||||
|
|
||||||
class ImageGenerationConfig(BaseModel):
|
class ImageGenerationConfig(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
@@ -126,4 +129,5 @@ class AudioSpeechConfig(BaseModel):
|
|||||||
provider: Optional[str] = None
|
provider: Optional[str] = None
|
||||||
voice: Optional[str] = None
|
voice: Optional[str] = None
|
||||||
instrcutions: str = "Speech this text in a natural way."
|
instrcutions: str = "Speech this text in a natural way."
|
||||||
response_format: Optional[str] = None
|
response_format: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
|
|||||||
from ..providers.any_provider import AnyProvider
|
from ..providers.any_provider import AnyProvider
|
||||||
from ..Provider import OpenaiAccount, PollinationsImage
|
from ..Provider import OpenaiAccount, PollinationsImage
|
||||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel, ClientResponse
|
||||||
from .models import ClientModels
|
from .models import ClientModels
|
||||||
from .types import IterResponse, Client as BaseClient
|
from .types import IterResponse, Client as BaseClient
|
||||||
from .service import convert_to_provider
|
from .service import convert_to_provider
|
||||||
@@ -79,7 +79,7 @@ def iter_response(
|
|||||||
for chunk in response:
|
for chunk in response:
|
||||||
if isinstance(chunk, FinishReason):
|
if isinstance(chunk, FinishReason):
|
||||||
finish_reason = chunk.reason
|
finish_reason = chunk.reason
|
||||||
break
|
continue
|
||||||
elif isinstance(chunk, JsonConversation):
|
elif isinstance(chunk, JsonConversation):
|
||||||
conversation = chunk
|
conversation = chunk
|
||||||
continue
|
continue
|
||||||
@@ -180,7 +180,7 @@ async def async_iter_response(
|
|||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
if isinstance(chunk, FinishReason):
|
if isinstance(chunk, FinishReason):
|
||||||
finish_reason = chunk.reason
|
finish_reason = chunk.reason
|
||||||
break
|
continue
|
||||||
elif isinstance(chunk, JsonConversation):
|
elif isinstance(chunk, JsonConversation):
|
||||||
conversation = chunk
|
conversation = chunk
|
||||||
continue
|
continue
|
||||||
@@ -250,6 +250,53 @@ async def async_iter_response(
|
|||||||
finally:
|
finally:
|
||||||
await safe_aclose(response)
|
await safe_aclose(response)
|
||||||
|
|
||||||
|
async def async_response(
|
||||||
|
response: AsyncIterator[Union[str, ResponseType]]
|
||||||
|
) -> ClientResponse:
|
||||||
|
content = ""
|
||||||
|
response_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||||
|
idx = 0
|
||||||
|
usage = None
|
||||||
|
provider: ProviderInfo = None
|
||||||
|
conversation: JsonConversation = None
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk, FinishReason):
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, JsonConversation):
|
||||||
|
conversation = chunk
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, ToolCalls):
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, Usage):
|
||||||
|
usage = chunk
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, ProviderInfo):
|
||||||
|
provider = chunk
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, HiddenResponse):
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, Exception):
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = add_chunk(content, chunk)
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
if usage is None:
|
||||||
|
usage = UsageModel.model_construct(completion_tokens=idx, total_tokens=idx)
|
||||||
|
else:
|
||||||
|
usage = UsageModel.model_construct(**usage.get_dict())
|
||||||
|
|
||||||
|
response = ClientResponse.model_construct(
|
||||||
|
content, response_id, int(time.time()), usage=usage, conversation=conversation
|
||||||
|
)
|
||||||
|
if provider is not None:
|
||||||
|
response.provider = provider.name
|
||||||
|
response.model = provider.model
|
||||||
|
return response
|
||||||
|
|
||||||
async def async_iter_append_model_and_provider(
|
async def async_iter_append_model_and_provider(
|
||||||
response: AsyncChatCompletionResponseType,
|
response: AsyncChatCompletionResponseType,
|
||||||
last_model: str,
|
last_model: str,
|
||||||
@@ -574,6 +621,7 @@ class AsyncClient(BaseClient):
|
|||||||
self.models: ClientModels = ClientModels(self, provider, media_provider)
|
self.models: ClientModels = ClientModels(self, provider, media_provider)
|
||||||
self.images: AsyncImages = AsyncImages(self, media_provider)
|
self.images: AsyncImages = AsyncImages(self, media_provider)
|
||||||
self.media: AsyncImages = self.images
|
self.media: AsyncImages = self.images
|
||||||
|
self.responses: AsyncResponses = AsyncResponses(self, provider)
|
||||||
|
|
||||||
class AsyncChat:
|
class AsyncChat:
|
||||||
completions: AsyncCompletions
|
completions: AsyncCompletions
|
||||||
@@ -673,3 +721,51 @@ class AsyncImages(Images):
|
|||||||
return await self.async_create_variation(
|
return await self.async_create_variation(
|
||||||
image, model, provider, response_format, **kwargs
|
image, model, provider, response_format, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class AsyncResponses():
|
||||||
|
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||||
|
self.client: AsyncClient = client
|
||||||
|
self.provider: ProviderType = provider
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
input: str,
|
||||||
|
model: str = "",
|
||||||
|
provider: Optional[ProviderType] = None,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ClientResponse:
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [{"role": "user", "content": input}]
|
||||||
|
if instructions is not None:
|
||||||
|
input = [{"role": "developer", "content": instructions}] + input
|
||||||
|
for idx, message in enumerate(input):
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
for key, value in enumerate(message["content"]):
|
||||||
|
if isinstance(value, dict) and value.get("type") == "input_text":
|
||||||
|
message["content"][key] = {"type": "text", "text": value.get("text")}
|
||||||
|
input[idx] = {"role": message["role"], "content": message["content"]}
|
||||||
|
resolve_media(kwargs)
|
||||||
|
if hasattr(model, "name"):
|
||||||
|
model = model.name
|
||||||
|
if provider is None:
|
||||||
|
provider = self.provider
|
||||||
|
if provider is None:
|
||||||
|
provider = AnyProvider
|
||||||
|
if isinstance(provider, str):
|
||||||
|
provider = convert_to_provider(provider)
|
||||||
|
|
||||||
|
response = async_iter_run_tools(
|
||||||
|
provider,
|
||||||
|
model=model,
|
||||||
|
messages=input,
|
||||||
|
**filter_none(
|
||||||
|
proxy=self.client.proxy if proxy is None else proxy,
|
||||||
|
api_key=self.client.api_key if api_key is None else api_key
|
||||||
|
),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return await async_response(response)
|
@@ -108,11 +108,33 @@ class ChatCompletionChunk(BaseModel):
|
|||||||
return conversation.get_dict()
|
return conversation.get_dict()
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
class ResponseMessage(BaseModel):
|
||||||
|
type: str = "message"
|
||||||
|
role: str
|
||||||
|
content: list[ResponseMessageContent]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_construct(cls, content: str):
|
||||||
|
return super().model_construct(role="assistant", content=[ResponseMessageContent.model_construct(content)])
|
||||||
|
|
||||||
|
class ResponseMessageContent(BaseModel):
|
||||||
|
type: str
|
||||||
|
text: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_construct(cls, text: str):
|
||||||
|
return super().model_construct(type="output_text", text=text)
|
||||||
|
|
||||||
|
@field_serializer('text')
|
||||||
|
def serialize_text(self, text: str):
|
||||||
|
return str(text)
|
||||||
|
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
tool_calls: list[ToolCallModel] = None
|
tool_calls: list[ToolCallModel] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@classmethod
|
@classmethod
|
||||||
def model_construct(cls, content: str, tool_calls: list = None):
|
def model_construct(cls, content: str, tool_calls: list = None):
|
||||||
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
|
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
|
||||||
@@ -134,6 +156,7 @@ class ChatCompletionMessage(BaseModel):
|
|||||||
with open(filepath, "w") as f:
|
with open(filepath, "w") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChoice(BaseModel):
|
class ChatCompletionChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatCompletionMessage
|
message: ChatCompletionMessage
|
||||||
@@ -183,6 +206,43 @@ class ChatCompletion(BaseModel):
|
|||||||
return conversation.get_dict()
|
return conversation.get_dict()
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
class ClientResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created_at: int
|
||||||
|
model: str
|
||||||
|
provider: Optional[str]
|
||||||
|
output: list[ResponseMessage]
|
||||||
|
usage: UsageModel
|
||||||
|
conversation: dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_construct(
|
||||||
|
cls,
|
||||||
|
content: str,
|
||||||
|
response_id: str = None,
|
||||||
|
created_at: int = None,
|
||||||
|
usage: UsageModel = None,
|
||||||
|
conversation: dict = None
|
||||||
|
) -> ClientResponse:
|
||||||
|
return super().model_construct(
|
||||||
|
id=f"resp-{response_id}" if response_id else None,
|
||||||
|
object="response",
|
||||||
|
created_at=created_at,
|
||||||
|
model=None,
|
||||||
|
provider=None,
|
||||||
|
output=[
|
||||||
|
ResponseMessage.model_construct(content),
|
||||||
|
],
|
||||||
|
**filter_none(usage=usage, conversation=conversation)
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_serializer('conversation')
|
||||||
|
def serialize_conversation(self, conversation: dict):
|
||||||
|
if hasattr(conversation, "get_dict"):
|
||||||
|
return conversation.get_dict()
|
||||||
|
return conversation
|
||||||
|
|
||||||
class ChatCompletionDelta(BaseModel):
|
class ChatCompletionDelta(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
|
@@ -21,8 +21,6 @@ from ... import debug
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
conversations: dict[dict[str, BaseConversation]] = {}
|
|
||||||
|
|
||||||
class Api:
|
class Api:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_models():
|
def get_models():
|
||||||
@@ -42,31 +40,49 @@ class Api:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_models(provider: str, api_key: str = None, api_base: str = None, ignored: list = None):
|
def get_provider_models(provider: str, api_key: str = None, api_base: str = None, ignored: list = None):
|
||||||
|
def get_model_data(provider: ProviderModelMixin, model: str):
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"label": model.split(":")[-1] if provider.__name__ == "AnyProvider" else model,
|
||||||
|
"default": model == provider.default_model,
|
||||||
|
"vision": model in provider.vision_models,
|
||||||
|
"audio": model in provider.audio_models,
|
||||||
|
"video": model in provider.video_models,
|
||||||
|
"image": model in provider.image_models,
|
||||||
|
"count": provider.models_count.get(model),
|
||||||
|
}
|
||||||
if provider in Provider.__map__:
|
if provider in Provider.__map__:
|
||||||
provider = Provider.__map__[provider]
|
provider = Provider.__map__[provider]
|
||||||
if issubclass(provider, ProviderModelMixin):
|
if issubclass(provider, ProviderModelMixin):
|
||||||
|
has_grouped_models = hasattr(provider, "get_grouped_models")
|
||||||
|
method = provider.get_grouped_models if has_grouped_models else provider.get_models
|
||||||
if "api_key" in signature(provider.get_models).parameters:
|
if "api_key" in signature(provider.get_models).parameters:
|
||||||
models = provider.get_models(api_key=api_key, api_base=api_base)
|
models = method(api_key=api_key, api_base=api_base)
|
||||||
elif "ignored" in signature(provider.get_models).parameters:
|
elif "ignored" in signature(provider.get_models).parameters:
|
||||||
models = provider.get_models(ignored=ignored)
|
models = method(ignored=ignored)
|
||||||
else:
|
else:
|
||||||
models = provider.get_models()
|
models = method()
|
||||||
|
if has_grouped_models:
|
||||||
|
return [{
|
||||||
|
"group": model["group"],
|
||||||
|
"models": [get_model_data(provider, name) for name in model["models"]]
|
||||||
|
} for model in models]
|
||||||
return [
|
return [
|
||||||
{
|
get_model_data(provider, model)
|
||||||
"model": model,
|
|
||||||
"default": model == provider.default_model,
|
|
||||||
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
|
|
||||||
"audio": getattr(provider, "default_audio_model", None) == model or model in getattr(provider, "audio_models", []),
|
|
||||||
"video": getattr(provider, "default_video_model", None) == model or model in getattr(provider, "video_models", []),
|
|
||||||
"image": False if provider.image_models is None else model in provider.image_models,
|
|
||||||
"count": getattr(provider, "models_count", {}).get(model),
|
|
||||||
}
|
|
||||||
for model in models
|
for model in models
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_providers() -> dict[str, str]:
|
def get_providers() -> dict[str, str]:
|
||||||
|
def safe_get_models(provider: ProviderModelMixin):
|
||||||
|
if not isinstance(provider, ProviderModelMixin):
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
return provider.get_models()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return True
|
||||||
return [{
|
return [{
|
||||||
"name": provider.__name__,
|
"name": provider.__name__,
|
||||||
"label": provider.label if hasattr(provider, "label") else provider.__name__,
|
"label": provider.label if hasattr(provider, "label") else provider.__name__,
|
||||||
@@ -79,7 +95,7 @@ class Api:
|
|||||||
"hf_space": getattr(provider, "hf_space", False),
|
"hf_space": getattr(provider, "hf_space", False),
|
||||||
"auth": provider.needs_auth,
|
"auth": provider.needs_auth,
|
||||||
"login_url": getattr(provider, "login_url", None),
|
"login_url": getattr(provider, "login_url", None),
|
||||||
} for provider in Provider.__providers__ if provider.working]
|
} for provider in Provider.__providers__ if provider.working and safe_get_models(provider)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_version() -> dict:
|
def get_version() -> dict:
|
||||||
@@ -121,11 +137,6 @@ class Api:
|
|||||||
conversation = json_data.get("conversation")
|
conversation = json_data.get("conversation")
|
||||||
if isinstance(conversation, dict):
|
if isinstance(conversation, dict):
|
||||||
kwargs["conversation"] = JsonConversation(**conversation)
|
kwargs["conversation"] = JsonConversation(**conversation)
|
||||||
else:
|
|
||||||
conversation_id = json_data.get("conversation_id")
|
|
||||||
if conversation_id and provider:
|
|
||||||
if provider in conversations and conversation_id in conversations[provider]:
|
|
||||||
kwargs["conversation"] = conversations[provider][conversation_id]
|
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
@@ -168,19 +179,13 @@ class Api:
|
|||||||
if isinstance(chunk, ProviderInfo):
|
if isinstance(chunk, ProviderInfo):
|
||||||
yield self.handle_provider(chunk, model)
|
yield self.handle_provider(chunk, model)
|
||||||
provider = chunk.name
|
provider = chunk.name
|
||||||
elif isinstance(chunk, BaseConversation):
|
elif isinstance(chunk, JsonConversation):
|
||||||
if provider is not None:
|
if provider is not None:
|
||||||
if hasattr(provider, "__name__"):
|
if hasattr(provider, "__name__"):
|
||||||
provider = provider.__name__
|
provider = provider.__name__
|
||||||
if provider not in conversations:
|
yield self._format_json("conversation", {
|
||||||
conversations[provider] = {}
|
provider: chunk.get_dict()
|
||||||
conversations[provider][conversation_id] = chunk
|
})
|
||||||
if isinstance(chunk, JsonConversation):
|
|
||||||
yield self._format_json("conversation", {
|
|
||||||
provider: chunk.get_dict()
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
yield self._format_json("conversation_id", conversation_id)
|
|
||||||
elif isinstance(chunk, Exception):
|
elif isinstance(chunk, Exception):
|
||||||
logger.exception(chunk)
|
logger.exception(chunk)
|
||||||
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
||||||
|
@@ -285,27 +285,27 @@ def to_input_audio(audio: ImageType, filename: str = None) -> str:
|
|||||||
}
|
}
|
||||||
raise ValueError("Invalid input audio")
|
raise ValueError("Invalid input audio")
|
||||||
|
|
||||||
def use_aspect_ratio(extra_data: dict, aspect_ratio: str) -> Image:
|
def use_aspect_ratio(extra_body: dict, aspect_ratio: str) -> Image:
|
||||||
extra_data = {key: value for key, value in extra_data.items() if value is not None}
|
extra_body = {key: value for key, value in extra_body.items() if value is not None}
|
||||||
if aspect_ratio == "1:1":
|
if aspect_ratio == "1:1":
|
||||||
extra_data = {
|
extra_body = {
|
||||||
"width": 1024,
|
"width": 1024,
|
||||||
"height": 1024,
|
"height": 1024,
|
||||||
**extra_data
|
**extra_body
|
||||||
}
|
}
|
||||||
elif aspect_ratio == "16:9":
|
elif aspect_ratio == "16:9":
|
||||||
extra_data = {
|
extra_body = {
|
||||||
"width": 832,
|
"width": 832,
|
||||||
"height": 480,
|
"height": 480,
|
||||||
**extra_data
|
**extra_body
|
||||||
}
|
}
|
||||||
elif aspect_ratio == "9:16":
|
elif aspect_ratio == "9:16":
|
||||||
extra_data = {
|
extra_body = {
|
||||||
"width": 480,
|
"width": 480,
|
||||||
"height": 832,
|
"height": 832,
|
||||||
**extra_data
|
**extra_body
|
||||||
}
|
}
|
||||||
return extra_data
|
return extra_body
|
||||||
|
|
||||||
class ImageDataResponse():
|
class ImageDataResponse():
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@@ -2,24 +2,103 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from ..typing import AsyncResult, Messages, MediaListType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..errors import ModelNotFoundError
|
from ..errors import ModelNotFoundError
|
||||||
from ..providers.retry_provider import IterListProvider
|
|
||||||
from ..image import is_data_an_audio
|
from ..image import is_data_an_audio
|
||||||
|
from ..providers.retry_provider import IterListProvider
|
||||||
|
from ..providers.types import ProviderType
|
||||||
from ..providers.response import JsonConversation, ProviderInfo
|
from ..providers.response import JsonConversation, ProviderInfo
|
||||||
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
|
||||||
from ..Provider.hf_space import HuggingSpace
|
from ..Provider.hf_space import HuggingSpace
|
||||||
|
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, FreeRouter
|
||||||
|
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, EdgeTTS, gTTS, MarkItDown
|
||||||
|
from ..Provider import HarProvider, DDG, HuggingFace, HuggingFaceMedia
|
||||||
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from .. import Provider
|
from .. import Provider
|
||||||
from .. import models
|
from .. import models
|
||||||
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, FreeRouter
|
from .. import debug
|
||||||
from ..Provider import Microsoft_Phi_4_Multimodal, DeepInfraChat, Blackbox, EdgeTTS, gTTS, MarkItDown, HarProvider
|
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
LABELS = {
|
||||||
|
"openai": "OpenAI: ChatGPT",
|
||||||
|
"llama": "Meta: LLaMA",
|
||||||
|
"deepseek": "DeepSeek",
|
||||||
|
"qwen": "Alibaba: Qwen",
|
||||||
|
"google": "Google: Gemini / Gemma",
|
||||||
|
"grok": "xAI: Grok",
|
||||||
|
"claude": "Anthropic: Claude",
|
||||||
|
"command": "Cohere: Command",
|
||||||
|
"phi": "Microsoft: Phi",
|
||||||
|
"PollinationsAI": "Pollinations AI",
|
||||||
|
"perplexity": "Perplexity Labs",
|
||||||
|
"video": "Video Generation",
|
||||||
|
"image": "Image Generation",
|
||||||
|
"other": "Other Models",
|
||||||
|
}
|
||||||
|
|
||||||
class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
default_model = "default"
|
default_model = "default"
|
||||||
working = True
|
working = True
|
||||||
|
models_storage: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_grouped_models(cls, ignored: list[str] = []) -> dict[str, list[str]]:
|
||||||
|
unsorted_models = cls.get_models(ignored=ignored)
|
||||||
|
groups = {key: [] for key in LABELS.keys()}
|
||||||
|
for model in unsorted_models:
|
||||||
|
added = False
|
||||||
|
for group in groups:
|
||||||
|
if group == "qwen":
|
||||||
|
if model.startswith("qwen") or model.startswith("qwq") or model.startswith("qvq"):
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
elif group == "perplexity":
|
||||||
|
if model.startswith("sonar") or model == "r1-1776":
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
elif group == "google":
|
||||||
|
if model.startswith("gemini-") or model.startswith("gemma-"):
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
elif group == "openai":
|
||||||
|
if model.startswith(
|
||||||
|
"gpt-") or model.startswith(
|
||||||
|
"chatgpt-") or model.startswith(
|
||||||
|
"o1") or model.startswith(
|
||||||
|
"o3") or model.startswith(
|
||||||
|
"o4-") or model in ("auto", "dall-e-3", "searchgpt"):
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
elif model.startswith(group):
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
elif group == "video":
|
||||||
|
if model in cls.video_models:
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
elif group == "image":
|
||||||
|
if model in cls.image_models:
|
||||||
|
groups[group].append(model)
|
||||||
|
added = True
|
||||||
|
break
|
||||||
|
if not added:
|
||||||
|
if model.startswith("janus"):
|
||||||
|
groups["deepseek"].append(model)
|
||||||
|
elif model == "meta-ai":
|
||||||
|
groups["llama"].append(model)
|
||||||
|
else:
|
||||||
|
groups["other"].append(model)
|
||||||
|
return [
|
||||||
|
{"group": LABELS[group], "models": names} for group, names in groups.items()
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls, ignored: list[str] = []) -> list[str]:
|
def get_models(cls, ignored: list[str] = []) -> list[str]:
|
||||||
if not cls.models:
|
ignored_key = " ".join(ignored)
|
||||||
|
if not cls.models_storage.get(ignored_key):
|
||||||
cls.audio_models = {}
|
cls.audio_models = {}
|
||||||
cls.image_models = []
|
cls.image_models = []
|
||||||
cls.vision_models = []
|
cls.vision_models = []
|
||||||
@@ -27,7 +106,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
model_with_providers = {
|
model_with_providers = {
|
||||||
model: [
|
model: [
|
||||||
provider for provider in providers
|
provider for provider in providers
|
||||||
if provider.working and getattr(provider, "parent", provider.__name__) not in ignored
|
if provider.working and provider.get_parent() not in ignored
|
||||||
] for model, (_, providers) in models.__models__.items()
|
] for model, (_, providers) in models.__models__.items()
|
||||||
}
|
}
|
||||||
model_with_providers = {
|
model_with_providers = {
|
||||||
@@ -38,30 +117,29 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
|
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
|
||||||
}
|
}
|
||||||
all_models = [cls.default_model] + list(model_with_providers.keys())
|
all_models = [cls.default_model] + list(model_with_providers.keys())
|
||||||
for provider in [OpenaiChat, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok]:
|
for provider in [OpenaiChat, CopilotAccount, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok, DDG]:
|
||||||
if not provider.working or getattr(provider, "parent", provider.__name__) in ignored:
|
provider: ProviderType = provider
|
||||||
|
if not provider.working or provider.get_parent() in ignored:
|
||||||
continue
|
continue
|
||||||
if provider == PollinationsAI:
|
if provider == CopilotAccount:
|
||||||
|
all_models.extend(list(provider.model_aliases.keys()))
|
||||||
|
elif provider == PollinationsAI:
|
||||||
all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models])
|
all_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model not in all_models])
|
||||||
cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models})
|
cls.audio_models.update({f"{provider.__name__}:{model}": [] for model in provider.get_models() if model in provider.audio_models})
|
||||||
cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models])
|
cls.image_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.image_models])
|
||||||
cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models])
|
cls.vision_models.extend([f"{provider.__name__}:{model}" for model in provider.get_models() if model in provider.vision_models])
|
||||||
|
all_models.extend(list(provider.model_aliases.keys()))
|
||||||
else:
|
else:
|
||||||
all_models.extend(provider.get_models())
|
all_models.extend(provider.get_models())
|
||||||
cls.image_models.extend(provider.image_models)
|
cls.image_models.extend(provider.image_models)
|
||||||
cls.vision_models.extend(provider.vision_models)
|
cls.vision_models.extend(provider.vision_models)
|
||||||
cls.video_models.extend(provider.video_models)
|
cls.video_models.extend(provider.video_models)
|
||||||
if CopilotAccount.working and CopilotAccount.parent not in ignored:
|
|
||||||
all_models.extend(list(CopilotAccount.model_aliases.keys()))
|
|
||||||
if PollinationsAI.working and PollinationsAI.__name__ not in ignored:
|
|
||||||
all_models.extend(list(PollinationsAI.model_aliases.keys()))
|
|
||||||
def clean_name(name: str) -> str:
|
def clean_name(name: str) -> str:
|
||||||
return name.split("/")[-1].split(":")[0].lower(
|
return name.split("/")[-1].split(":")[0].lower(
|
||||||
).replace("-instruct", ""
|
).replace("-instruct", ""
|
||||||
).replace("-chat", ""
|
).replace("-chat", ""
|
||||||
).replace("-08-2024", ""
|
).replace("-08-2024", ""
|
||||||
).replace("-03-2025", ""
|
).replace("-03-2025", ""
|
||||||
).replace("-20250219", ""
|
|
||||||
).replace("-20241022", ""
|
).replace("-20241022", ""
|
||||||
).replace("-20240904", ""
|
).replace("-20240904", ""
|
||||||
).replace("-2025-04-16", ""
|
).replace("-2025-04-16", ""
|
||||||
@@ -88,11 +166,16 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
).replace("-fp8", ""
|
).replace("-fp8", ""
|
||||||
).replace("-bf16", ""
|
).replace("-bf16", ""
|
||||||
).replace("-hf", ""
|
).replace("-hf", ""
|
||||||
).replace("llama3", "llama-3")
|
).replace("flux.1-", "flux-"
|
||||||
for provider in [HarProvider, LambdaChat, DeepInfraChat]:
|
).replace("llama3", "llama-3"
|
||||||
if not provider.working or getattr(provider, "parent", provider.__name__) in ignored:
|
).replace("meta-llama-", "llama-")
|
||||||
|
for provider in [HarProvider, LambdaChat, DeepInfraChat, HuggingFace, HuggingFaceMedia]:
|
||||||
|
if not provider.working or provider.get_parent() in ignored:
|
||||||
continue
|
continue
|
||||||
model_map = {clean_name(model): model for model in provider.get_models()}
|
new_models = provider.get_models()
|
||||||
|
if provider == HuggingFaceMedia:
|
||||||
|
new_models = provider.video_models
|
||||||
|
model_map = {clean_name(model): model for model in new_models}
|
||||||
if not provider.model_aliases:
|
if not provider.model_aliases:
|
||||||
provider.model_aliases = {}
|
provider.model_aliases = {}
|
||||||
provider.model_aliases.update(model_map)
|
provider.model_aliases.update(model_map)
|
||||||
@@ -101,11 +184,11 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls.vision_models.extend([clean_name(model) for model in provider.vision_models])
|
cls.vision_models.extend([clean_name(model) for model in provider.vision_models])
|
||||||
cls.video_models.extend([clean_name(model) for model in provider.video_models])
|
cls.video_models.extend([clean_name(model) for model in provider.video_models])
|
||||||
for provider in [Microsoft_Phi_4_Multimodal, PollinationsAI]:
|
for provider in [Microsoft_Phi_4_Multimodal, PollinationsAI]:
|
||||||
if provider.working and getattr(provider, "parent", provider.__name__) not in ignored:
|
if provider.working and provider.get_parent() not in ignored:
|
||||||
cls.audio_models.update(provider.audio_models)
|
cls.audio_models.update(provider.audio_models)
|
||||||
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
||||||
cls.models = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
cls.models_storage[ignored_key] = list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
||||||
return cls.models
|
return cls.models_storage[ignored_key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
@@ -116,6 +199,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
media: MediaListType = None,
|
media: MediaListType = None,
|
||||||
ignored: list[str] = [],
|
ignored: list[str] = [],
|
||||||
conversation: JsonConversation = None,
|
conversation: JsonConversation = None,
|
||||||
|
api_key: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
cls.get_models(ignored=ignored)
|
cls.get_models(ignored=ignored)
|
||||||
@@ -144,13 +228,13 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
providers = models.default.best_provider.providers
|
providers = models.default.best_provider.providers
|
||||||
elif model in Provider.__map__:
|
elif model in Provider.__map__:
|
||||||
provider = Provider.__map__[model]
|
provider = Provider.__map__[model]
|
||||||
if provider.working and getattr(provider, "parent", provider.__name__) not in ignored:
|
if provider.working and provider.get_parent() not in ignored:
|
||||||
model = None
|
model = None
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
else:
|
else:
|
||||||
for provider in [
|
for provider in [
|
||||||
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox,
|
OpenaiChat, Cloudflare, HarProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox,
|
||||||
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat
|
HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat, DDG, HuggingFace, HuggingFaceMedia,
|
||||||
]:
|
]:
|
||||||
if provider.working:
|
if provider.working:
|
||||||
if not model or model in provider.get_models() or model in provider.model_aliases:
|
if not model or model in provider.get_models() or model in provider.model_aliases:
|
||||||
@@ -158,7 +242,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if model in models.__models__:
|
if model in models.__models__:
|
||||||
for provider in models.__models__[model][1]:
|
for provider in models.__models__[model][1]:
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
providers = [provider for provider in providers if provider.working and getattr(provider, "parent", provider.__name__) not in ignored]
|
providers = [provider for provider in providers if provider.working and provider.get_parent() not in ignored]
|
||||||
|
providers = list({provider.__name__: provider for provider in providers}.values())
|
||||||
if len(providers) == 0:
|
if len(providers) == 0:
|
||||||
raise ModelNotFoundError(f"Model {model} not found in any provider.")
|
raise ModelNotFoundError(f"Model {model} not found in any provider.")
|
||||||
if len(providers) == 1:
|
if len(providers) == 1:
|
||||||
@@ -167,7 +252,10 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
child_conversation = getattr(conversation, provider.__name__, None)
|
child_conversation = getattr(conversation, provider.__name__, None)
|
||||||
if child_conversation is not None:
|
if child_conversation is not None:
|
||||||
kwargs["conversation"] = JsonConversation(**child_conversation)
|
kwargs["conversation"] = JsonConversation(**child_conversation)
|
||||||
|
debug.log(f"Using {provider.__name__} provider" + f" and model {model}" if model else "")
|
||||||
yield ProviderInfo(**provider.get_dict(), model=model)
|
yield ProviderInfo(**provider.get_dict(), model=model)
|
||||||
|
if provider in (HuggingFace, HuggingFaceMedia):
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
async for chunk in provider.get_async_create_function()(
|
async for chunk in provider.get_async_create_function()(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
@@ -183,6 +271,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
else:
|
else:
|
||||||
yield chunk
|
yield chunk
|
||||||
return
|
return
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
async for chunk in IterListProvider(providers).get_async_create_function()(
|
async for chunk in IterListProvider(providers).get_async_create_function()(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
@@ -194,4 +283,4 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
setattr(Provider, "AnyProvider", AnyProvider)
|
setattr(Provider, "AnyProvider", AnyProvider)
|
||||||
Provider.__map__["AnyProvider"] = AnyProvider
|
Provider.__map__["AnyProvider"] = AnyProvider
|
||||||
Provider.__providers__.append(AnyProvider)
|
Provider.__providers__.append(AnyProvider)
|
@@ -340,6 +340,7 @@ class ProviderModelMixin:
|
|||||||
default_model: str = None
|
default_model: str = None
|
||||||
models: list[str] = []
|
models: list[str] = []
|
||||||
model_aliases: dict[str, str] = {}
|
model_aliases: dict[str, str] = {}
|
||||||
|
models_count: dict = {}
|
||||||
image_models: list = []
|
image_models: list = []
|
||||||
vision_models: list = []
|
vision_models: list = []
|
||||||
video_models: list = []
|
video_models: list = []
|
||||||
|
@@ -52,7 +52,7 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
|
|||||||
messages = [
|
messages = [
|
||||||
(message["role"], to_string(message["content"]))
|
(message["role"], to_string(message["content"]))
|
||||||
for message in messages
|
for message in messages
|
||||||
if include_system or message.get("role") != "system"
|
if include_system or message.get("role") not in ("developer", "system")
|
||||||
]
|
]
|
||||||
formatted = "\n".join([
|
formatted = "\n".join([
|
||||||
f'{role.capitalize()}: {content}'
|
f'{role.capitalize()}: {content}'
|
||||||
@@ -64,7 +64,7 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
|
|||||||
return f"{formatted}\nAssistant:"
|
return f"{formatted}\nAssistant:"
|
||||||
|
|
||||||
def get_system_prompt(messages: Messages) -> str:
|
def get_system_prompt(messages: Messages) -> str:
|
||||||
return "\n".join([m["content"] for m in messages if m["role"] == "system"])
|
return "\n".join([m["content"] for m in messages if m["role"] in ("developer", "system")])
|
||||||
|
|
||||||
def get_last_user_message(messages: Messages) -> str:
|
def get_last_user_message(messages: Messages) -> str:
|
||||||
user_messages = []
|
user_messages = []
|
||||||
|
@@ -4,12 +4,12 @@ import random
|
|||||||
|
|
||||||
from ..typing import Type, List, CreateResult, Messages, AsyncResult
|
from ..typing import Type, List, CreateResult, Messages, AsyncResult
|
||||||
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
||||||
from .response import MediaResponse, AudioResponse, ProviderInfo
|
from .response import MediaResponse, AudioResponse, ProviderInfo, Reasoning
|
||||||
from .. import debug
|
from .. import debug
|
||||||
from ..errors import RetryProviderError, RetryNoProviderError
|
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
|
||||||
|
|
||||||
def is_content(chunk):
|
def is_content(chunk):
|
||||||
return isinstance(chunk, (str, MediaResponse, AudioResponse))
|
return isinstance(chunk, (str, MediaResponse, AudioResponse, Reasoning))
|
||||||
|
|
||||||
class IterListProvider(BaseRetryProvider):
|
class IterListProvider(BaseRetryProvider):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -29,6 +29,7 @@ class IterListProvider(BaseRetryProvider):
|
|||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.working = True
|
self.working = True
|
||||||
self.last_provider: Type[BaseProvider] = None
|
self.last_provider: Type[BaseProvider] = None
|
||||||
|
self.add_api_key = False
|
||||||
|
|
||||||
def create_completion(
|
def create_completion(
|
||||||
self,
|
self,
|
||||||
@@ -37,6 +38,7 @@ class IterListProvider(BaseRetryProvider):
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
ignore_stream: bool = False,
|
ignore_stream: bool = False,
|
||||||
ignored: list[str] = [],
|
ignored: list[str] = [],
|
||||||
|
api_key: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> CreateResult:
|
) -> CreateResult:
|
||||||
"""
|
"""
|
||||||
@@ -57,6 +59,8 @@ class IterListProvider(BaseRetryProvider):
|
|||||||
self.last_provider = provider
|
self.last_provider = provider
|
||||||
debug.log(f"Using {provider.__name__} provider")
|
debug.log(f"Using {provider.__name__} provider")
|
||||||
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
|
||||||
|
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]:
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
try:
|
try:
|
||||||
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
@@ -146,6 +150,7 @@ class RetryProvider(IterListProvider):
|
|||||||
super().__init__(providers, shuffle)
|
super().__init__(providers, shuffle)
|
||||||
self.single_provider_retry = single_provider_retry
|
self.single_provider_retry = single_provider_retry
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
|
self.add_api_key = True
|
||||||
|
|
||||||
def create_completion(
|
def create_completion(
|
||||||
self,
|
self,
|
||||||
@@ -238,6 +243,9 @@ def raise_exceptions(exceptions: dict) -> None:
|
|||||||
RetryNoProviderError: If no provider is found.
|
RetryNoProviderError: If no provider is found.
|
||||||
"""
|
"""
|
||||||
if exceptions:
|
if exceptions:
|
||||||
|
for provider_name, e in exceptions.items():
|
||||||
|
if isinstance(e, (MissingAuthError, NoValidHarFileError)):
|
||||||
|
raise e
|
||||||
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
|
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
|
||||||
f"{p}: {type(exception).__name__}: {exception}" for p, exception in exceptions.items()
|
f"{p}: {type(exception).__name__}: {exception}" for p, exception in exceptions.items()
|
||||||
])) from list(exceptions.values())[0]
|
])) from list(exceptions.values())[0]
|
||||||
|
@@ -56,6 +56,10 @@ class BaseProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
return {'name': cls.__name__, 'url': cls.url, 'label': getattr(cls, 'label', None)}
|
return {'name': cls.__name__, 'url': cls.url, 'label': getattr(cls, 'label', None)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_parent(cls) -> str:
|
||||||
|
return getattr(cls, "parent", cls.__name__)
|
||||||
|
|
||||||
class BaseRetryProvider(BaseProvider):
|
class BaseRetryProvider(BaseProvider):
|
||||||
"""
|
"""
|
||||||
Base class for a provider that implements retry logic.
|
Base class for a provider that implements retry logic.
|
||||||
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import json
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from typing import Iterator
|
from typing import Iterator, AsyncIterator
|
||||||
from http.cookies import Morsel
|
from http.cookies import Morsel
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -198,4 +199,11 @@ async def get_nodriver(
|
|||||||
browser.stop()
|
browser.stop()
|
||||||
finally:
|
finally:
|
||||||
lock_file.unlink(missing_ok=True)
|
lock_file.unlink(missing_ok=True)
|
||||||
return browser, on_stop
|
return browser, on_stop
|
||||||
|
|
||||||
|
async def see_stream(iter_lines: Iterator[bytes]) -> AsyncIterator[dict]:
|
||||||
|
async for line in iter_lines:
|
||||||
|
if line.startswith(b"data: "):
|
||||||
|
if line[6:].startswith(b"[DONE]"):
|
||||||
|
break
|
||||||
|
yield json.loads(line[6:])
|
@@ -19,4 +19,5 @@ cryptography
|
|||||||
nodriver
|
nodriver
|
||||||
python-multipart
|
python-multipart
|
||||||
pypdf2
|
pypdf2
|
||||||
python-docx
|
python-docx
|
||||||
|
markitdown[all]
|
18
setup.py
18
setup.py
@@ -3,14 +3,16 @@ import os
|
|||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
STATIC_HOST = "gpt4free.github.io"
|
||||||
|
|
||||||
here = os.path.abspath(os.path.dirname(__file__))
|
here = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
with codecs.open(os.path.join(here, 'README.md'), encoding='utf-8') as fh:
|
with codecs.open(os.path.join(here, 'README.md'), encoding='utf-8') as fh:
|
||||||
long_description = '\n' + fh.read()
|
long_description = '\n' + fh.read()
|
||||||
|
|
||||||
long_description = long_description.replace("[!NOTE]", "")
|
long_description = long_description.replace("[!NOTE]", "")
|
||||||
long_description = long_description.replace("(docs/images/", "(https://raw.githubusercontent.com/xtekky/gpt4free/refs/heads/main/docs/images/")
|
long_description = long_description.replace("(docs/images/", f"(https://{STATIC_HOST}/docs/images/")
|
||||||
long_description = long_description.replace("(docs/", "(https://github.com/xtekky/gpt4free/blob/main/docs/")
|
long_description = long_description.replace("(docs/", f"(https://github.com/gpt4free/{STATIC_HOST}/blob/main/docs/")
|
||||||
|
|
||||||
INSTALL_REQUIRE = [
|
INSTALL_REQUIRE = [
|
||||||
"requests",
|
"requests",
|
||||||
@@ -39,11 +41,10 @@ EXTRA_REQUIRE = {
|
|||||||
"pywebview",
|
"pywebview",
|
||||||
"plyer",
|
"plyer",
|
||||||
"setuptools",
|
"setuptools",
|
||||||
"pypdf2", # files
|
"odfpy", # files
|
||||||
"python-docx",
|
|
||||||
"odfpy",
|
|
||||||
"ebooklib",
|
"ebooklib",
|
||||||
"openpyxl",
|
"openpyxl",
|
||||||
|
"markitdown[all]"
|
||||||
],
|
],
|
||||||
'slim': [
|
'slim': [
|
||||||
"curl_cffi>=0.6.2",
|
"curl_cffi>=0.6.2",
|
||||||
@@ -57,8 +58,7 @@ EXTRA_REQUIRE = {
|
|||||||
"fastapi", # api
|
"fastapi", # api
|
||||||
"uvicorn", # api
|
"uvicorn", # api
|
||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pypdf2", # files
|
"markitdown[pdf, docx, pptx]"
|
||||||
"python-docx",
|
|
||||||
],
|
],
|
||||||
"image": [
|
"image": [
|
||||||
"pillow",
|
"pillow",
|
||||||
@@ -90,13 +90,11 @@ EXTRA_REQUIRE = {
|
|||||||
"gpt4all"
|
"gpt4all"
|
||||||
],
|
],
|
||||||
"files": [
|
"files": [
|
||||||
"spacy",
|
|
||||||
"beautifulsoup4",
|
"beautifulsoup4",
|
||||||
"pypdf2",
|
|
||||||
"python-docx",
|
|
||||||
"odfpy",
|
"odfpy",
|
||||||
"ebooklib",
|
"ebooklib",
|
||||||
"openpyxl",
|
"openpyxl",
|
||||||
|
"markitdown[all]"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user