Add audio transcribing example and support

Add Grok Chat provider
Rename images parameter to media
Update demo homepage
This commit is contained in:
hlohaus
2025-03-21 03:17:45 +01:00
parent 10d32a4c5f
commit c97ba0c88e
36 changed files with 407 additions and 300 deletions

View File

@@ -19,6 +19,7 @@ The G4F AsyncClient API is designed to be compatible with the OpenAI API, making
- [Text Completions](#text-completions)
- [Streaming Completions](#streaming-completions)
- [Using a Vision Model](#using-a-vision-model)
- **[Transcribing Audio with Chat Completions](#transcribing-audio-with-chat-completions)** *(New Section)*
- [Image Generation](#image-generation)
- [Advanced Usage](#advanced-usage)
- [Conversation Memory](#conversation-memory)
@@ -203,6 +204,54 @@ async def main():
asyncio.run(main())
```
---
### Transcribing Audio with Chat Completions
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:
```python
import asyncio
from g4f.client import AsyncClient
import g4f.Provider
import g4f.models
async def main():
client = AsyncClient(provider=g4f.Provider.PollinationsAI) # or g4f.Provider.Microsoft_Phi_4
with open("audio.wav", "rb") as audio_file:
response = await client.chat.completions.create(
model=g4f.models.default,
messages=[{"role": "user", "content": "Transcribe this audio"}],
media=[[audio_file, "audio.wav"]],
modalities=["text"],
)
print(response.choices[0].message.content)
if __name__ == "__main__":
asyncio.run(main())
```
#### Explanation
- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`.
- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use.
- **API Call**: The `chat.completions.create` method is called with:
- `model=g4f.models.default`: Uses the default model for the selected provider.
- `messages`: A list containing a user message instructing the model to transcribe the audio.
- `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`).
- `modalities=["text"]`: Specifies that the output should be text (the transcription).
- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed.
#### Notes
- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality.
- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider.
- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models).
This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API.
---
### Image Generation
**The `response_format` parameter is optional and can have the following values:**
- **If not specified (default):** The image will be saved locally, and a local path will be returned (e.g., "/images/1733331238_cf9d6aa9-f606-4fea-ba4b-f06576cba309.jpg").

View File

@@ -6,10 +6,10 @@ from g4f.models import __models__
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
from g4f.errors import MissingRequirementsError, MissingAuthError
class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
class TestProviderHasModel(unittest.TestCase):
cache: dict = {}
async def test_provider_has_model(self):
def test_provider_has_model(self):
for model, providers in __models__.values():
for provider in providers:
if issubclass(provider, ProviderModelMixin):
@@ -17,9 +17,9 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
model_name = provider.model_aliases[model.name]
else:
model_name = model.name
await asyncio.wait_for(self.provider_has_model(provider, model_name), 10)
self.provider_has_model(provider, model_name)
async def provider_has_model(self, provider: Type[BaseProvider], model: str):
def provider_has_model(self, provider: Type[BaseProvider], model: str):
if provider.__name__ not in self.cache:
try:
self.cache[provider.__name__] = provider.get_models()
@@ -28,7 +28,7 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
if self.cache[provider.__name__]:
self.assertIn(model, self.cache[provider.__name__], provider.__name__)
async def test_all_providers_working(self):
def test_all_providers_working(self):
for model, providers in __models__.values():
for provider in providers:
self.assertTrue(provider.working, f"{provider.__name__} in {model.name}")

View File

@@ -48,7 +48,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
"olmo-2-13b": "OLMo-2-1124-13B-Instruct",
"tulu-3-1-8b": "tulu-3-1-8b",
"tulu-3-70b": "Llama-3-1-Tulu-3-70B",
"llama-3.1-405b": "tulu-3-405b",
"llama-3.1-405b": "tulu3-405b",
"llama-3.1-8b": "tulu-3-1-8b",
"llama-3.1-70b": "Llama-3-1-Tulu-3-70B",
}

View File

@@ -11,7 +11,7 @@ from pathlib import Path
from typing import Optional
from datetime import datetime, timedelta
from ..typing import AsyncResult, Messages, ImagesType
from ..typing import AsyncResult, Messages, MediaListType
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import to_data_uri
@@ -444,7 +444,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
prompt: str = None,
proxy: str = None,
images: ImagesType = None,
media: MediaListType = None,
top_p: float = None,
temperature: float = None,
max_tokens: int = None,
@@ -479,14 +479,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
}
current_messages.append(current_msg)
if images is not None:
if media is not None:
current_messages[-1]['data'] = {
"imagesData": [
{
"filePath": f"/{image_name}",
"contents": to_data_uri(image)
}
for image, image_name in images
for image, image_name in media
],
"fileText": "",
"title": ""

View File

@@ -21,7 +21,7 @@ except ImportError:
from .base_provider import AbstractProvider, ProviderModelMixin
from .helper import format_prompt_max_length
from .openai.har_file import get_headers, get_har_files
from ..typing import CreateResult, Messages, ImagesType
from ..typing import CreateResult, Messages, MediaListType
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
from ..requests.raise_for_status import raise_for_status
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
@@ -66,7 +66,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
proxy: str = None,
timeout: int = 900,
prompt: str = None,
images: ImagesType = None,
media: MediaListType = None,
conversation: BaseConversation = None,
return_conversation: bool = False,
api_key: str = None,
@@ -77,7 +77,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
websocket_url = cls.websocket_url
headers = None
if cls.needs_auth or images is not None:
if cls.needs_auth or media is not None:
if api_key is not None:
cls._access_token = api_key
if cls._access_token is None:
@@ -142,8 +142,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
debug.log(f"Copilot: Use conversation: {conversation_id}")
uploaded_images = []
if images is not None:
for image, _ in images:
if media is not None:
for image, _ in media:
data = to_bytes(image)
response = session.post(
"https://copilot.microsoft.com/c/api/attachments",

View File

@@ -30,7 +30,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
api_endpoint = "https://duckduckgo.com/duckchat/v1/chat"
status_url = "https://duckduckgo.com/duckchat/v1/status"
working = True
working = False
supports_stream = True
supports_system_message = True
supports_message_history = True

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from ..typing import AsyncResult, Messages, ImagesType
from ..typing import AsyncResult, Messages, MediaListType
from .template import OpenaiTemplate
from ..image import to_data_uri
@@ -70,7 +70,6 @@ class DeepInfraChat(OpenaiTemplate):
temperature: float = 0.7,
max_tokens: int = None,
headers: dict = {},
images: ImagesType = None,
**kwargs
) -> AsyncResult:
headers = {
@@ -82,23 +81,6 @@ class DeepInfraChat(OpenaiTemplate):
**headers
}
if images is not None:
if not model or model not in cls.models:
model = cls.default_vision_model
if messages:
last_message = messages[-1].copy()
last_message["content"] = [
*[{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
} for image, _ in images],
{
"type": "text",
"text": last_message["content"]
}
]
messages[-1] = last_message
async for chunk in super().create_async_generator(
model,
messages,

View File

@@ -3,13 +3,12 @@ from __future__ import annotations
import json
from aiohttp import ClientSession, FormData
from ..typing import AsyncResult, Messages, ImagesType
from ..typing import AsyncResult, Messages, MediaListType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..requests.raise_for_status import raise_for_status
from ..image import to_data_uri, to_bytes, is_accepted_format
from ..image import to_bytes, is_accepted_format
from .helper import format_prompt
class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://dynaspark.onrender.com"
login_url = None
@@ -38,7 +37,7 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
images: ImagesType = None,
media: MediaListType = None,
**kwargs
) -> AsyncResult:
headers = {
@@ -49,14 +48,13 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
'x-requested-with': 'XMLHttpRequest'
}
async with ClientSession(headers=headers) as session:
form = FormData()
form.add_field('user_input', format_prompt(messages))
form.add_field('ai_model', model)
if images is not None and len(images) > 0:
image, image_name = images[0]
if media is not None and len(media) > 0:
image, image_name = media[0]
image_bytes = to_bytes(image)
form.add_field('file', image_bytes, filename=image_name, content_type=is_accepted_format(image_bytes))

View File

@@ -14,6 +14,7 @@ class LambdaChat(HuggingChat):
default_model = "deepseek-llama3.3-70b"
reasoning_model = "deepseek-r1"
image_models = []
models = []
fallback_models = [
default_model,
reasoning_model,

View File

@@ -9,8 +9,8 @@ from aiohttp import ClientSession
from .helper import filter_none, format_image_prompt
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType
from ..image import to_data_uri
from ..typing import AsyncResult, Messages, MediaListType
from ..image import to_data_uri, is_data_an_audio, to_input_audio
from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
@@ -146,17 +146,24 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance: bool = False,
safe: bool = False,
# Text generation parameters
images: ImagesType = None,
media: MediaListType = None,
temperature: float = None,
presence_penalty: float = None,
top_p: float = 1,
frequency_penalty: float = None,
response_format: Optional[dict] = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice"],
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice", "modalities"],
**kwargs
) -> AsyncResult:
# Load model list
cls.get_models()
if not model and media is not None:
has_audio = False
for media_data, filename in media:
if is_data_an_audio(media_data, filename):
has_audio = True
break
model = next(iter(cls.audio_models)) if has_audio else cls.default_vision_model
try:
model = cls.get_model(model)
except ModelNotFoundError:
@@ -182,7 +189,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async for result in cls._generate_text(
model=model,
messages=messages,
images=images,
media=media,
proxy=proxy,
temperature=temperature,
presence_penalty=presence_penalty,
@@ -239,7 +246,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
images: Optional[ImagesType],
media: MediaListType,
proxy: str,
temperature: float,
presence_penalty: float,
@@ -258,14 +265,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if response_format and response_format.get("type") == "json_object":
json_mode = True
if images and messages:
if media and messages:
last_message = messages[-1].copy()
image_content = [
{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
for image, _ in images
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
]
last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
messages[-1] = last_message

View File

@@ -17,7 +17,7 @@ except ImportError:
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
from ..helper import format_prompt, format_image_prompt, get_last_user_message
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
from ...image import to_bytes
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
@@ -99,7 +99,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
messages: Messages,
auth_result: AuthResult,
prompt: str = None,
images: ImagesType = None,
media: MediaListType = None,
return_conversation: bool = False,
conversation: Conversation = None,
web_search: bool = False,
@@ -108,7 +108,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
if model == llama_models["name"]:
model = llama_models["text"] if images is None else llama_models["vision"]
model = llama_models["text"] if media is None else llama_models["vision"]
model = cls.get_model(model)
session = Session(**auth_result.get_dict())
@@ -145,8 +145,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
}
data = CurlMime()
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
if images is not None:
for image, filename in images:
if media is not None:
for image, filename in media:
data.addpart(
"files",
filename=f"base64;{filename}",

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import requests
from ...providers.types import Messages
from ...typing import ImagesType
from ...typing import MediaListType
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
@@ -75,11 +75,11 @@ class HuggingFaceAPI(OpenaiTemplate):
api_key: str = None,
max_tokens: int = 2048,
max_inputs_lenght: int = 10000,
images: ImagesType = None,
media: MediaListType = None,
**kwargs
):
if model == llama_models["name"]:
model = llama_models["text"] if images is None else llama_models["vision"]
model = llama_models["text"] if media is None else llama_models["vision"]
if model in cls.model_aliases:
model = cls.model_aliases[model]
provider_mapping = await cls.get_mapping(model, api_key)
@@ -103,7 +103,7 @@ class HuggingFaceAPI(OpenaiTemplate):
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
yield chunk
def calculate_lenght(messages: Messages) -> int:

View File

@@ -7,7 +7,7 @@ import random
from datetime import datetime, timezone, timedelta
import urllib.parse
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation, ImageResponse, Reasoning
@@ -68,7 +68,7 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
images: ImagesType = None,
media: MediaListType = None,
prompt: str = None,
proxy: str = None,
cookies: Cookies = None,
@@ -98,27 +98,27 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
if return_conversation:
yield conversation
if images is not None:
if media is not None:
data = FormData()
for i in range(len(images)):
images[i] = (to_bytes(images[i][0]), images[i][1])
for image, image_name in images:
for i in range(len(media)):
media[i] = (to_bytes(media[i][0]), media[i][1])
for image, image_name in media:
data.add_field(f"files", image, filename=image_name)
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
await raise_for_status(response)
image_files = await response.json()
images = [{
media = [{
"path": image_file,
"url": f"{cls.api_url}/gradio_api/file={image_file}",
"orig_name": images[i][1],
"size": len(images[i][0]),
"mime_type": is_accepted_format(images[i][0]),
"orig_name": media[i][1],
"size": len(media[i][0]),
"mime_type": is_accepted_format(media[i][0]),
"meta": {
"_type": "gradio.FileData"
}
} for i, image_file in enumerate(image_files)]
async with cls.run(method, session, prompt, conversation, None if images is None else images.pop(), seed) as response:
async with cls.run(method, session, prompt, conversation, None if media is None else media.pop(), seed) as response:
await raise_for_status(response)
async with cls.run("get", session, prompt, conversation, None, seed) as response:

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
import json
import uuid
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation
from ...requests.aiohttp import StreamSession, StreamResponse, FormData
from ...requests.raise_for_status import raise_for_status
from ...image import to_bytes, is_accepted_format, is_data_an_wav
from ...image import to_bytes, is_accepted_format, is_data_an_audio
from ...errors import ResponseError
from ... import debug
from .DeepseekAI_JanusPro7b import get_zerogpu_token
@@ -32,7 +32,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
models = [default_model]
@classmethod
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, images: list = None):
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, media: list = None):
headers = {
"content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token,
@@ -47,7 +47,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
[],
{
"text": prompt,
"files": images,
"files": media,
},
None
],
@@ -70,7 +70,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
{
"role": "user",
"content": {"file": image}
} for image in images
} for image in media
]],
"event_data": None,
"fn_index": 11,
@@ -91,7 +91,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
images: ImagesType = None,
media: MediaListType = None,
prompt: str = None,
proxy: str = None,
cookies: Cookies = None,
@@ -115,23 +115,23 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
if return_conversation:
yield conversation
if images is not None:
if media is not None:
data = FormData()
mime_types = [None for i in range(len(images))]
for i in range(len(images)):
mime_types[i] = is_data_an_wav(images[i][0], images[i][1])
images[i] = (to_bytes(images[i][0]), images[i][1])
mime_types[i] = is_accepted_format(images[i][0]) if mime_types[i] is None else mime_types[i]
for image, image_name in images:
mime_types = [None for i in range(len(media))]
for i in range(len(media)):
mime_types[i] = is_data_an_audio(media[i][0], media[i][1])
media[i] = (to_bytes(media[i][0]), media[i][1])
mime_types[i] = is_accepted_format(media[i][0]) if mime_types[i] is None else mime_types[i]
for image, image_name in media:
data.add_field(f"files", to_bytes(image), filename=image_name)
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
await raise_for_status(response)
image_files = await response.json()
images = [{
media = [{
"path": image_file,
"url": f"{cls.api_url}/gradio_api/file={image_file}",
"orig_name": images[i][1],
"size": len(images[i][0]),
"orig_name": media[i][1],
"size": len(media[i][0]),
"mime_type": mime_types[i],
"meta": {
"_type": "gradio.FileData"
@@ -139,10 +139,10 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
} for i, image_file in enumerate(image_files)]
async with cls.run("predict", session, prompt, conversation, images) as response:
async with cls.run("predict", session, prompt, conversation, media) as response:
await raise_for_status(response)
async with cls.run("post", session, prompt, conversation, images) as response:
async with cls.run("post", session, prompt, conversation, media) as response:
await raise_for_status(response)
async with cls.run("get", session, prompt, conversation) as response:

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
from aiohttp import ClientSession, FormData
from ...typing import AsyncResult, Messages, ImagesType
from ...typing import AsyncResult, Messages, MediaListType
from ...requests import raise_for_status
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -25,7 +25,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
async def create_async_generator(
cls, model: str, messages: Messages,
images: ImagesType = None,
media: MediaListType = None,
api_key: str = None,
proxy: str = None,
**kwargs
@@ -36,10 +36,10 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session:
if images:
if media:
data = FormData()
data_bytes = to_bytes(images[0][0])
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=images[0][1])
data_bytes = to_bytes(media[0][0])
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=media[0][1])
url = f"{cls.url}/gradio_api/upload?upload_id={get_random_string()}"
async with session.post(url, data=data, proxy=proxy) as response:
await raise_for_status(response)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import random
from ...typing import AsyncResult, Messages, ImagesType
from ...typing import AsyncResult, Messages, MediaListType
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -67,15 +67,15 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
async def create_async_generator(
cls, model: str, messages: Messages, images: ImagesType = None, **kwargs
cls, model: str, messages: Messages, media: MediaListType = None, **kwargs
) -> AsyncResult:
if not model and images is not None:
if not model and media is not None:
model = cls.default_vision_model
is_started = False
random.shuffle(cls.providers)
for provider in cls.providers:
if model in provider.model_aliases:
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, images=images, **kwargs):
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, media=media, **kwargs):
is_started = True
yield chunk
if is_started:

View File

@@ -6,7 +6,7 @@ import base64
from typing import Optional
from ..helper import filter_none
from ...typing import AsyncResult, Messages, ImagesType
from ...typing import AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage
from ...errors import MissingAuthError
@@ -62,7 +62,7 @@ class Anthropic(OpenaiAPI):
messages: Messages,
proxy: str = None,
timeout: int = 120,
images: ImagesType = None,
media: MediaListType = None,
api_key: str = None,
temperature: float = None,
max_tokens: int = 4096,
@@ -79,9 +79,9 @@ class Anthropic(OpenaiAPI):
if api_key is None:
raise MissingAuthError('Add a "api_key"')
if images is not None:
if media is not None:
insert_images = []
for image, _ in images:
for image, _ in media:
data = to_bytes(image)
insert_images.append({
"type": "image",

View File

@@ -19,7 +19,7 @@ except ImportError:
has_nodriver = False
from ... import debug
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator
from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
@@ -149,7 +149,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None,
cookies: Cookies = None,
connector: BaseConnector = None,
images: ImagesType = None,
media: MediaListType = None,
return_conversation: bool = False,
conversation: Conversation = None,
language: str = "en",
@@ -186,7 +186,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
cls.start_auto_refresh()
)
images = await cls.upload_images(base_connector, images) if images else None
uploads = None if media is None else await cls.upload_images(base_connector, media)
async with ClientSession(
cookies=cls._cookies,
headers=REQUEST_HEADERS,
@@ -205,7 +205,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
prompt,
language=language,
conversation=conversation,
images=images
uploads=uploads
))])
}
async with client.post(
@@ -327,10 +327,10 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str,
language: str,
conversation: Conversation = None,
images: list[list[str, str]] = None,
uploads: list[list[str, str]] = None,
tools: list[list[str]] = []
) -> list:
image_list = [[[image_url, 1], image_name] for image_url, image_name in images] if images else []
image_list = [[[image_url, 1], image_name] for image_url, image_name in uploads] if uploads else []
return [
[prompt, 0, None, image_list, None, None, 0],
[language],
@@ -353,7 +353,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
0,
]
async def upload_images(connector: BaseConnector, images: ImagesType) -> list:
async def upload_images(connector: BaseConnector, media: MediaListType) -> list:
async def upload_image(image: bytes, image_name: str = None):
async with ClientSession(
headers=UPLOAD_IMAGE_HEADERS,
@@ -385,7 +385,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
) as response:
await raise_for_status(response)
return [await response.text(), image_name]
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in images])
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in media])
@classmethod
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):

