mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-05 16:26:57 +08:00
Add audio transcribing example and support
Add Grok Chat provider Rename images parameter to media Update demo homepage
This commit is contained in:
@@ -19,6 +19,7 @@ The G4F AsyncClient API is designed to be compatible with the OpenAI API, making
|
||||
- [Text Completions](#text-completions)
|
||||
- [Streaming Completions](#streaming-completions)
|
||||
- [Using a Vision Model](#using-a-vision-model)
|
||||
- **[Transcribing Audio with Chat Completions](#transcribing-audio-with-chat-completions)** *(New Section)*
|
||||
- [Image Generation](#image-generation)
|
||||
- [Advanced Usage](#advanced-usage)
|
||||
- [Conversation Memory](#conversation-memory)
|
||||
@@ -203,6 +204,54 @@ async def main():
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Transcribing Audio with Chat Completions
|
||||
|
||||
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from g4f.client import AsyncClient
|
||||
import g4f.Provider
|
||||
import g4f.models
|
||||
|
||||
async def main():
|
||||
client = AsyncClient(provider=g4f.Provider.PollinationsAI) # or g4f.Provider.Microsoft_Phi_4
|
||||
|
||||
with open("audio.wav", "rb") as audio_file:
|
||||
response = await client.chat.completions.create(
|
||||
model=g4f.models.default,
|
||||
messages=[{"role": "user", "content": "Transcribe this audio"}],
|
||||
media=[[audio_file, "audio.wav"]],
|
||||
modalities=["text"],
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
#### Explanation
|
||||
- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`.
|
||||
- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use.
|
||||
- **API Call**: The `chat.completions.create` method is called with:
|
||||
- `model=g4f.models.default`: Uses the default model for the selected provider.
|
||||
- `messages`: A list containing a user message instructing the model to transcribe the audio.
|
||||
- `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`).
|
||||
- `modalities=["text"]`: Specifies that the output should be text (the transcription).
|
||||
- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed.
|
||||
|
||||
#### Notes
|
||||
- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality.
|
||||
- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider.
|
||||
- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models).
|
||||
|
||||
This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API.
|
||||
|
||||
---
|
||||
|
||||
### Image Generation
|
||||
**The `response_format` parameter is optional and can have the following values:**
|
||||
- **If not specified (default):** The image will be saved locally, and a local path will be returned (e.g., "/images/1733331238_cf9d6aa9-f606-4fea-ba4b-f06576cba309.jpg").
|
||||
|
@@ -6,10 +6,10 @@ from g4f.models import __models__
|
||||
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
|
||||
from g4f.errors import MissingRequirementsError, MissingAuthError
|
||||
|
||||
class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
|
||||
class TestProviderHasModel(unittest.TestCase):
|
||||
cache: dict = {}
|
||||
|
||||
async def test_provider_has_model(self):
|
||||
def test_provider_has_model(self):
|
||||
for model, providers in __models__.values():
|
||||
for provider in providers:
|
||||
if issubclass(provider, ProviderModelMixin):
|
||||
@@ -17,9 +17,9 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
|
||||
model_name = provider.model_aliases[model.name]
|
||||
else:
|
||||
model_name = model.name
|
||||
await asyncio.wait_for(self.provider_has_model(provider, model_name), 10)
|
||||
self.provider_has_model(provider, model_name)
|
||||
|
||||
async def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
||||
def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
||||
if provider.__name__ not in self.cache:
|
||||
try:
|
||||
self.cache[provider.__name__] = provider.get_models()
|
||||
@@ -28,7 +28,7 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
|
||||
if self.cache[provider.__name__]:
|
||||
self.assertIn(model, self.cache[provider.__name__], provider.__name__)
|
||||
|
||||
async def test_all_providers_working(self):
|
||||
def test_all_providers_working(self):
|
||||
for model, providers in __models__.values():
|
||||
for provider in providers:
|
||||
self.assertTrue(provider.working, f"{provider.__name__} in {model.name}")
|
@@ -48,7 +48,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"olmo-2-13b": "OLMo-2-1124-13B-Instruct",
|
||||
"tulu-3-1-8b": "tulu-3-1-8b",
|
||||
"tulu-3-70b": "Llama-3-1-Tulu-3-70B",
|
||||
"llama-3.1-405b": "tulu-3-405b",
|
||||
"llama-3.1-405b": "tulu3-405b",
|
||||
"llama-3.1-8b": "tulu-3-1-8b",
|
||||
"llama-3.1-70b": "Llama-3-1-Tulu-3-70B",
|
||||
}
|
||||
|
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..image import to_data_uri
|
||||
@@ -444,7 +444,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
top_p: float = None,
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
@@ -479,14 +479,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
}
|
||||
current_messages.append(current_msg)
|
||||
|
||||
if images is not None:
|
||||
if media is not None:
|
||||
current_messages[-1]['data'] = {
|
||||
"imagesData": [
|
||||
{
|
||||
"filePath": f"/{image_name}",
|
||||
"contents": to_data_uri(image)
|
||||
}
|
||||
for image, image_name in images
|
||||
for image, image_name in media
|
||||
],
|
||||
"fileText": "",
|
||||
"title": ""
|
||||
|
@@ -21,7 +21,7 @@ except ImportError:
|
||||
from .base_provider import AbstractProvider, ProviderModelMixin
|
||||
from .helper import format_prompt_max_length
|
||||
from .openai.har_file import get_headers, get_har_files
|
||||
from ..typing import CreateResult, Messages, ImagesType
|
||||
from ..typing import CreateResult, Messages, MediaListType
|
||||
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
|
||||
@@ -66,7 +66,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
||||
proxy: str = None,
|
||||
timeout: int = 900,
|
||||
prompt: str = None,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
conversation: BaseConversation = None,
|
||||
return_conversation: bool = False,
|
||||
api_key: str = None,
|
||||
@@ -77,7 +77,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
||||
|
||||
websocket_url = cls.websocket_url
|
||||
headers = None
|
||||
if cls.needs_auth or images is not None:
|
||||
if cls.needs_auth or media is not None:
|
||||
if api_key is not None:
|
||||
cls._access_token = api_key
|
||||
if cls._access_token is None:
|
||||
@@ -142,8 +142,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
||||
|
||||
uploaded_images = []
|
||||
if images is not None:
|
||||
for image, _ in images:
|
||||
if media is not None:
|
||||
for image, _ in media:
|
||||
data = to_bytes(image)
|
||||
response = session.post(
|
||||
"https://copilot.microsoft.com/c/api/attachments",
|
||||
|
@@ -30,7 +30,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
api_endpoint = "https://duckduckgo.com/duckchat/v1/chat"
|
||||
status_url = "https://duckduckgo.com/duckchat/v1/status"
|
||||
|
||||
working = True
|
||||
working = False
|
||||
supports_stream = True
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from .template import OpenaiTemplate
|
||||
from ..image import to_data_uri
|
||||
|
||||
@@ -70,7 +70,6 @@ class DeepInfraChat(OpenaiTemplate):
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = None,
|
||||
headers: dict = {},
|
||||
images: ImagesType = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
headers = {
|
||||
@@ -82,23 +81,6 @@ class DeepInfraChat(OpenaiTemplate):
|
||||
**headers
|
||||
}
|
||||
|
||||
if images is not None:
|
||||
if not model or model not in cls.models:
|
||||
model = cls.default_vision_model
|
||||
if messages:
|
||||
last_message = messages[-1].copy()
|
||||
last_message["content"] = [
|
||||
*[{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(image)}
|
||||
} for image, _ in images],
|
||||
{
|
||||
"type": "text",
|
||||
"text": last_message["content"]
|
||||
}
|
||||
]
|
||||
messages[-1] = last_message
|
||||
|
||||
async for chunk in super().create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
|
@@ -3,13 +3,12 @@ from __future__ import annotations
|
||||
import json
|
||||
from aiohttp import ClientSession, FormData
|
||||
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..image import to_data_uri, to_bytes, is_accepted_format
|
||||
from ..image import to_bytes, is_accepted_format
|
||||
from .helper import format_prompt
|
||||
|
||||
|
||||
class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://dynaspark.onrender.com"
|
||||
login_url = None
|
||||
@@ -38,7 +37,7 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
headers = {
|
||||
@@ -49,14 +48,13 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
|
||||
'x-requested-with': 'XMLHttpRequest'
|
||||
}
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
form = FormData()
|
||||
form.add_field('user_input', format_prompt(messages))
|
||||
form.add_field('ai_model', model)
|
||||
|
||||
if images is not None and len(images) > 0:
|
||||
image, image_name = images[0]
|
||||
if media is not None and len(media) > 0:
|
||||
image, image_name = media[0]
|
||||
image_bytes = to_bytes(image)
|
||||
form.add_field('file', image_bytes, filename=image_name, content_type=is_accepted_format(image_bytes))
|
||||
|
||||
|
@@ -14,6 +14,7 @@ class LambdaChat(HuggingChat):
|
||||
default_model = "deepseek-llama3.3-70b"
|
||||
reasoning_model = "deepseek-r1"
|
||||
image_models = []
|
||||
models = []
|
||||
fallback_models = [
|
||||
default_model,
|
||||
reasoning_model,
|
||||
|
@@ -9,8 +9,8 @@ from aiohttp import ClientSession
|
||||
|
||||
from .helper import filter_none, format_image_prompt
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..image import to_data_uri
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..image import to_data_uri, is_data_an_audio, to_input_audio
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..requests.aiohttp import get_connector
|
||||
@@ -146,17 +146,24 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
enhance: bool = False,
|
||||
safe: bool = False,
|
||||
# Text generation parameters
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
temperature: float = None,
|
||||
presence_penalty: float = None,
|
||||
top_p: float = 1,
|
||||
frequency_penalty: float = None,
|
||||
response_format: Optional[dict] = None,
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice"],
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice", "modalities"],
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
# Load model list
|
||||
cls.get_models()
|
||||
if not model and media is not None:
|
||||
has_audio = False
|
||||
for media_data, filename in media:
|
||||
if is_data_an_audio(media_data, filename):
|
||||
has_audio = True
|
||||
break
|
||||
model = next(iter(cls.audio_models)) if has_audio else cls.default_vision_model
|
||||
try:
|
||||
model = cls.get_model(model)
|
||||
except ModelNotFoundError:
|
||||
@@ -182,7 +189,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async for result in cls._generate_text(
|
||||
model=model,
|
||||
messages=messages,
|
||||
images=images,
|
||||
media=media,
|
||||
proxy=proxy,
|
||||
temperature=temperature,
|
||||
presence_penalty=presence_penalty,
|
||||
@@ -239,7 +246,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
images: Optional[ImagesType],
|
||||
media: MediaListType,
|
||||
proxy: str,
|
||||
temperature: float,
|
||||
presence_penalty: float,
|
||||
@@ -258,14 +265,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if response_format and response_format.get("type") == "json_object":
|
||||
json_mode = True
|
||||
|
||||
if images and messages:
|
||||
if media and messages:
|
||||
last_message = messages[-1].copy()
|
||||
image_content = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(image)}
|
||||
"type": "input_audio",
|
||||
"input_audio": to_input_audio(media_data, filename)
|
||||
}
|
||||
for image, _ in images
|
||||
if is_data_an_audio(media_data, filename) else {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(media_data)}
|
||||
}
|
||||
for media_data, filename in media
|
||||
]
|
||||
last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
|
||||
messages[-1] = last_message
|
||||
|
@@ -17,7 +17,7 @@ except ImportError:
|
||||
|
||||
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
||||
from ..helper import format_prompt, format_image_prompt, get_last_user_message
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
|
||||
from ...image import to_bytes
|
||||
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
||||
@@ -99,7 +99,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
messages: Messages,
|
||||
auth_result: AuthResult,
|
||||
prompt: str = None,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
return_conversation: bool = False,
|
||||
conversation: Conversation = None,
|
||||
web_search: bool = False,
|
||||
@@ -108,7 +108,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
if not has_curl_cffi:
|
||||
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
||||
if model == llama_models["name"]:
|
||||
model = llama_models["text"] if images is None else llama_models["vision"]
|
||||
model = llama_models["text"] if media is None else llama_models["vision"]
|
||||
model = cls.get_model(model)
|
||||
|
||||
session = Session(**auth_result.get_dict())
|
||||
@@ -145,8 +145,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
}
|
||||
data = CurlMime()
|
||||
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
|
||||
if images is not None:
|
||||
for image, filename in images:
|
||||
if media is not None:
|
||||
for image, filename in media:
|
||||
data.addpart(
|
||||
"files",
|
||||
filename=f"base64;{filename}",
|
||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import requests
|
||||
|
||||
from ...providers.types import Messages
|
||||
from ...typing import ImagesType
|
||||
from ...typing import MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import ModelNotSupportedError
|
||||
from ...providers.helper import get_last_user_message
|
||||
@@ -75,11 +75,11 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
api_key: str = None,
|
||||
max_tokens: int = 2048,
|
||||
max_inputs_lenght: int = 10000,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
**kwargs
|
||||
):
|
||||
if model == llama_models["name"]:
|
||||
model = llama_models["text"] if images is None else llama_models["vision"]
|
||||
model = llama_models["text"] if media is None else llama_models["vision"]
|
||||
if model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
provider_mapping = await cls.get_mapping(model, api_key)
|
||||
@@ -103,7 +103,7 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
|
||||
messages = last_user_message
|
||||
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
|
||||
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
|
||||
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def calculate_lenght(messages: Messages) -> int:
|
||||
|
@@ -7,7 +7,7 @@ import random
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import urllib.parse
|
||||
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, format_image_prompt
|
||||
from ...providers.response import JsonConversation, ImageResponse, Reasoning
|
||||
@@ -68,7 +68,7 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
cookies: Cookies = None,
|
||||
@@ -98,27 +98,27 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
|
||||
if images is not None:
|
||||
if media is not None:
|
||||
data = FormData()
|
||||
for i in range(len(images)):
|
||||
images[i] = (to_bytes(images[i][0]), images[i][1])
|
||||
for image, image_name in images:
|
||||
for i in range(len(media)):
|
||||
media[i] = (to_bytes(media[i][0]), media[i][1])
|
||||
for image, image_name in media:
|
||||
data.add_field(f"files", image, filename=image_name)
|
||||
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
|
||||
await raise_for_status(response)
|
||||
image_files = await response.json()
|
||||
images = [{
|
||||
media = [{
|
||||
"path": image_file,
|
||||
"url": f"{cls.api_url}/gradio_api/file={image_file}",
|
||||
"orig_name": images[i][1],
|
||||
"size": len(images[i][0]),
|
||||
"mime_type": is_accepted_format(images[i][0]),
|
||||
"orig_name": media[i][1],
|
||||
"size": len(media[i][0]),
|
||||
"mime_type": is_accepted_format(media[i][0]),
|
||||
"meta": {
|
||||
"_type": "gradio.FileData"
|
||||
}
|
||||
} for i, image_file in enumerate(image_files)]
|
||||
|
||||
async with cls.run(method, session, prompt, conversation, None if images is None else images.pop(), seed) as response:
|
||||
async with cls.run(method, session, prompt, conversation, None if media is None else media.pop(), seed) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
async with cls.run("get", session, prompt, conversation, None, seed) as response:
|
||||
|
@@ -3,13 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, format_image_prompt
|
||||
from ...providers.response import JsonConversation
|
||||
from ...requests.aiohttp import StreamSession, StreamResponse, FormData
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...image import to_bytes, is_accepted_format, is_data_an_wav
|
||||
from ...image import to_bytes, is_accepted_format, is_data_an_audio
|
||||
from ...errors import ResponseError
|
||||
from ... import debug
|
||||
from .DeepseekAI_JanusPro7b import get_zerogpu_token
|
||||
@@ -32,7 +32,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
models = [default_model]
|
||||
|
||||
@classmethod
|
||||
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, images: list = None):
|
||||
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, media: list = None):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"x-zerogpu-token": conversation.zerogpu_token,
|
||||
@@ -47,7 +47,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
[],
|
||||
{
|
||||
"text": prompt,
|
||||
"files": images,
|
||||
"files": media,
|
||||
},
|
||||
None
|
||||
],
|
||||
@@ -70,7 +70,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
{
|
||||
"role": "user",
|
||||
"content": {"file": image}
|
||||
} for image in images
|
||||
} for image in media
|
||||
]],
|
||||
"event_data": None,
|
||||
"fn_index": 11,
|
||||
@@ -91,7 +91,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
cookies: Cookies = None,
|
||||
@@ -115,23 +115,23 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
|
||||
if images is not None:
|
||||
if media is not None:
|
||||
data = FormData()
|
||||
mime_types = [None for i in range(len(images))]
|
||||
for i in range(len(images)):
|
||||
mime_types[i] = is_data_an_wav(images[i][0], images[i][1])
|
||||
images[i] = (to_bytes(images[i][0]), images[i][1])
|
||||
mime_types[i] = is_accepted_format(images[i][0]) if mime_types[i] is None else mime_types[i]
|
||||
for image, image_name in images:
|
||||
mime_types = [None for i in range(len(media))]
|
||||
for i in range(len(media)):
|
||||
mime_types[i] = is_data_an_audio(media[i][0], media[i][1])
|
||||
media[i] = (to_bytes(media[i][0]), media[i][1])
|
||||
mime_types[i] = is_accepted_format(media[i][0]) if mime_types[i] is None else mime_types[i]
|
||||
for image, image_name in media:
|
||||
data.add_field(f"files", to_bytes(image), filename=image_name)
|
||||
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
|
||||
await raise_for_status(response)
|
||||
image_files = await response.json()
|
||||
images = [{
|
||||
media = [{
|
||||
"path": image_file,
|
||||
"url": f"{cls.api_url}/gradio_api/file={image_file}",
|
||||
"orig_name": images[i][1],
|
||||
"size": len(images[i][0]),
|
||||
"orig_name": media[i][1],
|
||||
"size": len(media[i][0]),
|
||||
"mime_type": mime_types[i],
|
||||
"meta": {
|
||||
"_type": "gradio.FileData"
|
||||
@@ -139,10 +139,10 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
} for i, image_file in enumerate(image_files)]
|
||||
|
||||
|
||||
async with cls.run("predict", session, prompt, conversation, images) as response:
|
||||
async with cls.run("predict", session, prompt, conversation, media) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
async with cls.run("post", session, prompt, conversation, images) as response:
|
||||
async with cls.run("post", session, prompt, conversation, media) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
async with cls.run("get", session, prompt, conversation) as response:
|
||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from aiohttp import ClientSession, FormData
|
||||
|
||||
from ...typing import AsyncResult, Messages, ImagesType
|
||||
from ...typing import AsyncResult, Messages, MediaListType
|
||||
from ...requests import raise_for_status
|
||||
from ...errors import ResponseError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
@@ -25,7 +25,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls, model: str, messages: Messages,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
api_key: str = None,
|
||||
proxy: str = None,
|
||||
**kwargs
|
||||
@@ -36,10 +36,10 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with ClientSession(headers=headers) as session:
|
||||
if images:
|
||||
if media:
|
||||
data = FormData()
|
||||
data_bytes = to_bytes(images[0][0])
|
||||
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=images[0][1])
|
||||
data_bytes = to_bytes(media[0][0])
|
||||
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=media[0][1])
|
||||
url = f"{cls.url}/gradio_api/upload?upload_id={get_random_string()}"
|
||||
async with session.post(url, data=data, proxy=proxy) as response:
|
||||
await raise_for_status(response)
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from ...typing import AsyncResult, Messages, ImagesType
|
||||
from ...typing import AsyncResult, Messages, MediaListType
|
||||
from ...errors import ResponseError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
|
||||
@@ -67,15 +67,15 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls, model: str, messages: Messages, images: ImagesType = None, **kwargs
|
||||
cls, model: str, messages: Messages, media: MediaListType = None, **kwargs
|
||||
) -> AsyncResult:
|
||||
if not model and images is not None:
|
||||
if not model and media is not None:
|
||||
model = cls.default_vision_model
|
||||
is_started = False
|
||||
random.shuffle(cls.providers)
|
||||
for provider in cls.providers:
|
||||
if model in provider.model_aliases:
|
||||
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, images=images, **kwargs):
|
||||
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, media=media, **kwargs):
|
||||
is_started = True
|
||||
yield chunk
|
||||
if is_started:
|
||||
|
@@ -6,7 +6,7 @@ import base64
|
||||
from typing import Optional
|
||||
|
||||
from ..helper import filter_none
|
||||
from ...typing import AsyncResult, Messages, ImagesType
|
||||
from ...typing import AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage
|
||||
from ...errors import MissingAuthError
|
||||
@@ -62,7 +62,7 @@ class Anthropic(OpenaiAPI):
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
api_key: str = None,
|
||||
temperature: float = None,
|
||||
max_tokens: int = 4096,
|
||||
@@ -79,9 +79,9 @@ class Anthropic(OpenaiAPI):
|
||||
if api_key is None:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
|
||||
if images is not None:
|
||||
if media is not None:
|
||||
insert_images = []
|
||||
for image, _ in images:
|
||||
for image, _ in media:
|
||||
data = to_bytes(image)
|
||||
insert_images.append({
|
||||
"type": "image",
|
||||
|
@@ -19,7 +19,7 @@ except ImportError:
|
||||
has_nodriver = False
|
||||
|
||||
from ... import debug
|
||||
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator
|
||||
from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator
|
||||
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import get_connector
|
||||
@@ -149,7 +149,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
proxy: str = None,
|
||||
cookies: Cookies = None,
|
||||
connector: BaseConnector = None,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
return_conversation: bool = False,
|
||||
conversation: Conversation = None,
|
||||
language: str = "en",
|
||||
@@ -186,7 +186,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls.start_auto_refresh()
|
||||
)
|
||||
|
||||
images = await cls.upload_images(base_connector, images) if images else None
|
||||
uploads = None if media is None else await cls.upload_images(base_connector, media)
|
||||
async with ClientSession(
|
||||
cookies=cls._cookies,
|
||||
headers=REQUEST_HEADERS,
|
||||
@@ -205,7 +205,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
prompt,
|
||||
language=language,
|
||||
conversation=conversation,
|
||||
images=images
|
||||
uploads=uploads
|
||||
))])
|
||||
}
|
||||
async with client.post(
|
||||
@@ -327,10 +327,10 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
prompt: str,
|
||||
language: str,
|
||||
conversation: Conversation = None,
|
||||
images: list[list[str, str]] = None,
|
||||
uploads: list[list[str, str]] = None,
|
||||
tools: list[list[str]] = []
|
||||
) -> list:
|
||||
image_list = [[[image_url, 1], image_name] for image_url, image_name in images] if images else []
|
||||
image_list = [[[image_url, 1], image_name] for image_url, image_name in uploads] if uploads else []
|
||||
return [
|
||||
[prompt, 0, None, image_list, None, None, 0],
|
||||
[language],
|
||||
@@ -353,7 +353,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
0,
|
||||
]
|
||||
|
||||
async def upload_images(connector: BaseConnector, images: ImagesType) -> list:
|
||||
async def upload_images(connector: BaseConnector, media: MediaListType) -> list:
|
||||
async def upload_image(image: bytes, image_name: str = None):
|
||||
async with ClientSession(
|
||||
headers=UPLOAD_IMAGE_HEADERS,
|
||||
@@ -385,7 +385,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
) as response:
|
||||
await raise_for_status(response)
|
||||
return [await response.text(), image_name]
|
||||
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in images])
|
||||
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in media])
|
||||
|
||||
@classmethod
|
||||
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):
|
||||
|
@@ -6,8 +6,8 @@ import requests
|
||||
from typing import Optional
|
||||
from aiohttp import ClientSession, BaseConnector
|
||||
|
||||
from ...typing import AsyncResult, Messages, ImagesType
|
||||
from ...image import to_bytes, is_accepted_format
|
||||
from ...typing import AsyncResult, Messages, MediaListType
|
||||
from ...image import to_bytes, is_data_an_media
|
||||
from ...errors import MissingAuthError
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...providers.response import Usage, FinishReason
|
||||
@@ -67,7 +67,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
api_key: str = None,
|
||||
api_base: str = api_base,
|
||||
use_auth_header: bool = False,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
tools: Optional[list] = None,
|
||||
connector: BaseConnector = None,
|
||||
**kwargs
|
||||
@@ -94,13 +94,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
for message in messages
|
||||
if message["role"] != "system"
|
||||
]
|
||||
if images is not None:
|
||||
for image, _ in images:
|
||||
if media is not None:
|
||||
for media_data, filename in media:
|
||||
image = to_bytes(image)
|
||||
contents[-1]["parts"].append({
|
||||
"inline_data": {
|
||||
"mime_type": is_accepted_format(image),
|
||||
"data": base64.b64encode(image).decode()
|
||||
"mime_type": is_data_an_media(image, filename),
|
||||
"data": base64.b64encode(media_data).decode()
|
||||
}
|
||||
})
|
||||
data = {
|
||||
|
@@ -1,51 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import base64
|
||||
import asyncio
|
||||
import time
|
||||
from urllib.parse import quote_plus, unquote_plus
|
||||
from pathlib import Path
|
||||
from aiohttp import ClientSession, BaseConnector
|
||||
from typing import Dict, Any, Optional, AsyncIterator, List
|
||||
from typing import Dict, Any, AsyncIterator
|
||||
|
||||
from ... import debug
|
||||
from ...typing import Messages, Cookies, ImagesType, AsyncResult
|
||||
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration
|
||||
from ...typing import Messages, Cookies, AsyncResult
|
||||
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration, AuthResult, RequestLogin
|
||||
from ...requests import StreamSession, get_args_from_nodriver, DEFAULT_HEADERS
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import get_connector
|
||||
from ...requests import get_nodriver
|
||||
from ...errors import MissingAuthError
|
||||
from ...cookies import get_cookies_dir
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, get_cookies, get_last_user_message
|
||||
|
||||
class Conversation(JsonConversation):
|
||||
def __init__(self,
|
||||
conversation_id: str,
|
||||
response_id: str,
|
||||
choice_id: str,
|
||||
model: str
|
||||
conversation_id: str
|
||||
) -> None:
|
||||
self.conversation_id = conversation_id
|
||||
self.response_id = response_id
|
||||
self.choice_id = choice_id
|
||||
self.model = model
|
||||
|
||||
class Grok(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
||||
label = "Grok AI"
|
||||
url = "https://grok.com"
|
||||
cookie_domain = ".grok.com"
|
||||
assets_url = "https://assets.grok.com"
|
||||
conversation_url = "https://grok.com/rest/app-chat/conversations"
|
||||
|
||||
needs_auth = True
|
||||
working = False
|
||||
working = True
|
||||
|
||||
default_model = "grok-3"
|
||||
models = [default_model, "grok-3-thinking", "grok-2"]
|
||||
|
||||
_cookies: Cookies = None
|
||||
@classmethod
|
||||
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
|
||||
if cookies is None:
|
||||
cookies = get_cookies(cls.cookie_domain, False, True, False)
|
||||
if cookies is not None and "sso" in cookies:
|
||||
yield AuthResult(
|
||||
cookies=cookies,
|
||||
impersonate="chrome",
|
||||
proxy=proxy,
|
||||
headers=DEFAULT_HEADERS
|
||||
)
|
||||
return
|
||||
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
|
||||
yield AuthResult(
|
||||
**await get_args_from_nodriver(
|
||||
cls.url,
|
||||
proxy=proxy,
|
||||
wait_for='[href="/chat#private"]'
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _prepare_payload(cls, model: str, message: str) -> Dict[str, Any]:
|
||||
@@ -72,63 +77,43 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
async def create_authed(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
auth_result: AuthResult,
|
||||
cookies: Cookies = None,
|
||||
connector: BaseConnector = None,
|
||||
images: ImagesType = None,
|
||||
return_conversation: bool = False,
|
||||
conversation: Optional[Conversation] = None,
|
||||
conversation: Conversation = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
cls._cookies = cookies or cls._cookies or get_cookies(".grok.com", False, True)
|
||||
if not cls._cookies:
|
||||
raise MissingAuthError("Missing required cookies")
|
||||
|
||||
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||
base_connector = get_connector(connector, proxy)
|
||||
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"accept-language": "en-GB,en;q=0.9",
|
||||
"content-type": "application/json",
|
||||
"origin": "https://grok.com",
|
||||
"priority": "u=1, i",
|
||||
"referer": "https://grok.com/",
|
||||
"sec-ch-ua": '"Not/A)Brand";v="8", "Chromium";v="126", "Brave";v="126"',
|
||||
"sec-ch-ua-mobile": "?0",
|
||||
"sec-ch-ua-platform": '"macOS"',
|
||||
"sec-fetch-dest": "empty",
|
||||
"sec-fetch-mode": "cors",
|
||||
"sec-fetch-site": "same-origin",
|
||||
"sec-gpc": "1",
|
||||
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
async with ClientSession(
|
||||
headers=headers,
|
||||
cookies=cls._cookies,
|
||||
connector=base_connector
|
||||
conversation_id = None if conversation is None else conversation.conversation_id
|
||||
prompt = format_prompt(messages) if conversation_id is None else get_last_user_message(messages)
|
||||
async with StreamSession(
|
||||
**auth_result.get_dict()
|
||||
) as session:
|
||||
payload = await cls._prepare_payload(model, prompt)
|
||||
response = await session.post(f"{cls.conversation_url}/new", json=payload)
|
||||
if conversation_id is None:
|
||||
url = f"{cls.conversation_url}/new"
|
||||
else:
|
||||
url = f"{cls.conversation_url}/{conversation_id}/responses"
|
||||
async with session.post(url, json=payload) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
thinking_duration = None
|
||||
async for line in response.content:
|
||||
async for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
json_data = json.loads(line)
|
||||
result = json_data.get("result", {})
|
||||
if conversation_id is None:
|
||||
conversation_id = result.get("conversation", {}).get("conversationId")
|
||||
response_data = result.get("response", {})
|
||||
image = response_data.get("streamingImageGenerationResponse", None)
|
||||
if image is not None:
|
||||
yield ImagePreview(f'{cls.assets_url}/{image["imageUrl"]}', "", {"cookies": cookies, "headers": headers})
|
||||
token = response_data.get("token", "")
|
||||
is_thinking = response_data.get("isThinking", False)
|
||||
token = response_data.get("token", result.get("token"))
|
||||
is_thinking = response_data.get("isThinking", result.get("isThinking"))
|
||||
if token:
|
||||
if is_thinking:
|
||||
if thinking_duration is None:
|
||||
@@ -145,9 +130,12 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
generated_images = response_data.get("modelResponse", {}).get("generatedImageUrls", None)
|
||||
if generated_images:
|
||||
yield ImageResponse([f'{cls.assets_url}/{image}' for image in generated_images], "", {"cookies": cookies, "headers": headers})
|
||||
title = response_data.get("title", {}).get("newTitle", "")
|
||||
title = result.get("title", {}).get("newTitle", "")
|
||||
if title:
|
||||
yield TitleGeneration(title)
|
||||
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if return_conversation and conversation_id is not None:
|
||||
yield Conversation(conversation_id)
|
@@ -18,7 +18,7 @@ except ImportError:
|
||||
has_nodriver = False
|
||||
|
||||
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests import StreamSession
|
||||
from ...requests import get_nodriver
|
||||
@@ -127,7 +127,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
cls,
|
||||
session: StreamSession,
|
||||
auth_result: AuthResult,
|
||||
images: ImagesType,
|
||||
media: MediaListType,
|
||||
) -> ImageRequest:
|
||||
"""
|
||||
Upload an image to the service and get the download URL
|
||||
@@ -135,7 +135,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
Args:
|
||||
session: The StreamSession object to use for requests
|
||||
headers: The headers to include in the requests
|
||||
images: The images to upload, either a PIL Image object or a bytes object
|
||||
media: The images to upload, either a PIL Image object or a bytes object
|
||||
|
||||
Returns:
|
||||
An ImageRequest object that contains the download URL, file name, and other data
|
||||
@@ -187,9 +187,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
await raise_for_status(response, "Get download url failed")
|
||||
image_data["download_url"] = (await response.json())["download_url"]
|
||||
return ImageRequest(image_data)
|
||||
if not images:
|
||||
if not media:
|
||||
return
|
||||
return [await upload_image(image, image_name) for image, image_name in images]
|
||||
return [await upload_image(image, image_name) for image, image_name in media]
|
||||
|
||||
@classmethod
|
||||
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
|
||||
@@ -268,7 +268,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
auto_continue: bool = False,
|
||||
action: str = "next",
|
||||
conversation: Conversation = None,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
return_conversation: bool = False,
|
||||
web_search: bool = False,
|
||||
**kwargs
|
||||
@@ -285,7 +285,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
auto_continue (bool): Flag to automatically continue the conversation.
|
||||
action (str): Type of action ('next', 'continue', 'variant').
|
||||
conversation_id (str): ID of the conversation.
|
||||
images (ImagesType): Images to include in the conversation.
|
||||
media (MediaListType): Images to include in the conversation.
|
||||
return_conversation (bool): Flag to include response fields in the output.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
@@ -316,7 +316,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
try:
|
||||
image_requests = await cls.upload_images(session, auth_result, images) if images else None
|
||||
image_requests = None if media is None else await cls.upload_images(session, auth_result, media)
|
||||
except Exception as e:
|
||||
debug.error("OpenaiChat: Upload image failed")
|
||||
debug.error(e)
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from ...typing import Messages, AsyncResult, ImagesType
|
||||
from ...typing import Messages, AsyncResult, MediaListType
|
||||
from ...requests import StreamSession
|
||||
from ...image import to_data_uri
|
||||
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
@@ -18,21 +18,21 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
api_key: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
debug.log(f"{cls.__name__}: {api_key}")
|
||||
if images is not None:
|
||||
for i in range(len(images)):
|
||||
images[i] = (to_data_uri(images[i][0]), images[i][1])
|
||||
if media is not None:
|
||||
for i in range(len(media)):
|
||||
media[i] = (to_data_uri(media[i][0]), media[i][1])
|
||||
async with StreamSession(
|
||||
headers={"Accept": "text/event-stream", **cls.headers},
|
||||
) as session:
|
||||
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"images": images,
|
||||
"media": media,
|
||||
"api_key": api_key,
|
||||
**kwargs
|
||||
}, ssl=cls.ssl) as response:
|
||||
|
@@ -5,11 +5,11 @@ import requests
|
||||
|
||||
from ..helper import filter_none, format_image_prompt
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ...image import to_data_uri
|
||||
from ...image import to_data_uri, is_data_an_audio, to_input_audio
|
||||
from ... import debug
|
||||
|
||||
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
||||
@@ -54,7 +54,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
api_key: str = None,
|
||||
api_endpoint: str = None,
|
||||
api_base: str = None,
|
||||
@@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
prompt: str = None,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
@@ -98,19 +98,26 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||
return
|
||||
|
||||
if images is not None and messages:
|
||||
if media is not None and messages:
|
||||
if not model and hasattr(cls, "default_vision_model"):
|
||||
model = cls.default_vision_model
|
||||
last_message = messages[-1].copy()
|
||||
last_message["content"] = [
|
||||
*[{
|
||||
*[
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": to_input_audio(media_data, filename)
|
||||
}
|
||||
if is_data_an_audio(media_data, filename) else {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(image)}
|
||||
} for image, _ in images],
|
||||
"image_url": {"url": to_data_uri(media_data, filename)}
|
||||
}
|
||||
for media_data, filename in media
|
||||
],
|
||||
{
|
||||
"type": "text",
|
||||
"text": messages[-1]["content"]
|
||||
}
|
||||
"text": last_message["content"]
|
||||
} if isinstance(last_message["content"], str) else last_message["content"]
|
||||
]
|
||||
messages[-1] = last_message
|
||||
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||
|
@@ -34,12 +34,14 @@ class ChatCompletion:
|
||||
ignore_stream: bool = False,
|
||||
**kwargs) -> Union[CreateResult, str]:
|
||||
if image is not None:
|
||||
kwargs["images"] = [(image, image_name)]
|
||||
kwargs["media"] = [(image, image_name)]
|
||||
elif "images" in kwargs:
|
||||
kwargs["media"] = kwargs.pop("images")
|
||||
model, provider = get_model_and_provider(
|
||||
model, provider, stream,
|
||||
ignore_working,
|
||||
ignore_stream,
|
||||
has_images="images" in kwargs,
|
||||
has_images="media" in kwargs,
|
||||
)
|
||||
if "proxy" not in kwargs:
|
||||
proxy = os.environ.get("G4F_PROXY")
|
||||
@@ -63,8 +65,10 @@ class ChatCompletion:
|
||||
ignore_working: bool = False,
|
||||
**kwargs) -> Union[AsyncResult, Coroutine[str]]:
|
||||
if image is not None:
|
||||
kwargs["images"] = [(image, image_name)]
|
||||
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="images" in kwargs)
|
||||
kwargs["media"] = [(image, image_name)]
|
||||
elif "images" in kwargs:
|
||||
kwargs["media"] = kwargs.pop("images")
|
||||
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="media" in kwargs)
|
||||
if "proxy" not in kwargs:
|
||||
proxy = os.environ.get("G4F_PROXY")
|
||||
if proxy:
|
||||
|
@@ -38,7 +38,7 @@ import g4f.debug
|
||||
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
|
||||
from g4f.providers.response import BaseConversation, JsonConversation
|
||||
from g4f.client.helper import filter_none
|
||||
from g4f.image import is_data_uri_an_media
|
||||
from g4f.image import is_data_an_media
|
||||
from g4f.image.copy_images import images_dir, copy_images, get_source_url
|
||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
|
||||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||
@@ -320,16 +320,18 @@ class Api:
|
||||
|
||||
if config.image is not None:
|
||||
try:
|
||||
is_data_uri_an_media(config.image)
|
||||
is_data_an_media(config.image)
|
||||
except ValueError as e:
|
||||
return ErrorResponse.from_message(f"The image you send must be a data URI. Example: data:image/jpeg;base64,...", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
if config.images is not None:
|
||||
for image in config.images:
|
||||
if config.media is None:
|
||||
config.media = config.images
|
||||
if config.media is not None:
|
||||
for image in config.media:
|
||||
try:
|
||||
is_data_uri_an_media(image[0])
|
||||
is_data_an_media(image[0])
|
||||
except ValueError as e:
|
||||
example = json.dumps({"images": [["data:image/jpeg;base64,...", "filename"]]})
|
||||
return ErrorResponse.from_message(f'The image you send must be a data URI. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
example = json.dumps({"media": [["data:image/jpeg;base64,...", "filename.jpg"]]})
|
||||
return ErrorResponse.from_message(f'The media you send must be a data URIs. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
# Create the completion response
|
||||
response = self.client.chat.completions.create(
|
||||
|
@@ -17,6 +17,7 @@ class ChatCompletionsConfig(BaseModel):
|
||||
image: Optional[str] = None
|
||||
image_name: Optional[str] = None
|
||||
images: Optional[list[tuple[str, str]]] = None
|
||||
media: Optional[list[tuple[str, str]]] = None
|
||||
temperature: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
|
@@ -293,14 +293,16 @@ class Completions:
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
if image is not None:
|
||||
kwargs["images"] = [(image, image_name)]
|
||||
kwargs["media"] = [(image, image_name)]
|
||||
elif "images" in kwargs:
|
||||
kwargs["media"] = kwargs.pop("images")
|
||||
model, provider = get_model_and_provider(
|
||||
model,
|
||||
self.provider if provider is None else provider,
|
||||
stream,
|
||||
ignore_working,
|
||||
ignore_stream,
|
||||
has_images="images" in kwargs
|
||||
has_images="media" in kwargs
|
||||
)
|
||||
stop = [stop] if isinstance(stop, str) else stop
|
||||
if ignore_stream:
|
||||
@@ -481,7 +483,7 @@ class Images:
|
||||
proxy = self.client.proxy
|
||||
prompt = "create a variation of this image"
|
||||
if image is not None:
|
||||
kwargs["images"] = [(image, None)]
|
||||
kwargs["media"] = [(image, None)]
|
||||
|
||||
error = None
|
||||
response = None
|
||||
@@ -581,14 +583,16 @@ class AsyncCompletions:
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
if image is not None:
|
||||
kwargs["images"] = [(image, image_name)]
|
||||
kwargs["media"] = [(image, image_name)]
|
||||
elif "images" in kwargs:
|
||||
kwargs["media"] = kwargs.pop("images")
|
||||
model, provider = get_model_and_provider(
|
||||
model,
|
||||
self.provider if provider is None else provider,
|
||||
stream,
|
||||
ignore_working,
|
||||
ignore_stream,
|
||||
has_images="images" in kwargs,
|
||||
has_images="media" in kwargs,
|
||||
)
|
||||
stop = [stop] if isinstance(stop, str) else stop
|
||||
if ignore_stream:
|
||||
|
@@ -67,7 +67,7 @@ DOMAINS = [
|
||||
if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
|
||||
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
|
||||
|
||||
def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, single_browser: bool = False) -> Dict[str, str]:
|
||||
def get_cookies(domain_name: str, raise_requirements_error: bool = True, single_browser: bool = False, cache_result: bool = True) -> Dict[str, str]:
|
||||
"""
|
||||
Load cookies for a given domain from all supported browsers and cache the results.
|
||||
|
||||
@@ -77,10 +77,11 @@ def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, si
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary of cookie names and values.
|
||||
"""
|
||||
if domain_name in CookiesConfig.cookies:
|
||||
if cache_result and domain_name in CookiesConfig.cookies:
|
||||
return CookiesConfig.cookies[domain_name]
|
||||
|
||||
cookies = load_cookies_from_browsers(domain_name, raise_requirements_error, single_browser)
|
||||
if cache_result:
|
||||
CookiesConfig.cookies[domain_name] = cookies
|
||||
return cookies
|
||||
|
||||
@@ -108,8 +109,8 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
|
||||
for cookie_fn in browsers:
|
||||
try:
|
||||
cookie_jar = cookie_fn(domain_name=domain_name)
|
||||
if len(cookie_jar) and debug.logging:
|
||||
print(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
|
||||
if len(cookie_jar):
|
||||
debug.log(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
|
||||
for cookie in cookie_jar:
|
||||
if cookie.name not in cookies:
|
||||
if not cookie.expires or cookie.expires > time.time():
|
||||
@@ -119,8 +120,7 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
|
||||
except BrowserCookieError:
|
||||
pass
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
|
||||
debug.error(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
|
||||
return cookies
|
||||
|
||||
def set_cookies_dir(dir: str) -> None:
|
||||
|
@@ -86,6 +86,23 @@
|
||||
height: 100%;
|
||||
text-align: center;
|
||||
z-index: 1;
|
||||
transition: transform 0.25s ease-in;
|
||||
}
|
||||
|
||||
.container.slide {
|
||||
transform: translateX(-100%);
|
||||
transition: transform 0.15s ease-out;
|
||||
}
|
||||
|
||||
.slide-button {
|
||||
position: absolute;
|
||||
top: 20px;
|
||||
left: 20px;
|
||||
background: var(--colour-4);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
padding: 6px 8px;
|
||||
}
|
||||
|
||||
header {
|
||||
@@ -105,7 +122,8 @@
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
z-index: -1;
|
||||
object-fit: contain;
|
||||
object-fit: cover;
|
||||
object-position: center;
|
||||
width: 100%;
|
||||
background: black;
|
||||
}
|
||||
@@ -181,6 +199,7 @@
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<link rel="stylesheet" href="/static/css/all.min.css">
|
||||
<script>
|
||||
(async () => {
|
||||
const isIframe = window.self !== window.top;
|
||||
@@ -208,6 +227,10 @@
|
||||
<!-- Gradient Background Circle -->
|
||||
<div class="gradient"></div>
|
||||
|
||||
<button class="slide-button">
|
||||
<i class="fa-solid fa-arrow-left"></i>
|
||||
</button>
|
||||
|
||||
<!-- Main Content -->
|
||||
<div class="container">
|
||||
<header>
|
||||
@@ -277,8 +300,11 @@
|
||||
if (oauthResult) {
|
||||
try {
|
||||
oauthResult = JSON.parse(oauthResult);
|
||||
user = await hub.whoAmI({accessToken: oauthResult.accessToken});
|
||||
} catch {
|
||||
oauthResult = null;
|
||||
localStorage.removeItem("oauth");
|
||||
localStorage.removeItem("HuggingFace-api_key");
|
||||
}
|
||||
}
|
||||
oauthResult ||= await oauthHandleRedirectIfPresent();
|
||||
@@ -339,7 +365,7 @@
|
||||
return;
|
||||
}
|
||||
const lower = data.prompt.toLowerCase();
|
||||
const tags = ["nsfw", "timeline", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
|
||||
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
|
||||
for (i in tags) {
|
||||
if (lower.indexOf(tags[i]) != -1) {
|
||||
console.log("Skipping image with tag: " + tags[i]);
|
||||
@@ -363,6 +389,21 @@
|
||||
imageFeed.remove();
|
||||
}
|
||||
}, 7000);
|
||||
|
||||
const container = document.querySelector('.container');
|
||||
const button = document.querySelector('.slide-button');
|
||||
const slideIcon = button.querySelector('i');
|
||||
button.onclick = () => {
|
||||
if (container.classList.contains('slide')) {
|
||||
container.classList.remove('slide');
|
||||
slideIcon.classList.remove('fa-arrow-right');
|
||||
slideIcon.classList.add('fa-arrow-left');
|
||||
} else {
|
||||
container.classList.add('slide');
|
||||
slideIcon.classList.remove('fa-arrow-left');
|
||||
slideIcon.classList.add('fa-arrow-right');
|
||||
}
|
||||
}
|
||||
})();
|
||||
</script>
|
||||
</body>
|
||||
|
@@ -141,7 +141,7 @@ class Api:
|
||||
stream=True,
|
||||
ignore_stream=True,
|
||||
logging=False,
|
||||
has_images="images" in kwargs,
|
||||
has_images="media" in kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
debug.error(e)
|
||||
|
@@ -8,6 +8,7 @@ import asyncio
|
||||
import shutil
|
||||
import random
|
||||
import datetime
|
||||
import tempfile
|
||||
from flask import Flask, Response, request, jsonify, render_template
|
||||
from typing import Generator
|
||||
from pathlib import Path
|
||||
@@ -111,11 +112,13 @@ class Backend_Api(Api):
|
||||
else:
|
||||
json_data = request.json
|
||||
if "files" in request.files:
|
||||
images = []
|
||||
media = []
|
||||
for file in request.files.getlist('files'):
|
||||
if file.filename != '' and is_allowed_extension(file.filename):
|
||||
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
|
||||
json_data['images'] = images
|
||||
newfile = tempfile.TemporaryFile()
|
||||
shutil.copyfileobj(file.stream, newfile)
|
||||
media.append((newfile, file.filename))
|
||||
json_data['media'] = media
|
||||
|
||||
if app.demo and not json_data.get("provider"):
|
||||
model = json_data.get("model")
|
||||
|
@@ -76,24 +76,24 @@ def is_allowed_extension(filename: str) -> bool:
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
def is_data_uri_an_media(data_uri: str) -> str:
|
||||
return is_data_an_wav(data_uri) or is_data_uri_an_image(data_uri)
|
||||
def is_data_an_media(data, filename: str = None) -> str:
|
||||
content_type = is_data_an_audio(data, filename)
|
||||
if content_type is not None:
|
||||
return content_type
|
||||
if isinstance(data, bytes):
|
||||
return is_accepted_format(data)
|
||||
return is_data_uri_an_image(data)
|
||||
|
||||
def is_data_an_wav(data_uri: str, filename: str = None) -> str:
|
||||
"""
|
||||
Checks if the given data URI represents an image.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the data URI is invalid or the image format is not allowed.
|
||||
"""
|
||||
if filename and filename.endswith(".wav"):
|
||||
return "audio/wav"
|
||||
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
|
||||
if isinstance(data_uri, str) and re.match(r'data:audio/wav;base64,', data_uri):
|
||||
def is_data_an_audio(data_uri: str, filename: str = None) -> str:
|
||||
if filename:
|
||||
if filename.endswith(".wav"):
|
||||
return "audio/wav"
|
||||
elif filename.endswith(".mp3"):
|
||||
return "audio/mpeg"
|
||||
if isinstance(data_uri, str):
|
||||
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
|
||||
if audio_format:
|
||||
return audio_format.group(1)
|
||||
|
||||
def is_data_uri_an_image(data_uri: str) -> bool:
|
||||
"""
|
||||
@@ -218,7 +218,7 @@ def to_bytes(image: ImageType) -> bytes:
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
elif isinstance(image, str) and image.startswith("data:"):
|
||||
is_data_uri_an_media(image)
|
||||
is_data_an_media(image)
|
||||
return extract_data_uri(image)
|
||||
elif isinstance(image, Image):
|
||||
bytes_io = BytesIO()
|
||||
@@ -236,13 +236,29 @@ def to_bytes(image: ImageType) -> bytes:
|
||||
pass
|
||||
return image.read()
|
||||
|
||||
def to_data_uri(image: ImageType) -> str:
|
||||
def to_data_uri(image: ImageType, filename: str = None) -> str:
|
||||
if not isinstance(image, str):
|
||||
data = to_bytes(image)
|
||||
data_base64 = base64.b64encode(data).decode()
|
||||
return f"data:{is_accepted_format(data)};base64,{data_base64}"
|
||||
return f"data:{is_data_an_media(data, filename)};base64,{data_base64}"
|
||||
return image
|
||||
|
||||
def to_input_audio(audio: ImageType, filename: str = None) -> str:
|
||||
if not isinstance(audio, str):
|
||||
if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
|
||||
return {
|
||||
"data": base64.b64encode(to_bytes(audio)).decode(),
|
||||
"format": "wav" if filename.endswith(".wav") else "mpeg"
|
||||
}
|
||||
raise ValueError("Invalid input audio")
|
||||
audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
|
||||
if audio:
|
||||
return {
|
||||
"data": audio.group(2),
|
||||
"format": audio.group(1),
|
||||
}
|
||||
raise ValueError("Invalid input audio")
|
||||
|
||||
class ImageDataResponse():
|
||||
def __init__(
|
||||
self,
|
||||
|
@@ -109,7 +109,7 @@ async def copy_images(
|
||||
f.write(chunk)
|
||||
|
||||
# Verify file format
|
||||
if not os.path.splitext(target_path)[1]:
|
||||
if target is None and not os.path.splitext(target_path)[1]:
|
||||
with open(target_path, "rb") as f:
|
||||
file_header = f.read(12)
|
||||
detected_type = is_accepted_format(file_header)
|
||||
@@ -120,7 +120,7 @@ async def copy_images(
|
||||
|
||||
# Build URL with safe encoding
|
||||
url_filename = quote(os.path.basename(target_path))
|
||||
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
|
||||
return f"/images/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
|
||||
|
||||
except (ClientError, IOError, OSError) as e:
|
||||
debug.error(f"Image copying failed: {type(e).__name__}: {e}")
|
||||
|
@@ -25,7 +25,7 @@ from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
"model", "messages", "stream", "timeout",
|
||||
"proxy", "images", "response_format",
|
||||
"proxy", "media", "response_format",
|
||||
"prompt", "negative_prompt", "tools", "conversation",
|
||||
"history_disabled",
|
||||
"temperature", "top_k", "top_p",
|
||||
@@ -56,7 +56,7 @@ PARAMETER_EXAMPLES = {
|
||||
"frequency_penalty": 1,
|
||||
"presence_penalty": 1,
|
||||
"messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
|
||||
"images": [["data:image/jpeg;base64,...", "filename.jpg"]],
|
||||
"media": [["data:image/jpeg;base64,...", "filename.jpg"]],
|
||||
"response_format": {"type": "json_object"},
|
||||
"conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
|
||||
"seed": 42,
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from ..typing import AsyncResult, Messages, ImagesType
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..client.service import get_model_and_provider
|
||||
from ..client.helper import filter_json
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
@@ -17,7 +17,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
images: ImagesType = None,
|
||||
media: MediaListType = None,
|
||||
tools: list[str] = None,
|
||||
response_format: dict = None,
|
||||
**kwargs
|
||||
@@ -28,7 +28,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
||||
model, provider = get_model_and_provider(
|
||||
model, provider,
|
||||
stream, logging=False,
|
||||
has_images=images is not None
|
||||
has_images=media is not None
|
||||
)
|
||||
if tools is not None:
|
||||
if len(tools) > 1:
|
||||
@@ -49,7 +49,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
images=images,
|
||||
media=media,
|
||||
response_format=response_format,
|
||||
**kwargs
|
||||
):
|
||||
|
@@ -21,7 +21,7 @@ AsyncResult = AsyncIterator[Union[str, ResponseType]]
|
||||
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
|
||||
Cookies = Dict[str, str]
|
||||
ImageType = Union[str, bytes, IO, Image, os.PathLike]
|
||||
ImagesType = List[Tuple[ImageType, Optional[str]]]
|
||||
MediaListType = List[Tuple[ImageType, Optional[str]]]
|
||||
|
||||
__all__ = [
|
||||
'Any',
|
||||
@@ -44,5 +44,5 @@ __all__ = [
|
||||
'Cookies',
|
||||
'Image',
|
||||
'ImageType',
|
||||
'ImagesType'
|
||||
'MediaListType'
|
||||
]
|
||||
|
Reference in New Issue
Block a user