Add audio transcribing example and support

Add Grok Chat provider
Rename images parameter to media
Update demo homepage
This commit is contained in:
hlohaus
2025-03-21 03:17:45 +01:00
parent 10d32a4c5f
commit c97ba0c88e
36 changed files with 407 additions and 300 deletions

View File

@@ -19,6 +19,7 @@ The G4F AsyncClient API is designed to be compatible with the OpenAI API, making
- [Text Completions](#text-completions) - [Text Completions](#text-completions)
- [Streaming Completions](#streaming-completions) - [Streaming Completions](#streaming-completions)
- [Using a Vision Model](#using-a-vision-model) - [Using a Vision Model](#using-a-vision-model)
- **[Transcribing Audio with Chat Completions](#transcribing-audio-with-chat-completions)** *(New Section)*
- [Image Generation](#image-generation) - [Image Generation](#image-generation)
- [Advanced Usage](#advanced-usage) - [Advanced Usage](#advanced-usage)
- [Conversation Memory](#conversation-memory) - [Conversation Memory](#conversation-memory)
@@ -203,6 +204,54 @@ async def main():
asyncio.run(main()) asyncio.run(main())
``` ```
---
### Transcribing Audio with Chat Completions
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:
```python
import asyncio
from g4f.client import AsyncClient
import g4f.Provider
import g4f.models
async def main():
client = AsyncClient(provider=g4f.Provider.PollinationsAI) # or g4f.Provider.Microsoft_Phi_4
with open("audio.wav", "rb") as audio_file:
response = await client.chat.completions.create(
model=g4f.models.default,
messages=[{"role": "user", "content": "Transcribe this audio"}],
media=[[audio_file, "audio.wav"]],
modalities=["text"],
)
print(response.choices[0].message.content)
if __name__ == "__main__":
asyncio.run(main())
```
#### Explanation
- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`.
- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use.
- **API Call**: The `chat.completions.create` method is called with:
- `model=g4f.models.default`: Uses the default model for the selected provider.
- `messages`: A list containing a user message instructing the model to transcribe the audio.
- `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`).
- `modalities=["text"]`: Specifies that the output should be text (the transcription).
- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed.
#### Notes
- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality.
- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider.
- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models).
This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API.
---
### Image Generation ### Image Generation
**The `response_format` parameter is optional and can have the following values:** **The `response_format` parameter is optional and can have the following values:**
- **If not specified (default):** The image will be saved locally, and a local path will be returned (e.g., "/images/1733331238_cf9d6aa9-f606-4fea-ba4b-f06576cba309.jpg"). - **If not specified (default):** The image will be saved locally, and a local path will be returned (e.g., "/images/1733331238_cf9d6aa9-f606-4fea-ba4b-f06576cba309.jpg").

View File

@@ -6,10 +6,10 @@ from g4f.models import __models__
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
from g4f.errors import MissingRequirementsError, MissingAuthError from g4f.errors import MissingRequirementsError, MissingAuthError
class TestProviderHasModel(unittest.IsolatedAsyncioTestCase): class TestProviderHasModel(unittest.TestCase):
cache: dict = {} cache: dict = {}
async def test_provider_has_model(self): def test_provider_has_model(self):
for model, providers in __models__.values(): for model, providers in __models__.values():
for provider in providers: for provider in providers:
if issubclass(provider, ProviderModelMixin): if issubclass(provider, ProviderModelMixin):
@@ -17,9 +17,9 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
model_name = provider.model_aliases[model.name] model_name = provider.model_aliases[model.name]
else: else:
model_name = model.name model_name = model.name
await asyncio.wait_for(self.provider_has_model(provider, model_name), 10) self.provider_has_model(provider, model_name)
async def provider_has_model(self, provider: Type[BaseProvider], model: str): def provider_has_model(self, provider: Type[BaseProvider], model: str):
if provider.__name__ not in self.cache: if provider.__name__ not in self.cache:
try: try:
self.cache[provider.__name__] = provider.get_models() self.cache[provider.__name__] = provider.get_models()
@@ -28,7 +28,7 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
if self.cache[provider.__name__]: if self.cache[provider.__name__]:
self.assertIn(model, self.cache[provider.__name__], provider.__name__) self.assertIn(model, self.cache[provider.__name__], provider.__name__)
async def test_all_providers_working(self): def test_all_providers_working(self):
for model, providers in __models__.values(): for model, providers in __models__.values():
for provider in providers: for provider in providers:
self.assertTrue(provider.working, f"{provider.__name__} in {model.name}") self.assertTrue(provider.working, f"{provider.__name__} in {model.name}")

View File

@@ -48,7 +48,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
"olmo-2-13b": "OLMo-2-1124-13B-Instruct", "olmo-2-13b": "OLMo-2-1124-13B-Instruct",
"tulu-3-1-8b": "tulu-3-1-8b", "tulu-3-1-8b": "tulu-3-1-8b",
"tulu-3-70b": "Llama-3-1-Tulu-3-70B", "tulu-3-70b": "Llama-3-1-Tulu-3-70B",
"llama-3.1-405b": "tulu-3-405b", "llama-3.1-405b": "tulu3-405b",
"llama-3.1-8b": "tulu-3-1-8b", "llama-3.1-8b": "tulu-3-1-8b",
"llama-3.1-70b": "Llama-3-1-Tulu-3-70B", "llama-3.1-70b": "Llama-3-1-Tulu-3-70B",
} }

View File

@@ -11,7 +11,7 @@ from pathlib import Path
from typing import Optional from typing import Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, MediaListType
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import to_data_uri from ..image import to_data_uri
@@ -444,7 +444,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages, messages: Messages,
prompt: str = None, prompt: str = None,
proxy: str = None, proxy: str = None,
images: ImagesType = None, media: MediaListType = None,
top_p: float = None, top_p: float = None,
temperature: float = None, temperature: float = None,
max_tokens: int = None, max_tokens: int = None,
@@ -479,14 +479,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
} }
current_messages.append(current_msg) current_messages.append(current_msg)
if images is not None: if media is not None:
current_messages[-1]['data'] = { current_messages[-1]['data'] = {
"imagesData": [ "imagesData": [
{ {
"filePath": f"/{image_name}", "filePath": f"/{image_name}",
"contents": to_data_uri(image) "contents": to_data_uri(image)
} }
for image, image_name in images for image, image_name in media
], ],
"fileText": "", "fileText": "",
"title": "" "title": ""

View File

@@ -21,7 +21,7 @@ except ImportError:
from .base_provider import AbstractProvider, ProviderModelMixin from .base_provider import AbstractProvider, ProviderModelMixin
from .helper import format_prompt_max_length from .helper import format_prompt_max_length
from .openai.har_file import get_headers, get_har_files from .openai.har_file import get_headers, get_har_files
from ..typing import CreateResult, Messages, ImagesType from ..typing import CreateResult, Messages, MediaListType
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
@@ -66,7 +66,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
proxy: str = None, proxy: str = None,
timeout: int = 900, timeout: int = 900,
prompt: str = None, prompt: str = None,
images: ImagesType = None, media: MediaListType = None,
conversation: BaseConversation = None, conversation: BaseConversation = None,
return_conversation: bool = False, return_conversation: bool = False,
api_key: str = None, api_key: str = None,
@@ -77,7 +77,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
websocket_url = cls.websocket_url websocket_url = cls.websocket_url
headers = None headers = None
if cls.needs_auth or images is not None: if cls.needs_auth or media is not None:
if api_key is not None: if api_key is not None:
cls._access_token = api_key cls._access_token = api_key
if cls._access_token is None: if cls._access_token is None:
@@ -142,8 +142,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
debug.log(f"Copilot: Use conversation: {conversation_id}") debug.log(f"Copilot: Use conversation: {conversation_id}")
uploaded_images = [] uploaded_images = []
if images is not None: if media is not None:
for image, _ in images: for image, _ in media:
data = to_bytes(image) data = to_bytes(image)
response = session.post( response = session.post(
"https://copilot.microsoft.com/c/api/attachments", "https://copilot.microsoft.com/c/api/attachments",

View File

@@ -30,7 +30,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
api_endpoint = "https://duckduckgo.com/duckchat/v1/chat" api_endpoint = "https://duckduckgo.com/duckchat/v1/chat"
status_url = "https://duckduckgo.com/duckchat/v1/status" status_url = "https://duckduckgo.com/duckchat/v1/status"
working = True working = False
supports_stream = True supports_stream = True
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, MediaListType
from .template import OpenaiTemplate from .template import OpenaiTemplate
from ..image import to_data_uri from ..image import to_data_uri
@@ -70,7 +70,6 @@ class DeepInfraChat(OpenaiTemplate):
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = None, max_tokens: int = None,
headers: dict = {}, headers: dict = {},
images: ImagesType = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
headers = { headers = {
@@ -82,23 +81,6 @@ class DeepInfraChat(OpenaiTemplate):
**headers **headers
} }
if images is not None:
if not model or model not in cls.models:
model = cls.default_vision_model
if messages:
last_message = messages[-1].copy()
last_message["content"] = [
*[{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
} for image, _ in images],
{
"type": "text",
"text": last_message["content"]
}
]
messages[-1] = last_message
async for chunk in super().create_async_generator( async for chunk in super().create_async_generator(
model, model,
messages, messages,

View File

@@ -3,13 +3,12 @@ from __future__ import annotations
import json import json
from aiohttp import ClientSession, FormData from aiohttp import ClientSession, FormData
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, MediaListType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..image import to_data_uri, to_bytes, is_accepted_format from ..image import to_bytes, is_accepted_format
from .helper import format_prompt from .helper import format_prompt
class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin): class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://dynaspark.onrender.com" url = "https://dynaspark.onrender.com"
login_url = None login_url = None
@@ -38,7 +37,7 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
model: str, model: str,
messages: Messages, messages: Messages,
proxy: str = None, proxy: str = None,
images: ImagesType = None, media: MediaListType = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
headers = { headers = {
@@ -49,14 +48,13 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36', 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
'x-requested-with': 'XMLHttpRequest' 'x-requested-with': 'XMLHttpRequest'
} }
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
form = FormData() form = FormData()
form.add_field('user_input', format_prompt(messages)) form.add_field('user_input', format_prompt(messages))
form.add_field('ai_model', model) form.add_field('ai_model', model)
if images is not None and len(images) > 0: if media is not None and len(media) > 0:
image, image_name = images[0] image, image_name = media[0]
image_bytes = to_bytes(image) image_bytes = to_bytes(image)
form.add_field('file', image_bytes, filename=image_name, content_type=is_accepted_format(image_bytes)) form.add_field('file', image_bytes, filename=image_name, content_type=is_accepted_format(image_bytes))

View File

@@ -14,6 +14,7 @@ class LambdaChat(HuggingChat):
default_model = "deepseek-llama3.3-70b" default_model = "deepseek-llama3.3-70b"
reasoning_model = "deepseek-r1" reasoning_model = "deepseek-r1"
image_models = [] image_models = []
models = []
fallback_models = [ fallback_models = [
default_model, default_model,
reasoning_model, reasoning_model,

View File

@@ -9,8 +9,8 @@ from aiohttp import ClientSession
from .helper import filter_none, format_image_prompt from .helper import filter_none, format_image_prompt
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, MediaListType
from ..image import to_data_uri from ..image import to_data_uri, is_data_an_audio, to_input_audio
from ..errors import ModelNotFoundError from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector from ..requests.aiohttp import get_connector
@@ -146,17 +146,24 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance: bool = False, enhance: bool = False,
safe: bool = False, safe: bool = False,
# Text generation parameters # Text generation parameters
images: ImagesType = None, media: MediaListType = None,
temperature: float = None, temperature: float = None,
presence_penalty: float = None, presence_penalty: float = None,
top_p: float = 1, top_p: float = 1,
frequency_penalty: float = None, frequency_penalty: float = None,
response_format: Optional[dict] = None, response_format: Optional[dict] = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice"], extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice", "modalities"],
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
# Load model list # Load model list
cls.get_models() cls.get_models()
if not model and media is not None:
has_audio = False
for media_data, filename in media:
if is_data_an_audio(media_data, filename):
has_audio = True
break
model = next(iter(cls.audio_models)) if has_audio else cls.default_vision_model
try: try:
model = cls.get_model(model) model = cls.get_model(model)
except ModelNotFoundError: except ModelNotFoundError:
@@ -182,7 +189,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async for result in cls._generate_text( async for result in cls._generate_text(
model=model, model=model,
messages=messages, messages=messages,
images=images, media=media,
proxy=proxy, proxy=proxy,
temperature=temperature, temperature=temperature,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
@@ -239,7 +246,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
images: Optional[ImagesType], media: MediaListType,
proxy: str, proxy: str,
temperature: float, temperature: float,
presence_penalty: float, presence_penalty: float,
@@ -258,14 +265,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if response_format and response_format.get("type") == "json_object": if response_format and response_format.get("type") == "json_object":
json_mode = True json_mode = True
if images and messages: if media and messages:
last_message = messages[-1].copy() last_message = messages[-1].copy()
image_content = [ image_content = [
{ {
"type": "image_url", "type": "input_audio",
"image_url": {"url": to_data_uri(image)} "input_audio": to_input_audio(media_data, filename)
} }
for image, _ in images if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
] ]
last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}] last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
messages[-1] = last_message messages[-1] = last_message

View File

@@ -17,7 +17,7 @@ except ImportError:
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
from ..helper import format_prompt, format_image_prompt, get_last_user_message from ..helper import format_prompt, format_image_prompt, get_last_user_message
from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
from ...image import to_bytes from ...image import to_bytes
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
@@ -99,7 +99,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
messages: Messages, messages: Messages,
auth_result: AuthResult, auth_result: AuthResult,
prompt: str = None, prompt: str = None,
images: ImagesType = None, media: MediaListType = None,
return_conversation: bool = False, return_conversation: bool = False,
conversation: Conversation = None, conversation: Conversation = None,
web_search: bool = False, web_search: bool = False,
@@ -108,7 +108,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
if not has_curl_cffi: if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi') raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
if model == llama_models["name"]: if model == llama_models["name"]:
model = llama_models["text"] if images is None else llama_models["vision"] model = llama_models["text"] if media is None else llama_models["vision"]
model = cls.get_model(model) model = cls.get_model(model)
session = Session(**auth_result.get_dict()) session = Session(**auth_result.get_dict())
@@ -145,8 +145,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
} }
data = CurlMime() data = CurlMime()
data.addpart('data', data=json.dumps(settings, separators=(',', ':'))) data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
if images is not None: if media is not None:
for image, filename in images: for image, filename in media:
data.addpart( data.addpart(
"files", "files",
filename=f"base64;{filename}", filename=f"base64;{filename}",

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import requests import requests
from ...providers.types import Messages from ...providers.types import Messages
from ...typing import ImagesType from ...typing import MediaListType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message from ...providers.helper import get_last_user_message
@@ -75,11 +75,11 @@ class HuggingFaceAPI(OpenaiTemplate):
api_key: str = None, api_key: str = None,
max_tokens: int = 2048, max_tokens: int = 2048,
max_inputs_lenght: int = 10000, max_inputs_lenght: int = 10000,
images: ImagesType = None, media: MediaListType = None,
**kwargs **kwargs
): ):
if model == llama_models["name"]: if model == llama_models["name"]:
model = llama_models["text"] if images is None else llama_models["vision"] model = llama_models["text"] if media is None else llama_models["vision"]
if model in cls.model_aliases: if model in cls.model_aliases:
model = cls.model_aliases[model] model = cls.model_aliases[model]
provider_mapping = await cls.get_mapping(model, api_key) provider_mapping = await cls.get_mapping(model, api_key)
@@ -103,7 +103,7 @@ class HuggingFaceAPI(OpenaiTemplate):
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght: if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = last_user_message messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}") debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs): async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
yield chunk yield chunk
def calculate_lenght(messages: Messages) -> int: def calculate_lenght(messages: Messages) -> int:

View File

@@ -7,7 +7,7 @@ import random
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
import urllib.parse import urllib.parse
from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, format_image_prompt from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation, ImageResponse, Reasoning from ...providers.response import JsonConversation, ImageResponse, Reasoning
@@ -68,7 +68,7 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
images: ImagesType = None, media: MediaListType = None,
prompt: str = None, prompt: str = None,
proxy: str = None, proxy: str = None,
cookies: Cookies = None, cookies: Cookies = None,
@@ -98,27 +98,27 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
if return_conversation: if return_conversation:
yield conversation yield conversation
if images is not None: if media is not None:
data = FormData() data = FormData()
for i in range(len(images)): for i in range(len(media)):
images[i] = (to_bytes(images[i][0]), images[i][1]) media[i] = (to_bytes(media[i][0]), media[i][1])
for image, image_name in images: for image, image_name in media:
data.add_field(f"files", image, filename=image_name) data.add_field(f"files", image, filename=image_name)
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response: async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
await raise_for_status(response) await raise_for_status(response)
image_files = await response.json() image_files = await response.json()
images = [{ media = [{
"path": image_file, "path": image_file,
"url": f"{cls.api_url}/gradio_api/file={image_file}", "url": f"{cls.api_url}/gradio_api/file={image_file}",
"orig_name": images[i][1], "orig_name": media[i][1],
"size": len(images[i][0]), "size": len(media[i][0]),
"mime_type": is_accepted_format(images[i][0]), "mime_type": is_accepted_format(media[i][0]),
"meta": { "meta": {
"_type": "gradio.FileData" "_type": "gradio.FileData"
} }
} for i, image_file in enumerate(image_files)] } for i, image_file in enumerate(image_files)]
async with cls.run(method, session, prompt, conversation, None if images is None else images.pop(), seed) as response: async with cls.run(method, session, prompt, conversation, None if media is None else media.pop(), seed) as response:
await raise_for_status(response) await raise_for_status(response)
async with cls.run("get", session, prompt, conversation, None, seed) as response: async with cls.run("get", session, prompt, conversation, None, seed) as response:

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
import json import json
import uuid import uuid
from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, format_image_prompt from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation from ...providers.response import JsonConversation
from ...requests.aiohttp import StreamSession, StreamResponse, FormData from ...requests.aiohttp import StreamSession, StreamResponse, FormData
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...image import to_bytes, is_accepted_format, is_data_an_wav from ...image import to_bytes, is_accepted_format, is_data_an_audio
from ...errors import ResponseError from ...errors import ResponseError
from ... import debug from ... import debug
from .DeepseekAI_JanusPro7b import get_zerogpu_token from .DeepseekAI_JanusPro7b import get_zerogpu_token
@@ -32,7 +32,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
models = [default_model] models = [default_model]
@classmethod @classmethod
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, images: list = None): def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, media: list = None):
headers = { headers = {
"content-type": "application/json", "content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token, "x-zerogpu-token": conversation.zerogpu_token,
@@ -47,7 +47,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
[], [],
{ {
"text": prompt, "text": prompt,
"files": images, "files": media,
}, },
None None
], ],
@@ -70,7 +70,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
{ {
"role": "user", "role": "user",
"content": {"file": image} "content": {"file": image}
} for image in images } for image in media
]], ]],
"event_data": None, "event_data": None,
"fn_index": 11, "fn_index": 11,
@@ -91,7 +91,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
images: ImagesType = None, media: MediaListType = None,
prompt: str = None, prompt: str = None,
proxy: str = None, proxy: str = None,
cookies: Cookies = None, cookies: Cookies = None,
@@ -115,23 +115,23 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
if return_conversation: if return_conversation:
yield conversation yield conversation
if images is not None: if media is not None:
data = FormData() data = FormData()
mime_types = [None for i in range(len(images))] mime_types = [None for i in range(len(media))]
for i in range(len(images)): for i in range(len(media)):
mime_types[i] = is_data_an_wav(images[i][0], images[i][1]) mime_types[i] = is_data_an_audio(media[i][0], media[i][1])
images[i] = (to_bytes(images[i][0]), images[i][1]) media[i] = (to_bytes(media[i][0]), media[i][1])
mime_types[i] = is_accepted_format(images[i][0]) if mime_types[i] is None else mime_types[i] mime_types[i] = is_accepted_format(media[i][0]) if mime_types[i] is None else mime_types[i]
for image, image_name in images: for image, image_name in media:
data.add_field(f"files", to_bytes(image), filename=image_name) data.add_field(f"files", to_bytes(image), filename=image_name)
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response: async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
await raise_for_status(response) await raise_for_status(response)
image_files = await response.json() image_files = await response.json()
images = [{ media = [{
"path": image_file, "path": image_file,
"url": f"{cls.api_url}/gradio_api/file={image_file}", "url": f"{cls.api_url}/gradio_api/file={image_file}",
"orig_name": images[i][1], "orig_name": media[i][1],
"size": len(images[i][0]), "size": len(media[i][0]),
"mime_type": mime_types[i], "mime_type": mime_types[i],
"meta": { "meta": {
"_type": "gradio.FileData" "_type": "gradio.FileData"
@@ -139,10 +139,10 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
} for i, image_file in enumerate(image_files)] } for i, image_file in enumerate(image_files)]
async with cls.run("predict", session, prompt, conversation, images) as response: async with cls.run("predict", session, prompt, conversation, media) as response:
await raise_for_status(response) await raise_for_status(response)
async with cls.run("post", session, prompt, conversation, images) as response: async with cls.run("post", session, prompt, conversation, media) as response:
await raise_for_status(response) await raise_for_status(response)
async with cls.run("get", session, prompt, conversation) as response: async with cls.run("get", session, prompt, conversation) as response:

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
from aiohttp import ClientSession, FormData from aiohttp import ClientSession, FormData
from ...typing import AsyncResult, Messages, ImagesType from ...typing import AsyncResult, Messages, MediaListType
from ...requests import raise_for_status from ...requests import raise_for_status
from ...errors import ResponseError from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -25,7 +25,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
cls, model: str, messages: Messages, cls, model: str, messages: Messages,
images: ImagesType = None, media: MediaListType = None,
api_key: str = None, api_key: str = None,
proxy: str = None, proxy: str = None,
**kwargs **kwargs
@@ -36,10 +36,10 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None: if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
if images: if media:
data = FormData() data = FormData()
data_bytes = to_bytes(images[0][0]) data_bytes = to_bytes(media[0][0])
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=images[0][1]) data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=media[0][1])
url = f"{cls.url}/gradio_api/upload?upload_id={get_random_string()}" url = f"{cls.url}/gradio_api/upload?upload_id={get_random_string()}"
async with session.post(url, data=data, proxy=proxy) as response: async with session.post(url, data=data, proxy=proxy) as response:
await raise_for_status(response) await raise_for_status(response)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import random import random
from ...typing import AsyncResult, Messages, ImagesType from ...typing import AsyncResult, Messages, MediaListType
from ...errors import ResponseError from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -67,15 +67,15 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
cls, model: str, messages: Messages, images: ImagesType = None, **kwargs cls, model: str, messages: Messages, media: MediaListType = None, **kwargs
) -> AsyncResult: ) -> AsyncResult:
if not model and images is not None: if not model and media is not None:
model = cls.default_vision_model model = cls.default_vision_model
is_started = False is_started = False
random.shuffle(cls.providers) random.shuffle(cls.providers)
for provider in cls.providers: for provider in cls.providers:
if model in provider.model_aliases: if model in provider.model_aliases:
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, images=images, **kwargs): async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, media=media, **kwargs):
is_started = True is_started = True
yield chunk yield chunk
if is_started: if is_started:

View File

@@ -6,7 +6,7 @@ import base64
from typing import Optional from typing import Optional
from ..helper import filter_none from ..helper import filter_none
from ...typing import AsyncResult, Messages, ImagesType from ...typing import AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage from ...providers.response import FinishReason, ToolCalls, Usage
from ...errors import MissingAuthError from ...errors import MissingAuthError
@@ -62,7 +62,7 @@ class Anthropic(OpenaiAPI):
messages: Messages, messages: Messages,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
images: ImagesType = None, media: MediaListType = None,
api_key: str = None, api_key: str = None,
temperature: float = None, temperature: float = None,
max_tokens: int = 4096, max_tokens: int = 4096,
@@ -79,9 +79,9 @@ class Anthropic(OpenaiAPI):
if api_key is None: if api_key is None:
raise MissingAuthError('Add a "api_key"') raise MissingAuthError('Add a "api_key"')
if images is not None: if media is not None:
insert_images = [] insert_images = []
for image, _ in images: for image, _ in media:
data = to_bytes(image) data = to_bytes(image)
insert_images.append({ insert_images.append({
"type": "image", "type": "image",

View File

@@ -19,7 +19,7 @@ except ImportError:
has_nodriver = False has_nodriver = False
from ... import debug from ... import debug
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector from ...requests.aiohttp import get_connector
@@ -149,7 +149,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None, proxy: str = None,
cookies: Cookies = None, cookies: Cookies = None,
connector: BaseConnector = None, connector: BaseConnector = None,
images: ImagesType = None, media: MediaListType = None,
return_conversation: bool = False, return_conversation: bool = False,
conversation: Conversation = None, conversation: Conversation = None,
language: str = "en", language: str = "en",
@@ -186,7 +186,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
cls.start_auto_refresh() cls.start_auto_refresh()
) )
images = await cls.upload_images(base_connector, images) if images else None uploads = None if media is None else await cls.upload_images(base_connector, media)
async with ClientSession( async with ClientSession(
cookies=cls._cookies, cookies=cls._cookies,
headers=REQUEST_HEADERS, headers=REQUEST_HEADERS,
@@ -205,7 +205,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
prompt, prompt,
language=language, language=language,
conversation=conversation, conversation=conversation,
images=images uploads=uploads
))]) ))])
} }
async with client.post( async with client.post(
@@ -327,10 +327,10 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str, prompt: str,
language: str, language: str,
conversation: Conversation = None, conversation: Conversation = None,
images: list[list[str, str]] = None, uploads: list[list[str, str]] = None,
tools: list[list[str]] = [] tools: list[list[str]] = []
) -> list: ) -> list:
image_list = [[[image_url, 1], image_name] for image_url, image_name in images] if images else [] image_list = [[[image_url, 1], image_name] for image_url, image_name in uploads] if uploads else []
return [ return [
[prompt, 0, None, image_list, None, None, 0], [prompt, 0, None, image_list, None, None, 0],
[language], [language],
@@ -353,7 +353,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
0, 0,
] ]
async def upload_images(connector: BaseConnector, images: ImagesType) -> list: async def upload_images(connector: BaseConnector, media: MediaListType) -> list:
async def upload_image(image: bytes, image_name: str = None): async def upload_image(image: bytes, image_name: str = None):
async with ClientSession( async with ClientSession(
headers=UPLOAD_IMAGE_HEADERS, headers=UPLOAD_IMAGE_HEADERS,
@@ -385,7 +385,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
) as response: ) as response:
await raise_for_status(response) await raise_for_status(response)
return [await response.text(), image_name] return [await response.text(), image_name]
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in images]) return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in media])
@classmethod @classmethod
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies): async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):

View File

@@ -6,8 +6,8 @@ import requests
from typing import Optional from typing import Optional
from aiohttp import ClientSession, BaseConnector from aiohttp import ClientSession, BaseConnector
from ...typing import AsyncResult, Messages, ImagesType from ...typing import AsyncResult, Messages, MediaListType
from ...image import to_bytes, is_accepted_format from ...image import to_bytes, is_data_an_media
from ...errors import MissingAuthError from ...errors import MissingAuthError
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...providers.response import Usage, FinishReason from ...providers.response import Usage, FinishReason
@@ -67,7 +67,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
api_key: str = None, api_key: str = None,
api_base: str = api_base, api_base: str = api_base,
use_auth_header: bool = False, use_auth_header: bool = False,
images: ImagesType = None, media: MediaListType = None,
tools: Optional[list] = None, tools: Optional[list] = None,
connector: BaseConnector = None, connector: BaseConnector = None,
**kwargs **kwargs
@@ -94,13 +94,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
for message in messages for message in messages
if message["role"] != "system" if message["role"] != "system"
] ]
if images is not None: if media is not None:
for image, _ in images: for media_data, filename in media:
image = to_bytes(image) image = to_bytes(image)
contents[-1]["parts"].append({ contents[-1]["parts"].append({
"inline_data": { "inline_data": {
"mime_type": is_accepted_format(image), "mime_type": is_data_an_media(image, filename),
"data": base64.b64encode(image).decode() "data": base64.b64encode(media_data).decode()
} }
}) })
data = { data = {

View File

@@ -1,51 +1,56 @@
from __future__ import annotations
import os import os
import json import json
import random
import re
import base64
import asyncio
import time import time
from urllib.parse import quote_plus, unquote_plus from typing import Dict, Any, AsyncIterator
from pathlib import Path
from aiohttp import ClientSession, BaseConnector
from typing import Dict, Any, Optional, AsyncIterator, List
from ... import debug from ...typing import Messages, Cookies, AsyncResult
from ...typing import Messages, Cookies, ImagesType, AsyncResult from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration, AuthResult, RequestLogin
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration from ...requests import StreamSession, get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ...requests import get_nodriver
from ...errors import MissingAuthError
from ...cookies import get_cookies_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies, get_last_user_message from ..helper import format_prompt, get_cookies, get_last_user_message
class Conversation(JsonConversation): class Conversation(JsonConversation):
def __init__(self, def __init__(self,
conversation_id: str, conversation_id: str
response_id: str,
choice_id: str,
model: str
) -> None: ) -> None:
self.conversation_id = conversation_id self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id
self.model = model
class Grok(AsyncGeneratorProvider, ProviderModelMixin): class Grok(AsyncAuthedProvider, ProviderModelMixin):
label = "Grok AI" label = "Grok AI"
url = "https://grok.com" url = "https://grok.com"
cookie_domain = ".grok.com"
assets_url = "https://assets.grok.com" assets_url = "https://assets.grok.com"
conversation_url = "https://grok.com/rest/app-chat/conversations" conversation_url = "https://grok.com/rest/app-chat/conversations"
needs_auth = True needs_auth = True
working = False working = True
default_model = "grok-3" default_model = "grok-3"
models = [default_model, "grok-3-thinking", "grok-2"] models = [default_model, "grok-3-thinking", "grok-2"]
_cookies: Cookies = None @classmethod
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
if cookies is None:
cookies = get_cookies(cls.cookie_domain, False, True, False)
if cookies is not None and "sso" in cookies:
yield AuthResult(
cookies=cookies,
impersonate="chrome",
proxy=proxy,
headers=DEFAULT_HEADERS
)
return
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
yield AuthResult(
**await get_args_from_nodriver(
cls.url,
proxy=proxy,
wait_for='[href="/chat#private"]'
)
)
@classmethod @classmethod
async def _prepare_payload(cls, model: str, message: str) -> Dict[str, Any]: async def _prepare_payload(cls, model: str, message: str) -> Dict[str, Any]:
@@ -72,63 +77,43 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
} }
@classmethod @classmethod
async def create_async_generator( async def create_authed(
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
proxy: str = None, auth_result: AuthResult,
cookies: Cookies = None, cookies: Cookies = None,
connector: BaseConnector = None,
images: ImagesType = None,
return_conversation: bool = False, return_conversation: bool = False,
conversation: Optional[Conversation] = None, conversation: Conversation = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
cls._cookies = cookies or cls._cookies or get_cookies(".grok.com", False, True) conversation_id = None if conversation is None else conversation.conversation_id
if not cls._cookies: prompt = format_prompt(messages) if conversation_id is None else get_last_user_message(messages)
raise MissingAuthError("Missing required cookies") async with StreamSession(
**auth_result.get_dict()
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
base_connector = get_connector(connector, proxy)
headers = {
"accept": "*/*",
"accept-language": "en-GB,en;q=0.9",
"content-type": "application/json",
"origin": "https://grok.com",
"priority": "u=1, i",
"referer": "https://grok.com/",
"sec-ch-ua": '"Not/A)Brand";v="8", "Chromium";v="126", "Brave";v="126"',
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": '"macOS"',
"sec-fetch-dest": "empty",
"sec-fetch-mode": "cors",
"sec-fetch-site": "same-origin",
"sec-gpc": "1",
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
}
async with ClientSession(
headers=headers,
cookies=cls._cookies,
connector=base_connector
) as session: ) as session:
payload = await cls._prepare_payload(model, prompt) payload = await cls._prepare_payload(model, prompt)
response = await session.post(f"{cls.conversation_url}/new", json=payload) if conversation_id is None:
url = f"{cls.conversation_url}/new"
else:
url = f"{cls.conversation_url}/{conversation_id}/responses"
async with session.post(url, json=payload) as response:
await raise_for_status(response) await raise_for_status(response)
thinking_duration = None thinking_duration = None
async for line in response.content: async for line in response.iter_lines():
if line: if line:
try: try:
json_data = json.loads(line) json_data = json.loads(line)
result = json_data.get("result", {}) result = json_data.get("result", {})
if conversation_id is None:
conversation_id = result.get("conversation", {}).get("conversationId")
response_data = result.get("response", {}) response_data = result.get("response", {})
image = response_data.get("streamingImageGenerationResponse", None) image = response_data.get("streamingImageGenerationResponse", None)
if image is not None: if image is not None:
yield ImagePreview(f'{cls.assets_url}/{image["imageUrl"]}', "", {"cookies": cookies, "headers": headers}) yield ImagePreview(f'{cls.assets_url}/{image["imageUrl"]}', "", {"cookies": cookies, "headers": headers})
token = response_data.get("token", "") token = response_data.get("token", result.get("token"))
is_thinking = response_data.get("isThinking", False) is_thinking = response_data.get("isThinking", result.get("isThinking"))
if token: if token:
if is_thinking: if is_thinking:
if thinking_duration is None: if thinking_duration is None:
@@ -145,9 +130,12 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
generated_images = response_data.get("modelResponse", {}).get("generatedImageUrls", None) generated_images = response_data.get("modelResponse", {}).get("generatedImageUrls", None)
if generated_images: if generated_images:
yield ImageResponse([f'{cls.assets_url}/{image}' for image in generated_images], "", {"cookies": cookies, "headers": headers}) yield ImageResponse([f'{cls.assets_url}/{image}' for image in generated_images], "", {"cookies": cookies, "headers": headers})
title = response_data.get("title", {}).get("newTitle", "") title = result.get("title", {}).get("newTitle", "")
if title: if title:
yield TitleGeneration(title) yield TitleGeneration(title)
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
if return_conversation and conversation_id is not None:
yield Conversation(conversation_id)

View File

@@ -18,7 +18,7 @@ except ImportError:
has_nodriver = False has_nodriver = False
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession from ...requests import StreamSession
from ...requests import get_nodriver from ...requests import get_nodriver
@@ -127,7 +127,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls, cls,
session: StreamSession, session: StreamSession,
auth_result: AuthResult, auth_result: AuthResult,
images: ImagesType, media: MediaListType,
) -> ImageRequest: ) -> ImageRequest:
""" """
Upload an image to the service and get the download URL Upload an image to the service and get the download URL
@@ -135,7 +135,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
Args: Args:
session: The StreamSession object to use for requests session: The StreamSession object to use for requests
headers: The headers to include in the requests headers: The headers to include in the requests
images: The images to upload, either a PIL Image object or a bytes object media: The images to upload, either a PIL Image object or a bytes object
Returns: Returns:
An ImageRequest object that contains the download URL, file name, and other data An ImageRequest object that contains the download URL, file name, and other data
@@ -187,9 +187,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await raise_for_status(response, "Get download url failed") await raise_for_status(response, "Get download url failed")
image_data["download_url"] = (await response.json())["download_url"] image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data) return ImageRequest(image_data)
if not images: if not media:
return return
return [await upload_image(image, image_name) for image, image_name in images] return [await upload_image(image, image_name) for image, image_name in media]
@classmethod @classmethod
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None): def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
@@ -268,7 +268,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
auto_continue: bool = False, auto_continue: bool = False,
action: str = "next", action: str = "next",
conversation: Conversation = None, conversation: Conversation = None,
images: ImagesType = None, media: MediaListType = None,
return_conversation: bool = False, return_conversation: bool = False,
web_search: bool = False, web_search: bool = False,
**kwargs **kwargs
@@ -285,7 +285,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
auto_continue (bool): Flag to automatically continue the conversation. auto_continue (bool): Flag to automatically continue the conversation.
action (str): Type of action ('next', 'continue', 'variant'). action (str): Type of action ('next', 'continue', 'variant').
conversation_id (str): ID of the conversation. conversation_id (str): ID of the conversation.
images (ImagesType): Images to include in the conversation. media (MediaListType): Images to include in the conversation.
return_conversation (bool): Flag to include response fields in the output. return_conversation (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@@ -316,7 +316,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls._update_request_args(auth_result, session) cls._update_request_args(auth_result, session)
await raise_for_status(response) await raise_for_status(response)
try: try:
image_requests = await cls.upload_images(session, auth_result, images) if images else None image_requests = None if media is None else await cls.upload_images(session, auth_result, media)
except Exception as e: except Exception as e:
debug.error("OpenaiChat: Upload image failed") debug.error("OpenaiChat: Upload image failed")
debug.error(e) debug.error(e)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json import json
from ...typing import Messages, AsyncResult, ImagesType from ...typing import Messages, AsyncResult, MediaListType
from ...requests import StreamSession from ...requests import StreamSession
from ...image import to_data_uri from ...image import to_data_uri
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -18,21 +18,21 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
images: ImagesType = None, media: MediaListType = None,
api_key: str = None, api_key: str = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
debug.log(f"{cls.__name__}: {api_key}") debug.log(f"{cls.__name__}: {api_key}")
if images is not None: if media is not None:
for i in range(len(images)): for i in range(len(media)):
images[i] = (to_data_uri(images[i][0]), images[i][1]) media[i] = (to_data_uri(media[i][0]), media[i][1])
async with StreamSession( async with StreamSession(
headers={"Accept": "text/event-stream", **cls.headers}, headers={"Accept": "text/event-stream", **cls.headers},
) as session: ) as session:
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={ async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
"model": model, "model": model,
"messages": messages, "messages": messages,
"images": images, "media": media,
"api_key": api_key, "api_key": api_key,
**kwargs **kwargs
}, ssl=cls.ssl) as response: }, ssl=cls.ssl) as response:

View File

@@ -5,11 +5,11 @@ import requests
from ..helper import filter_none, format_image_prompt from ..helper import filter_none, format_image_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
from ...errors import MissingAuthError, ResponseError from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri from ...image import to_data_uri, is_data_an_audio, to_input_audio
from ... import debug from ... import debug
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin): class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
@@ -54,7 +54,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
messages: Messages, messages: Messages,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
images: ImagesType = None, media: MediaListType = None,
api_key: str = None, api_key: str = None,
api_endpoint: str = None, api_endpoint: str = None,
api_base: str = None, api_base: str = None,
@@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
prompt: str = None, prompt: str = None,
headers: dict = None, headers: dict = None,
impersonate: str = None, impersonate: str = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"], extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
extra_data: dict = {}, extra_data: dict = {},
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
@@ -98,19 +98,26 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
yield ImageResponse([image["url"] for image in data["data"]], prompt) yield ImageResponse([image["url"] for image in data["data"]], prompt)
return return
if images is not None and messages: if media is not None and messages:
if not model and hasattr(cls, "default_vision_model"): if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model model = cls.default_vision_model
last_message = messages[-1].copy() last_message = messages[-1].copy()
last_message["content"] = [ last_message["content"] = [
*[{ *[
{
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url", "type": "image_url",
"image_url": {"url": to_data_uri(image)} "image_url": {"url": to_data_uri(media_data, filename)}
} for image, _ in images], }
for media_data, filename in media
],
{ {
"type": "text", "type": "text",
"text": messages[-1]["content"] "text": last_message["content"]
} } if isinstance(last_message["content"], str) else last_message["content"]
] ]
messages[-1] = last_message messages[-1] = last_message
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs} extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}

View File

@@ -34,12 +34,14 @@ class ChatCompletion:
ignore_stream: bool = False, ignore_stream: bool = False,
**kwargs) -> Union[CreateResult, str]: **kwargs) -> Union[CreateResult, str]:
if image is not None: if image is not None:
kwargs["images"] = [(image, image_name)] kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, provider, stream, model, provider, stream,
ignore_working, ignore_working,
ignore_stream, ignore_stream,
has_images="images" in kwargs, has_images="media" in kwargs,
) )
if "proxy" not in kwargs: if "proxy" not in kwargs:
proxy = os.environ.get("G4F_PROXY") proxy = os.environ.get("G4F_PROXY")
@@ -63,8 +65,10 @@ class ChatCompletion:
ignore_working: bool = False, ignore_working: bool = False,
**kwargs) -> Union[AsyncResult, Coroutine[str]]: **kwargs) -> Union[AsyncResult, Coroutine[str]]:
if image is not None: if image is not None:
kwargs["images"] = [(image, image_name)] kwargs["media"] = [(image, image_name)]
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="images" in kwargs) elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="media" in kwargs)
if "proxy" not in kwargs: if "proxy" not in kwargs:
proxy = os.environ.get("G4F_PROXY") proxy = os.environ.get("G4F_PROXY")
if proxy: if proxy:

View File

@@ -38,7 +38,7 @@ import g4f.debug
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
from g4f.providers.response import BaseConversation, JsonConversation from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none from g4f.client.helper import filter_none
from g4f.image import is_data_uri_an_media from g4f.image import is_data_an_media
from g4f.image.copy_images import images_dir, copy_images, get_source_url from g4f.image.copy_images import images_dir, copy_images, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.cookies import read_cookie_files, get_cookies_dir
@@ -320,16 +320,18 @@ class Api:
if config.image is not None: if config.image is not None:
try: try:
is_data_uri_an_media(config.image) is_data_an_media(config.image)
except ValueError as e: except ValueError as e:
return ErrorResponse.from_message(f"The image you send must be a data URI. Example: data:image/jpeg;base64,...", status_code=HTTP_422_UNPROCESSABLE_ENTITY) return ErrorResponse.from_message(f"The image you send must be a data URI. Example: data:image/jpeg;base64,...", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
if config.images is not None: if config.media is None:
for image in config.images: config.media = config.images
if config.media is not None:
for image in config.media:
try: try:
is_data_uri_an_media(image[0]) is_data_an_media(image[0])
except ValueError as e: except ValueError as e:
example = json.dumps({"images": [["data:image/jpeg;base64,...", "filename"]]}) example = json.dumps({"media": [["data:image/jpeg;base64,...", "filename.jpg"]]})
return ErrorResponse.from_message(f'The image you send must be a data URI. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY) return ErrorResponse.from_message(f'The media you send must be a data URIs. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
# Create the completion response # Create the completion response
response = self.client.chat.completions.create( response = self.client.chat.completions.create(

View File

@@ -17,6 +17,7 @@ class ChatCompletionsConfig(BaseModel):
image: Optional[str] = None image: Optional[str] = None
image_name: Optional[str] = None image_name: Optional[str] = None
images: Optional[list[tuple[str, str]]] = None images: Optional[list[tuple[str, str]]] = None
media: Optional[list[tuple[str, str]]] = None
temperature: Optional[float] = None temperature: Optional[float] = None
presence_penalty: Optional[float] = None presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None frequency_penalty: Optional[float] = None

View File

@@ -293,14 +293,16 @@ class Completions:
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
if image is not None: if image is not None:
kwargs["images"] = [(image, image_name)] kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, model,
self.provider if provider is None else provider, self.provider if provider is None else provider,
stream, stream,
ignore_working, ignore_working,
ignore_stream, ignore_stream,
has_images="images" in kwargs has_images="media" in kwargs
) )
stop = [stop] if isinstance(stop, str) else stop stop = [stop] if isinstance(stop, str) else stop
if ignore_stream: if ignore_stream:
@@ -481,7 +483,7 @@ class Images:
proxy = self.client.proxy proxy = self.client.proxy
prompt = "create a variation of this image" prompt = "create a variation of this image"
if image is not None: if image is not None:
kwargs["images"] = [(image, None)] kwargs["media"] = [(image, None)]
error = None error = None
response = None response = None
@@ -581,14 +583,16 @@ class AsyncCompletions:
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
if image is not None: if image is not None:
kwargs["images"] = [(image, image_name)] kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, model,
self.provider if provider is None else provider, self.provider if provider is None else provider,
stream, stream,
ignore_working, ignore_working,
ignore_stream, ignore_stream,
has_images="images" in kwargs, has_images="media" in kwargs,
) )
stop = [stop] if isinstance(stop, str) else stop stop = [stop] if isinstance(stop, str) else stop
if ignore_stream: if ignore_stream:

View File

@@ -67,7 +67,7 @@ DOMAINS = [
if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null": if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
_LinuxPasswordManager.get_password = lambda a, b: b"secret" _LinuxPasswordManager.get_password = lambda a, b: b"secret"
def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, single_browser: bool = False) -> Dict[str, str]: def get_cookies(domain_name: str, raise_requirements_error: bool = True, single_browser: bool = False, cache_result: bool = True) -> Dict[str, str]:
""" """
Load cookies for a given domain from all supported browsers and cache the results. Load cookies for a given domain from all supported browsers and cache the results.
@@ -77,10 +77,11 @@ def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, si
Returns: Returns:
Dict[str, str]: A dictionary of cookie names and values. Dict[str, str]: A dictionary of cookie names and values.
""" """
if domain_name in CookiesConfig.cookies: if cache_result and domain_name in CookiesConfig.cookies:
return CookiesConfig.cookies[domain_name] return CookiesConfig.cookies[domain_name]
cookies = load_cookies_from_browsers(domain_name, raise_requirements_error, single_browser) cookies = load_cookies_from_browsers(domain_name, raise_requirements_error, single_browser)
if cache_result:
CookiesConfig.cookies[domain_name] = cookies CookiesConfig.cookies[domain_name] = cookies
return cookies return cookies
@@ -108,8 +109,8 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
for cookie_fn in browsers: for cookie_fn in browsers:
try: try:
cookie_jar = cookie_fn(domain_name=domain_name) cookie_jar = cookie_fn(domain_name=domain_name)
if len(cookie_jar) and debug.logging: if len(cookie_jar):
print(f"Read cookies from {cookie_fn.__name__} for {domain_name}") debug.log(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
for cookie in cookie_jar: for cookie in cookie_jar:
if cookie.name not in cookies: if cookie.name not in cookies:
if not cookie.expires or cookie.expires > time.time(): if not cookie.expires or cookie.expires > time.time():
@@ -119,8 +120,7 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
except BrowserCookieError: except BrowserCookieError:
pass pass
except Exception as e: except Exception as e:
if debug.logging: debug.error(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
return cookies return cookies
def set_cookies_dir(dir: str) -> None: def set_cookies_dir(dir: str) -> None:

View File

@@ -86,6 +86,23 @@
height: 100%; height: 100%;
text-align: center; text-align: center;
z-index: 1; z-index: 1;
transition: transform 0.25s ease-in;
}
.container.slide {
transform: translateX(-100%);
transition: transform 0.15s ease-out;
}
.slide-button {
position: absolute;
top: 20px;
left: 20px;
background: var(--colour-4);
border: none;
border-radius: 4px;
cursor: pointer;
padding: 6px 8px;
} }
header { header {
@@ -105,7 +122,8 @@
height: 100%; height: 100%;
position: absolute; position: absolute;
z-index: -1; z-index: -1;
object-fit: contain; object-fit: cover;
object-position: center;
width: 100%; width: 100%;
background: black; background: black;
} }
@@ -181,6 +199,7 @@
} }
} }
</style> </style>
<link rel="stylesheet" href="/static/css/all.min.css">
<script> <script>
(async () => { (async () => {
const isIframe = window.self !== window.top; const isIframe = window.self !== window.top;
@@ -208,6 +227,10 @@
<!-- Gradient Background Circle --> <!-- Gradient Background Circle -->
<div class="gradient"></div> <div class="gradient"></div>
<button class="slide-button">
<i class="fa-solid fa-arrow-left"></i>
</button>
<!-- Main Content --> <!-- Main Content -->
<div class="container"> <div class="container">
<header> <header>
@@ -277,8 +300,11 @@
if (oauthResult) { if (oauthResult) {
try { try {
oauthResult = JSON.parse(oauthResult); oauthResult = JSON.parse(oauthResult);
user = await hub.whoAmI({accessToken: oauthResult.accessToken});
} catch { } catch {
oauthResult = null; oauthResult = null;
localStorage.removeItem("oauth");
localStorage.removeItem("HuggingFace-api_key");
} }
} }
oauthResult ||= await oauthHandleRedirectIfPresent(); oauthResult ||= await oauthHandleRedirectIfPresent();
@@ -339,7 +365,7 @@
return; return;
} }
const lower = data.prompt.toLowerCase(); const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"]; const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) { for (i in tags) {
if (lower.indexOf(tags[i]) != -1) { if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]); console.log("Skipping image with tag: " + tags[i]);
@@ -363,6 +389,21 @@
imageFeed.remove(); imageFeed.remove();
} }
}, 7000); }, 7000);
const container = document.querySelector('.container');
const button = document.querySelector('.slide-button');
const slideIcon = button.querySelector('i');
button.onclick = () => {
if (container.classList.contains('slide')) {
container.classList.remove('slide');
slideIcon.classList.remove('fa-arrow-right');
slideIcon.classList.add('fa-arrow-left');
} else {
container.classList.add('slide');
slideIcon.classList.remove('fa-arrow-left');
slideIcon.classList.add('fa-arrow-right');
}
}
})(); })();
</script> </script>
</body> </body>

View File

@@ -141,7 +141,7 @@ class Api:
stream=True, stream=True,
ignore_stream=True, ignore_stream=True,
logging=False, logging=False,
has_images="images" in kwargs, has_images="media" in kwargs,
) )
except Exception as e: except Exception as e:
debug.error(e) debug.error(e)

View File

@@ -8,6 +8,7 @@ import asyncio
import shutil import shutil
import random import random
import datetime import datetime
import tempfile
from flask import Flask, Response, request, jsonify, render_template from flask import Flask, Response, request, jsonify, render_template
from typing import Generator from typing import Generator
from pathlib import Path from pathlib import Path
@@ -111,11 +112,13 @@ class Backend_Api(Api):
else: else:
json_data = request.json json_data = request.json
if "files" in request.files: if "files" in request.files:
images = [] media = []
for file in request.files.getlist('files'): for file in request.files.getlist('files'):
if file.filename != '' and is_allowed_extension(file.filename): if file.filename != '' and is_allowed_extension(file.filename):
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename)) newfile = tempfile.TemporaryFile()
json_data['images'] = images shutil.copyfileobj(file.stream, newfile)
media.append((newfile, file.filename))
json_data['media'] = media
if app.demo and not json_data.get("provider"): if app.demo and not json_data.get("provider"):
model = json_data.get("model") model = json_data.get("model")

View File

@@ -76,24 +76,24 @@ def is_allowed_extension(filename: str) -> bool:
return '.' in filename and \ return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_media(data_uri: str) -> str: def is_data_an_media(data, filename: str = None) -> str:
return is_data_an_wav(data_uri) or is_data_uri_an_image(data_uri) content_type = is_data_an_audio(data, filename)
if content_type is not None:
return content_type
if isinstance(data, bytes):
return is_accepted_format(data)
return is_data_uri_an_image(data)
def is_data_an_wav(data_uri: str, filename: str = None) -> str: def is_data_an_audio(data_uri: str, filename: str = None) -> str:
""" if filename:
Checks if the given data URI represents an image. if filename.endswith(".wav"):
Args:
data_uri (str): The data URI to check.
Raises:
ValueError: If the data URI is invalid or the image format is not allowed.
"""
if filename and filename.endswith(".wav"):
return "audio/wav"
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if isinstance(data_uri, str) and re.match(r'data:audio/wav;base64,', data_uri):
return "audio/wav" return "audio/wav"
elif filename.endswith(".mp3"):
return "audio/mpeg"
if isinstance(data_uri, str):
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
if audio_format:
return audio_format.group(1)
def is_data_uri_an_image(data_uri: str) -> bool: def is_data_uri_an_image(data_uri: str) -> bool:
""" """
@@ -218,7 +218,7 @@ def to_bytes(image: ImageType) -> bytes:
if isinstance(image, bytes): if isinstance(image, bytes):
return image return image
elif isinstance(image, str) and image.startswith("data:"): elif isinstance(image, str) and image.startswith("data:"):
is_data_uri_an_media(image) is_data_an_media(image)
return extract_data_uri(image) return extract_data_uri(image)
elif isinstance(image, Image): elif isinstance(image, Image):
bytes_io = BytesIO() bytes_io = BytesIO()
@@ -236,13 +236,29 @@ def to_bytes(image: ImageType) -> bytes:
pass pass
return image.read() return image.read()
def to_data_uri(image: ImageType) -> str: def to_data_uri(image: ImageType, filename: str = None) -> str:
if not isinstance(image, str): if not isinstance(image, str):
data = to_bytes(image) data = to_bytes(image)
data_base64 = base64.b64encode(data).decode() data_base64 = base64.b64encode(data).decode()
return f"data:{is_accepted_format(data)};base64,{data_base64}" return f"data:{is_data_an_media(data, filename)};base64,{data_base64}"
return image return image
def to_input_audio(audio: ImageType, filename: str = None) -> str:
if not isinstance(audio, str):
if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
return {
"data": base64.b64encode(to_bytes(audio)).decode(),
"format": "wav" if filename.endswith(".wav") else "mpeg"
}
raise ValueError("Invalid input audio")
audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
if audio:
return {
"data": audio.group(2),
"format": audio.group(1),
}
raise ValueError("Invalid input audio")
class ImageDataResponse(): class ImageDataResponse():
def __init__( def __init__(
self, self,

View File

@@ -109,7 +109,7 @@ async def copy_images(
f.write(chunk) f.write(chunk)
# Verify file format # Verify file format
if not os.path.splitext(target_path)[1]: if target is None and not os.path.splitext(target_path)[1]:
with open(target_path, "rb") as f: with open(target_path, "rb") as f:
file_header = f.read(12) file_header = f.read(12)
detected_type = is_accepted_format(file_header) detected_type = is_accepted_format(file_header)
@@ -120,7 +120,7 @@ async def copy_images(
# Build URL with safe encoding # Build URL with safe encoding
url_filename = quote(os.path.basename(target_path)) url_filename = quote(os.path.basename(target_path))
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}" return f"/images/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
except (ClientError, IOError, OSError) as e: except (ClientError, IOError, OSError) as e:
debug.error(f"Image copying failed: {type(e).__name__}: {e}") debug.error(f"Image copying failed: {type(e).__name__}: {e}")

View File

@@ -25,7 +25,7 @@ from .. import debug
SAFE_PARAMETERS = [ SAFE_PARAMETERS = [
"model", "messages", "stream", "timeout", "model", "messages", "stream", "timeout",
"proxy", "images", "response_format", "proxy", "media", "response_format",
"prompt", "negative_prompt", "tools", "conversation", "prompt", "negative_prompt", "tools", "conversation",
"history_disabled", "history_disabled",
"temperature", "top_k", "top_p", "temperature", "top_k", "top_p",
@@ -56,7 +56,7 @@ PARAMETER_EXAMPLES = {
"frequency_penalty": 1, "frequency_penalty": 1,
"presence_penalty": 1, "presence_penalty": 1,
"messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}], "messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
"images": [["data:image/jpeg;base64,...", "filename.jpg"]], "media": [["data:image/jpeg;base64,...", "filename.jpg"]],
"response_format": {"type": "json_object"}, "response_format": {"type": "json_object"},
"conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."}, "conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
"seed": 42, "seed": 42,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json import json
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, MediaListType
from ..client.service import get_model_and_provider from ..client.service import get_model_and_provider
from ..client.helper import filter_json from ..client.helper import filter_json
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider
@@ -17,7 +17,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool = True, stream: bool = True,
images: ImagesType = None, media: MediaListType = None,
tools: list[str] = None, tools: list[str] = None,
response_format: dict = None, response_format: dict = None,
**kwargs **kwargs
@@ -28,7 +28,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, provider, model, provider,
stream, logging=False, stream, logging=False,
has_images=images is not None has_images=media is not None
) )
if tools is not None: if tools is not None:
if len(tools) > 1: if len(tools) > 1:
@@ -49,7 +49,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
model, model,
messages, messages,
stream=stream, stream=stream,
images=images, media=media,
response_format=response_format, response_format=response_format,
**kwargs **kwargs
): ):

View File

@@ -21,7 +21,7 @@ AsyncResult = AsyncIterator[Union[str, ResponseType]]
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]] Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
Cookies = Dict[str, str] Cookies = Dict[str, str]
ImageType = Union[str, bytes, IO, Image, os.PathLike] ImageType = Union[str, bytes, IO, Image, os.PathLike]
ImagesType = List[Tuple[ImageType, Optional[str]]] MediaListType = List[Tuple[ImageType, Optional[str]]]
__all__ = [ __all__ = [
'Any', 'Any',
@@ -44,5 +44,5 @@ __all__ = [
'Cookies', 'Cookies',
'Image', 'Image',
'ImageType', 'ImageType',
'ImagesType' 'MediaListType'
] ]