View File

@@ -6,8 +6,8 @@ import requests
from typing import Optional
from aiohttp import ClientSession, BaseConnector
from ...typing import AsyncResult, Messages, ImagesType
from ...image import to_bytes, is_accepted_format
from ...typing import AsyncResult, Messages, MediaListType
from ...image import to_bytes, is_data_an_media
from ...errors import MissingAuthError
from ...requests.raise_for_status import raise_for_status
from ...providers.response import Usage, FinishReason
@@ -67,7 +67,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
api_key: str = None,
api_base: str = api_base,
use_auth_header: bool = False,
images: ImagesType = None,
media: MediaListType = None,
tools: Optional[list] = None,
connector: BaseConnector = None,
**kwargs
@@ -94,13 +94,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
for message in messages
if message["role"] != "system"
]
if images is not None:
for image, _ in images:
if media is not None:
for media_data, filename in media:
image = to_bytes(image)
contents[-1]["parts"].append({
"inline_data": {
"mime_type": is_accepted_format(image),
"data": base64.b64encode(image).decode()
"mime_type": is_data_an_media(image, filename),
"data": base64.b64encode(media_data).decode()
}
})
data = {

View File

@@ -1,51 +1,56 @@
from __future__ import annotations
import os
import json
import random
import re
import base64
import asyncio
import time
from urllib.parse import quote_plus, unquote_plus
from pathlib import Path
from aiohttp import ClientSession, BaseConnector
from typing import Dict, Any, Optional, AsyncIterator, List
from typing import Dict, Any, AsyncIterator
from ... import debug
from ...typing import Messages, Cookies, ImagesType, AsyncResult
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration
from ...typing import Messages, Cookies, AsyncResult
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration, AuthResult, RequestLogin
from ...requests import StreamSession, get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...requests import get_nodriver
from ...errors import MissingAuthError
from ...cookies import get_cookies_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies, get_last_user_message
class Conversation(JsonConversation):
def __init__(self,
conversation_id: str,
response_id: str,
choice_id: str,
model: str
conversation_id: str
) -> None:
self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id
self.model = model
class Grok(AsyncGeneratorProvider, ProviderModelMixin):
class Grok(AsyncAuthedProvider, ProviderModelMixin):
label = "Grok AI"
url = "https://grok.com"
cookie_domain = ".grok.com"
assets_url = "https://assets.grok.com"
conversation_url = "https://grok.com/rest/app-chat/conversations"
needs_auth = True
working = False
working = True
default_model = "grok-3"
models = [default_model, "grok-3-thinking", "grok-2"]
_cookies: Cookies = None
@classmethod
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
if cookies is None:
cookies = get_cookies(cls.cookie_domain, False, True, False)
if cookies is not None and "sso" in cookies:
yield AuthResult(
cookies=cookies,
impersonate="chrome",
proxy=proxy,
headers=DEFAULT_HEADERS
)
return
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
yield AuthResult(
**await get_args_from_nodriver(
cls.url,
proxy=proxy,
wait_for='[href="/chat#private"]'
)
)
@classmethod
async def _prepare_payload(cls, model: str, message: str) -> Dict[str, Any]:
@@ -72,63 +77,43 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
}
@classmethod
async def create_async_generator(
async def create_authed(
cls,
model: str,
messages: Messages,
proxy: str = None,
auth_result: AuthResult,
cookies: Cookies = None,
connector: BaseConnector = None,
images: ImagesType = None,
return_conversation: bool = False,
conversation: Optional[Conversation] = None,
conversation: Conversation = None,
**kwargs
) -> AsyncResult:
cls._cookies = cookies or cls._cookies or get_cookies(".grok.com", False, True)
if not cls._cookies:
raise MissingAuthError("Missing required cookies")
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
base_connector = get_connector(connector, proxy)
headers = {
"accept": "*/*",
"accept-language": "en-GB,en;q=0.9",
"content-type": "application/json",
"origin": "https://grok.com",
"priority": "u=1, i",
"referer": "https://grok.com/",
"sec-ch-ua": '"Not/A)Brand";v="8", "Chromium";v="126", "Brave";v="126"',
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": '"macOS"',
"sec-fetch-dest": "empty",
"sec-fetch-mode": "cors",
"sec-fetch-site": "same-origin",
"sec-gpc": "1",
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
}
async with ClientSession(
headers=headers,
cookies=cls._cookies,
connector=base_connector
conversation_id = None if conversation is None else conversation.conversation_id
prompt = format_prompt(messages) if conversation_id is None else get_last_user_message(messages)
async with StreamSession(
**auth_result.get_dict()
) as session:
payload = await cls._prepare_payload(model, prompt)
response = await session.post(f"{cls.conversation_url}/new", json=payload)
if conversation_id is None:
url = f"{cls.conversation_url}/new"
else:
url = f"{cls.conversation_url}/{conversation_id}/responses"
async with session.post(url, json=payload) as response:
await raise_for_status(response)
thinking_duration = None
async for line in response.content:
async for line in response.iter_lines():
if line:
try:
json_data = json.loads(line)
result = json_data.get("result", {})
if conversation_id is None:
conversation_id = result.get("conversation", {}).get("conversationId")
response_data = result.get("response", {})
image = response_data.get("streamingImageGenerationResponse", None)
if image is not None:
yield ImagePreview(f'{cls.assets_url}/{image["imageUrl"]}', "", {"cookies": cookies, "headers": headers})
token = response_data.get("token", "")
is_thinking = response_data.get("isThinking", False)
token = response_data.get("token", result.get("token"))
is_thinking = response_data.get("isThinking", result.get("isThinking"))
if token:
if is_thinking:
if thinking_duration is None:
@@ -145,9 +130,12 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
generated_images = response_data.get("modelResponse", {}).get("generatedImageUrls", None)
if generated_images:
yield ImageResponse([f'{cls.assets_url}/{image}' for image in generated_images], "", {"cookies": cookies, "headers": headers})
title = response_data.get("title", {}).get("newTitle", "")
title = result.get("title", {}).get("newTitle", "")
if title:
yield TitleGeneration(title)
except json.JSONDecodeError:
continue
if return_conversation and conversation_id is not None:
yield Conversation(conversation_id)

View File

@@ -18,7 +18,7 @@ except ImportError:
has_nodriver = False
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...typing import AsyncResult, Messages, Cookies, MediaListType
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...requests import get_nodriver
@@ -127,7 +127,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls,
session: StreamSession,
auth_result: AuthResult,
images: ImagesType,
media: MediaListType,
) -> ImageRequest:
"""
Upload an image to the service and get the download URL
@@ -135,7 +135,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
Args:
session: The StreamSession object to use for requests
headers: The headers to include in the requests
images: The images to upload, either a PIL Image object or a bytes object
media: The images to upload, either a PIL Image object or a bytes object
Returns:
An ImageRequest object that contains the download URL, file name, and other data
@@ -187,9 +187,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await raise_for_status(response, "Get download url failed")
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)
if not images:
if not media:
return
return [await upload_image(image, image_name) for image, image_name in images]
return [await upload_image(image, image_name) for image, image_name in media]
@classmethod
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
@@ -268,7 +268,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
auto_continue: bool = False,
action: str = "next",
conversation: Conversation = None,
images: ImagesType = None,
media: MediaListType = None,
return_conversation: bool = False,
web_search: bool = False,
**kwargs
@@ -285,7 +285,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
auto_continue (bool): Flag to automatically continue the conversation.
action (str): Type of action ('next', 'continue', 'variant').
conversation_id (str): ID of the conversation.
images (ImagesType): Images to include in the conversation.
media (MediaListType): Images to include in the conversation.
return_conversation (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments.
@@ -316,7 +316,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls._update_request_args(auth_result, session)
await raise_for_status(response)
try:
image_requests = await cls.upload_images(session, auth_result, images) if images else None
image_requests = None if media is None else await cls.upload_images(session, auth_result, media)
except Exception as e:
debug.error("OpenaiChat: Upload image failed")
debug.error(e)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from ...typing import Messages, AsyncResult, ImagesType
from ...typing import Messages, AsyncResult, MediaListType
from ...requests import StreamSession
from ...image import to_data_uri
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -18,21 +18,21 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
images: ImagesType = None,
media: MediaListType = None,
api_key: str = None,
**kwargs
) -> AsyncResult:
debug.log(f"{cls.__name__}: {api_key}")
if images is not None:
for i in range(len(images)):
images[i] = (to_data_uri(images[i][0]), images[i][1])
if media is not None:
for i in range(len(media)):
media[i] = (to_data_uri(media[i][0]), media[i][1])
async with StreamSession(
headers={"Accept": "text/event-stream", **cls.headers},
) as session:
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
"model": model,
"messages": messages,
"images": images,
"media": media,
"api_key": api_key,
**kwargs
}, ssl=cls.ssl) as response:

View File

@@ -5,11 +5,11 @@ import requests
from ..helper import filter_none, format_image_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri
from ...image import to_data_uri, is_data_an_audio, to_input_audio
from ... import debug
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
@@ -54,7 +54,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
messages: Messages,
proxy: str = None,
timeout: int = 120,
images: ImagesType = None,
media: MediaListType = None,
api_key: str = None,
api_endpoint: str = None,
api_base: str = None,
@@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
prompt: str = None,
headers: dict = None,
impersonate: str = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
extra_data: dict = {},
**kwargs
) -> AsyncResult:
@@ -98,19 +98,26 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
yield ImageResponse([image["url"] for image in data["data"]], prompt)
return
if images is not None and messages:
if media is not None and messages:
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
last_message = messages[-1].copy()
last_message["content"] = [
*[{
*[
{
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
} for image, _ in images],
"image_url": {"url": to_data_uri(media_data, filename)}
}
for media_data, filename in media
],
{
"type": "text",
"text": messages[-1]["content"]
}
"text": last_message["content"]
} if isinstance(last_message["content"], str) else last_message["content"]
]
messages[-1] = last_message
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}

View File

@@ -34,12 +34,14 @@ class ChatCompletion:
ignore_stream: bool = False,
**kwargs) -> Union[CreateResult, str]:
if image is not None:
kwargs["images"] = [(image, image_name)]
kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider(
model, provider, stream,
ignore_working,
ignore_stream,
has_images="images" in kwargs,
has_images="media" in kwargs,
)
if "proxy" not in kwargs:
proxy = os.environ.get("G4F_PROXY")
@@ -63,8 +65,10 @@ class ChatCompletion:
ignore_working: bool = False,
**kwargs) -> Union[AsyncResult, Coroutine[str]]:
if image is not None:
kwargs["images"] = [(image, image_name)]
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="images" in kwargs)
kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="media" in kwargs)
if "proxy" not in kwargs:
proxy = os.environ.get("G4F_PROXY")
if proxy:

View File

@@ -38,7 +38,7 @@ import g4f.debug
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none
from g4f.image import is_data_uri_an_media
from g4f.image import is_data_an_media
from g4f.image.copy_images import images_dir, copy_images, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir
@@ -320,16 +320,18 @@ class Api:
if config.image is not None:
try:
is_data_uri_an_media(config.image)
is_data_an_media(config.image)
except ValueError as e:
return ErrorResponse.from_message(f"The image you send must be a data URI. Example: data:image/jpeg;base64,...", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
if config.images is not None:
for image in config.images:
if config.media is None:
config.media = config.images
if config.media is not None:
for image in config.media:
try:
is_data_uri_an_media(image[0])
is_data_an_media(image[0])
except ValueError as e:
example = json.dumps({"images": [["data:image/jpeg;base64,...", "filename"]]})
return ErrorResponse.from_message(f'The image you send must be a data URI. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
example = json.dumps({"media": [["data:image/jpeg;base64,...", "filename.jpg"]]})
return ErrorResponse.from_message(f'The media you send must be a data URIs. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
# Create the completion response
response = self.client.chat.completions.create(

View File

@@ -17,6 +17,7 @@ class ChatCompletionsConfig(BaseModel):
image: Optional[str] = None
image_name: Optional[str] = None
images: Optional[list[tuple[str, str]]] = None
media: Optional[list[tuple[str, str]]] = None
temperature: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None

View File

@@ -293,14 +293,16 @@ class Completions:
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if image is not None:
kwargs["images"] = [(image, image_name)]
kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
stream,
ignore_working,
ignore_stream,
has_images="images" in kwargs
has_images="media" in kwargs
)
stop = [stop] if isinstance(stop, str) else stop
if ignore_stream:
@@ -481,7 +483,7 @@ class Images:
proxy = self.client.proxy
prompt = "create a variation of this image"
if image is not None:
kwargs["images"] = [(image, None)]
kwargs["media"] = [(image, None)]
error = None
response = None
@@ -581,14 +583,16 @@ class AsyncCompletions:
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if image is not None:
kwargs["images"] = [(image, image_name)]
kwargs["media"] = [(image, image_name)]
elif "images" in kwargs:
kwargs["media"] = kwargs.pop("images")
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
stream,
ignore_working,
ignore_stream,
has_images="images" in kwargs,
has_images="media" in kwargs,
)
stop = [stop] if isinstance(stop, str) else stop
if ignore_stream:

View File

@@ -67,7 +67,7 @@ DOMAINS = [
if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, single_browser: bool = False) -> Dict[str, str]:
def get_cookies(domain_name: str, raise_requirements_error: bool = True, single_browser: bool = False, cache_result: bool = True) -> Dict[str, str]:
"""
Load cookies for a given domain from all supported browsers and cache the results.
@@ -77,10 +77,11 @@ def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, si
Returns:
Dict[str, str]: A dictionary of cookie names and values.
"""
if domain_name in CookiesConfig.cookies:
if cache_result and domain_name in CookiesConfig.cookies:
return CookiesConfig.cookies[domain_name]
cookies = load_cookies_from_browsers(domain_name, raise_requirements_error, single_browser)
if cache_result:
CookiesConfig.cookies[domain_name] = cookies
return cookies
@@ -108,8 +109,8 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
for cookie_fn in browsers:
try:
cookie_jar = cookie_fn(domain_name=domain_name)
if len(cookie_jar) and debug.logging:
print(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
if len(cookie_jar):
debug.log(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
for cookie in cookie_jar:
if cookie.name not in cookies:
if not cookie.expires or cookie.expires > time.time():
@@ -119,8 +120,7 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
except BrowserCookieError:
pass
except Exception as e:
if debug.logging:
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
debug.error(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
return cookies
def set_cookies_dir(dir: str) -> None:

View File

@@ -86,6 +86,23 @@
height: 100%;
text-align: center;
z-index: 1;
transition: transform 0.25s ease-in;
}
.container.slide {
transform: translateX(-100%);
transition: transform 0.15s ease-out;
}
.slide-button {
position: absolute;
top: 20px;
left: 20px;
background: var(--colour-4);
border: none;
border-radius: 4px;
cursor: pointer;
padding: 6px 8px;
}
header {
@@ -105,7 +122,8 @@
height: 100%;
position: absolute;
z-index: -1;
object-fit: contain;
object-fit: cover;
object-position: center;
width: 100%;
background: black;
}
@@ -181,6 +199,7 @@
}
}
</style>
<link rel="stylesheet" href="/static/css/all.min.css">
<script>
(async () => {
const isIframe = window.self !== window.top;
@@ -208,6 +227,10 @@
<!-- Gradient Background Circle -->
<div class="gradient"></div>
<button class="slide-button">
<i class="fa-solid fa-arrow-left"></i>
</button>
<!-- Main Content -->
<div class="container">
<header>
@@ -277,8 +300,11 @@
if (oauthResult) {
try {
oauthResult = JSON.parse(oauthResult);
user = await hub.whoAmI({accessToken: oauthResult.accessToken});
} catch {
oauthResult = null;
localStorage.removeItem("oauth");
localStorage.removeItem("HuggingFace-api_key");
}
}
oauthResult ||= await oauthHandleRedirectIfPresent();
@@ -339,7 +365,7 @@
return;
}
const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) {
if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]);
@@ -363,6 +389,21 @@
imageFeed.remove();
}
}, 7000);
const container = document.querySelector('.container');
const button = document.querySelector('.slide-button');
const slideIcon = button.querySelector('i');
button.onclick = () => {
if (container.classList.contains('slide')) {
container.classList.remove('slide');
slideIcon.classList.remove('fa-arrow-right');
slideIcon.classList.add('fa-arrow-left');
} else {
container.classList.add('slide');
slideIcon.classList.remove('fa-arrow-left');
slideIcon.classList.add('fa-arrow-right');
}
}
})();
</script>
</body>

View File

@@ -141,7 +141,7 @@ class Api:
stream=True,
ignore_stream=True,
logging=False,
has_images="images" in kwargs,
has_images="media" in kwargs,
)
except Exception as e:
debug.error(e)

View File

@@ -8,6 +8,7 @@ import asyncio
import shutil
import random
import datetime
import tempfile
from flask import Flask, Response, request, jsonify, render_template
from typing import Generator
from pathlib import Path
@@ -111,11 +112,13 @@ class Backend_Api(Api):
else:
json_data = request.json
if "files" in request.files:
images = []
media = []
for file in request.files.getlist('files'):
if file.filename != '' and is_allowed_extension(file.filename):
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
json_data['images'] = images
newfile = tempfile.TemporaryFile()
shutil.copyfileobj(file.stream, newfile)
media.append((newfile, file.filename))
json_data['media'] = media
if app.demo and not json_data.get("provider"):
model = json_data.get("model")

View File

@@ -76,24 +76,24 @@ def is_allowed_extension(filename: str) -> bool:
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_media(data_uri: str) -> str:
return is_data_an_wav(data_uri) or is_data_uri_an_image(data_uri)
def is_data_an_media(data, filename: str = None) -> str:
content_type = is_data_an_audio(data, filename)
if content_type is not None:
return content_type
if isinstance(data, bytes):
return is_accepted_format(data)
return is_data_uri_an_image(data)
def is_data_an_wav(data_uri: str, filename: str = None) -> str:
"""
Checks if the given data URI represents an image.
Args:
data_uri (str): The data URI to check.
Raises:
ValueError: If the data URI is invalid or the image format is not allowed.
"""
if filename and filename.endswith(".wav"):
return "audio/wav"
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if isinstance(data_uri, str) and re.match(r'data:audio/wav;base64,', data_uri):
def is_data_an_audio(data_uri: str, filename: str = None) -> str:
if filename:
if filename.endswith(".wav"):
return "audio/wav"
elif filename.endswith(".mp3"):
return "audio/mpeg"
if isinstance(data_uri, str):
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
if audio_format:
return audio_format.group(1)
def is_data_uri_an_image(data_uri: str) -> bool:
"""
@@ -218,7 +218,7 @@ def to_bytes(image: ImageType) -> bytes:
if isinstance(image, bytes):
return image
elif isinstance(image, str) and image.startswith("data:"):
is_data_uri_an_media(image)
is_data_an_media(image)
return extract_data_uri(image)
elif isinstance(image, Image):
bytes_io = BytesIO()
@@ -236,13 +236,29 @@ def to_bytes(image: ImageType) -> bytes:
pass
return image.read()
def to_data_uri(image: ImageType) -> str:
def to_data_uri(image: ImageType, filename: str = None) -> str:
if not isinstance(image, str):
data = to_bytes(image)
data_base64 = base64.b64encode(data).decode()
return f"data:{is_accepted_format(data)};base64,{data_base64}"
return f"data:{is_data_an_media(data, filename)};base64,{data_base64}"
return image
def to_input_audio(audio: ImageType, filename: str = None) -> str:
if not isinstance(audio, str):
if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
return {
"data": base64.b64encode(to_bytes(audio)).decode(),
"format": "wav" if filename.endswith(".wav") else "mpeg"
}
raise ValueError("Invalid input audio")
audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
if audio:
return {
"data": audio.group(2),
"format": audio.group(1),
}
raise ValueError("Invalid input audio")
class ImageDataResponse():
def __init__(
self,

View File

@@ -109,7 +109,7 @@ async def copy_images(
f.write(chunk)
# Verify file format
if not os.path.splitext(target_path)[1]:
if target is None and not os.path.splitext(target_path)[1]:
with open(target_path, "rb") as f:
file_header = f.read(12)
detected_type = is_accepted_format(file_header)
@@ -120,7 +120,7 @@ async def copy_images(
# Build URL with safe encoding
url_filename = quote(os.path.basename(target_path))
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
return f"/images/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
except (ClientError, IOError, OSError) as e:
debug.error(f"Image copying failed: {type(e).__name__}: {e}")

View File

@@ -25,7 +25,7 @@ from .. import debug
SAFE_PARAMETERS = [
"model", "messages", "stream", "timeout",
"proxy", "images", "response_format",
"proxy", "media", "response_format",
"prompt", "negative_prompt", "tools", "conversation",
"history_disabled",
"temperature", "top_k", "top_p",
@@ -56,7 +56,7 @@ PARAMETER_EXAMPLES = {
"frequency_penalty": 1,
"presence_penalty": 1,
"messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
"images": [["data:image/jpeg;base64,...", "filename.jpg"]],
"media": [["data:image/jpeg;base64,...", "filename.jpg"]],
"response_format": {"type": "json_object"},
"conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
"seed": 42,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from ..typing import AsyncResult, Messages, ImagesType
from ..typing import AsyncResult, Messages, MediaListType
from ..client.service import get_model_and_provider
from ..client.helper import filter_json
from .base_provider import AsyncGeneratorProvider
@@ -17,7 +17,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
model: str,
messages: Messages,
stream: bool = True,
images: ImagesType = None,
media: MediaListType = None,
tools: list[str] = None,
response_format: dict = None,
**kwargs
@@ -28,7 +28,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
model, provider = get_model_and_provider(
model, provider,
stream, logging=False,
has_images=images is not None
has_images=media is not None
)
if tools is not None:
if len(tools) > 1:
@@ -49,7 +49,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
model,
messages,
stream=stream,
images=images,
media=media,
response_format=response_format,
**kwargs
):

View File

@@ -21,7 +21,7 @@ AsyncResult = AsyncIterator[Union[str, ResponseType]]
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
Cookies = Dict[str, str]
ImageType = Union[str, bytes, IO, Image, os.PathLike]
ImagesType = List[Tuple[ImageType, Optional[str]]]
MediaListType = List[Tuple[ImageType, Optional[str]]]
__all__ = [
'Any',
@@ -44,5 +44,5 @@ __all__ = [
'Cookies',
'Image',
'ImageType',
'ImagesType'
'MediaListType'
]