mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-06 00:36:57 +08:00
Add audio transcribing example and support
Add Grok Chat provider Rename images parameter to media Update demo homepage
This commit is contained in:
@@ -19,6 +19,7 @@ The G4F AsyncClient API is designed to be compatible with the OpenAI API, making
|
|||||||
- [Text Completions](#text-completions)
|
- [Text Completions](#text-completions)
|
||||||
- [Streaming Completions](#streaming-completions)
|
- [Streaming Completions](#streaming-completions)
|
||||||
- [Using a Vision Model](#using-a-vision-model)
|
- [Using a Vision Model](#using-a-vision-model)
|
||||||
|
- **[Transcribing Audio with Chat Completions](#transcribing-audio-with-chat-completions)** *(New Section)*
|
||||||
- [Image Generation](#image-generation)
|
- [Image Generation](#image-generation)
|
||||||
- [Advanced Usage](#advanced-usage)
|
- [Advanced Usage](#advanced-usage)
|
||||||
- [Conversation Memory](#conversation-memory)
|
- [Conversation Memory](#conversation-memory)
|
||||||
@@ -203,6 +204,54 @@ async def main():
|
|||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Transcribing Audio with Chat Completions
|
||||||
|
|
||||||
|
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
from g4f.client import AsyncClient
|
||||||
|
import g4f.Provider
|
||||||
|
import g4f.models
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
client = AsyncClient(provider=g4f.Provider.PollinationsAI) # or g4f.Provider.Microsoft_Phi_4
|
||||||
|
|
||||||
|
with open("audio.wav", "rb") as audio_file:
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=g4f.models.default,
|
||||||
|
messages=[{"role": "user", "content": "Transcribe this audio"}],
|
||||||
|
media=[[audio_file, "audio.wav"]],
|
||||||
|
modalities=["text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Explanation
|
||||||
|
- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`.
|
||||||
|
- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use.
|
||||||
|
- **API Call**: The `chat.completions.create` method is called with:
|
||||||
|
- `model=g4f.models.default`: Uses the default model for the selected provider.
|
||||||
|
- `messages`: A list containing a user message instructing the model to transcribe the audio.
|
||||||
|
- `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`).
|
||||||
|
- `modalities=["text"]`: Specifies that the output should be text (the transcription).
|
||||||
|
- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed.
|
||||||
|
|
||||||
|
#### Notes
|
||||||
|
- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality.
|
||||||
|
- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider.
|
||||||
|
- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models).
|
||||||
|
|
||||||
|
This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### Image Generation
|
### Image Generation
|
||||||
**The `response_format` parameter is optional and can have the following values:**
|
**The `response_format` parameter is optional and can have the following values:**
|
||||||
- **If not specified (default):** The image will be saved locally, and a local path will be returned (e.g., "/images/1733331238_cf9d6aa9-f606-4fea-ba4b-f06576cba309.jpg").
|
- **If not specified (default):** The image will be saved locally, and a local path will be returned (e.g., "/images/1733331238_cf9d6aa9-f606-4fea-ba4b-f06576cba309.jpg").
|
||||||
|
@@ -6,10 +6,10 @@ from g4f.models import __models__
|
|||||||
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
|
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
|
||||||
from g4f.errors import MissingRequirementsError, MissingAuthError
|
from g4f.errors import MissingRequirementsError, MissingAuthError
|
||||||
|
|
||||||
class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
|
class TestProviderHasModel(unittest.TestCase):
|
||||||
cache: dict = {}
|
cache: dict = {}
|
||||||
|
|
||||||
async def test_provider_has_model(self):
|
def test_provider_has_model(self):
|
||||||
for model, providers in __models__.values():
|
for model, providers in __models__.values():
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if issubclass(provider, ProviderModelMixin):
|
if issubclass(provider, ProviderModelMixin):
|
||||||
@@ -17,9 +17,9 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
|
|||||||
model_name = provider.model_aliases[model.name]
|
model_name = provider.model_aliases[model.name]
|
||||||
else:
|
else:
|
||||||
model_name = model.name
|
model_name = model.name
|
||||||
await asyncio.wait_for(self.provider_has_model(provider, model_name), 10)
|
self.provider_has_model(provider, model_name)
|
||||||
|
|
||||||
async def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
||||||
if provider.__name__ not in self.cache:
|
if provider.__name__ not in self.cache:
|
||||||
try:
|
try:
|
||||||
self.cache[provider.__name__] = provider.get_models()
|
self.cache[provider.__name__] = provider.get_models()
|
||||||
@@ -28,7 +28,7 @@ class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
|
|||||||
if self.cache[provider.__name__]:
|
if self.cache[provider.__name__]:
|
||||||
self.assertIn(model, self.cache[provider.__name__], provider.__name__)
|
self.assertIn(model, self.cache[provider.__name__], provider.__name__)
|
||||||
|
|
||||||
async def test_all_providers_working(self):
|
def test_all_providers_working(self):
|
||||||
for model, providers in __models__.values():
|
for model, providers in __models__.values():
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
self.assertTrue(provider.working, f"{provider.__name__} in {model.name}")
|
self.assertTrue(provider.working, f"{provider.__name__} in {model.name}")
|
@@ -48,7 +48,7 @@ class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"olmo-2-13b": "OLMo-2-1124-13B-Instruct",
|
"olmo-2-13b": "OLMo-2-1124-13B-Instruct",
|
||||||
"tulu-3-1-8b": "tulu-3-1-8b",
|
"tulu-3-1-8b": "tulu-3-1-8b",
|
||||||
"tulu-3-70b": "Llama-3-1-Tulu-3-70B",
|
"tulu-3-70b": "Llama-3-1-Tulu-3-70B",
|
||||||
"llama-3.1-405b": "tulu-3-405b",
|
"llama-3.1-405b": "tulu3-405b",
|
||||||
"llama-3.1-8b": "tulu-3-1-8b",
|
"llama-3.1-8b": "tulu-3-1-8b",
|
||||||
"llama-3.1-70b": "Llama-3-1-Tulu-3-70B",
|
"llama-3.1-70b": "Llama-3-1-Tulu-3-70B",
|
||||||
}
|
}
|
||||||
|
@@ -11,7 +11,7 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages, ImagesType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..image import to_data_uri
|
from ..image import to_data_uri
|
||||||
@@ -444,7 +444,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
messages: Messages,
|
messages: Messages,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
temperature: float = None,
|
temperature: float = None,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
@@ -479,14 +479,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
}
|
}
|
||||||
current_messages.append(current_msg)
|
current_messages.append(current_msg)
|
||||||
|
|
||||||
if images is not None:
|
if media is not None:
|
||||||
current_messages[-1]['data'] = {
|
current_messages[-1]['data'] = {
|
||||||
"imagesData": [
|
"imagesData": [
|
||||||
{
|
{
|
||||||
"filePath": f"/{image_name}",
|
"filePath": f"/{image_name}",
|
||||||
"contents": to_data_uri(image)
|
"contents": to_data_uri(image)
|
||||||
}
|
}
|
||||||
for image, image_name in images
|
for image, image_name in media
|
||||||
],
|
],
|
||||||
"fileText": "",
|
"fileText": "",
|
||||||
"title": ""
|
"title": ""
|
||||||
|
@@ -21,7 +21,7 @@ except ImportError:
|
|||||||
from .base_provider import AbstractProvider, ProviderModelMixin
|
from .base_provider import AbstractProvider, ProviderModelMixin
|
||||||
from .helper import format_prompt_max_length
|
from .helper import format_prompt_max_length
|
||||||
from .openai.har_file import get_headers, get_har_files
|
from .openai.har_file import get_headers, get_har_files
|
||||||
from ..typing import CreateResult, Messages, ImagesType
|
from ..typing import CreateResult, Messages, MediaListType
|
||||||
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
|
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
|
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
|
||||||
@@ -66,7 +66,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 900,
|
timeout: int = 900,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
conversation: BaseConversation = None,
|
conversation: BaseConversation = None,
|
||||||
return_conversation: bool = False,
|
return_conversation: bool = False,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
@@ -77,7 +77,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
websocket_url = cls.websocket_url
|
websocket_url = cls.websocket_url
|
||||||
headers = None
|
headers = None
|
||||||
if cls.needs_auth or images is not None:
|
if cls.needs_auth or media is not None:
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
cls._access_token = api_key
|
cls._access_token = api_key
|
||||||
if cls._access_token is None:
|
if cls._access_token is None:
|
||||||
@@ -142,8 +142,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
||||||
|
|
||||||
uploaded_images = []
|
uploaded_images = []
|
||||||
if images is not None:
|
if media is not None:
|
||||||
for image, _ in images:
|
for image, _ in media:
|
||||||
data = to_bytes(image)
|
data = to_bytes(image)
|
||||||
response = session.post(
|
response = session.post(
|
||||||
"https://copilot.microsoft.com/c/api/attachments",
|
"https://copilot.microsoft.com/c/api/attachments",
|
||||||
|
@@ -30,7 +30,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
api_endpoint = "https://duckduckgo.com/duckchat/v1/chat"
|
api_endpoint = "https://duckduckgo.com/duckchat/v1/chat"
|
||||||
status_url = "https://duckduckgo.com/duckchat/v1/status"
|
status_url = "https://duckduckgo.com/duckchat/v1/status"
|
||||||
|
|
||||||
working = True
|
working = False
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages, ImagesType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from .template import OpenaiTemplate
|
from .template import OpenaiTemplate
|
||||||
from ..image import to_data_uri
|
from ..image import to_data_uri
|
||||||
|
|
||||||
@@ -70,7 +70,6 @@ class DeepInfraChat(OpenaiTemplate):
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
images: ImagesType = None,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
headers = {
|
headers = {
|
||||||
@@ -82,23 +81,6 @@ class DeepInfraChat(OpenaiTemplate):
|
|||||||
**headers
|
**headers
|
||||||
}
|
}
|
||||||
|
|
||||||
if images is not None:
|
|
||||||
if not model or model not in cls.models:
|
|
||||||
model = cls.default_vision_model
|
|
||||||
if messages:
|
|
||||||
last_message = messages[-1].copy()
|
|
||||||
last_message["content"] = [
|
|
||||||
*[{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": to_data_uri(image)}
|
|
||||||
} for image, _ in images],
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": last_message["content"]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
messages[-1] = last_message
|
|
||||||
|
|
||||||
async for chunk in super().create_async_generator(
|
async for chunk in super().create_async_generator(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
@@ -3,13 +3,12 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from aiohttp import ClientSession, FormData
|
from aiohttp import ClientSession, FormData
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages, ImagesType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
from ..image import to_data_uri, to_bytes, is_accepted_format
|
from ..image import to_bytes, is_accepted_format
|
||||||
from .helper import format_prompt
|
from .helper import format_prompt
|
||||||
|
|
||||||
|
|
||||||
class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://dynaspark.onrender.com"
|
url = "https://dynaspark.onrender.com"
|
||||||
login_url = None
|
login_url = None
|
||||||
@@ -38,7 +37,7 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
headers = {
|
headers = {
|
||||||
@@ -49,14 +48,13 @@ class Dynaspark(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
|
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
|
||||||
'x-requested-with': 'XMLHttpRequest'
|
'x-requested-with': 'XMLHttpRequest'
|
||||||
}
|
}
|
||||||
|
|
||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=headers) as session:
|
||||||
form = FormData()
|
form = FormData()
|
||||||
form.add_field('user_input', format_prompt(messages))
|
form.add_field('user_input', format_prompt(messages))
|
||||||
form.add_field('ai_model', model)
|
form.add_field('ai_model', model)
|
||||||
|
|
||||||
if images is not None and len(images) > 0:
|
if media is not None and len(media) > 0:
|
||||||
image, image_name = images[0]
|
image, image_name = media[0]
|
||||||
image_bytes = to_bytes(image)
|
image_bytes = to_bytes(image)
|
||||||
form.add_field('file', image_bytes, filename=image_name, content_type=is_accepted_format(image_bytes))
|
form.add_field('file', image_bytes, filename=image_name, content_type=is_accepted_format(image_bytes))
|
||||||
|
|
||||||
|
@@ -14,6 +14,7 @@ class LambdaChat(HuggingChat):
|
|||||||
default_model = "deepseek-llama3.3-70b"
|
default_model = "deepseek-llama3.3-70b"
|
||||||
reasoning_model = "deepseek-r1"
|
reasoning_model = "deepseek-r1"
|
||||||
image_models = []
|
image_models = []
|
||||||
|
models = []
|
||||||
fallback_models = [
|
fallback_models = [
|
||||||
default_model,
|
default_model,
|
||||||
reasoning_model,
|
reasoning_model,
|
||||||
|
@@ -9,8 +9,8 @@ from aiohttp import ClientSession
|
|||||||
|
|
||||||
from .helper import filter_none, format_image_prompt
|
from .helper import filter_none, format_image_prompt
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..typing import AsyncResult, Messages, ImagesType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..image import to_data_uri
|
from ..image import to_data_uri, is_data_an_audio, to_input_audio
|
||||||
from ..errors import ModelNotFoundError
|
from ..errors import ModelNotFoundError
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
from ..requests.aiohttp import get_connector
|
from ..requests.aiohttp import get_connector
|
||||||
@@ -146,17 +146,24 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
enhance: bool = False,
|
enhance: bool = False,
|
||||||
safe: bool = False,
|
safe: bool = False,
|
||||||
# Text generation parameters
|
# Text generation parameters
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
temperature: float = None,
|
temperature: float = None,
|
||||||
presence_penalty: float = None,
|
presence_penalty: float = None,
|
||||||
top_p: float = 1,
|
top_p: float = 1,
|
||||||
frequency_penalty: float = None,
|
frequency_penalty: float = None,
|
||||||
response_format: Optional[dict] = None,
|
response_format: Optional[dict] = None,
|
||||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice"],
|
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice", "modalities"],
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
# Load model list
|
# Load model list
|
||||||
cls.get_models()
|
cls.get_models()
|
||||||
|
if not model and media is not None:
|
||||||
|
has_audio = False
|
||||||
|
for media_data, filename in media:
|
||||||
|
if is_data_an_audio(media_data, filename):
|
||||||
|
has_audio = True
|
||||||
|
break
|
||||||
|
model = next(iter(cls.audio_models)) if has_audio else cls.default_vision_model
|
||||||
try:
|
try:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
except ModelNotFoundError:
|
except ModelNotFoundError:
|
||||||
@@ -182,7 +189,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
async for result in cls._generate_text(
|
async for result in cls._generate_text(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
images=images,
|
media=media,
|
||||||
proxy=proxy,
|
proxy=proxy,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
@@ -239,7 +246,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
images: Optional[ImagesType],
|
media: MediaListType,
|
||||||
proxy: str,
|
proxy: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
presence_penalty: float,
|
presence_penalty: float,
|
||||||
@@ -258,14 +265,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if response_format and response_format.get("type") == "json_object":
|
if response_format and response_format.get("type") == "json_object":
|
||||||
json_mode = True
|
json_mode = True
|
||||||
|
|
||||||
if images and messages:
|
if media and messages:
|
||||||
last_message = messages[-1].copy()
|
last_message = messages[-1].copy()
|
||||||
image_content = [
|
image_content = [
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "input_audio",
|
||||||
"image_url": {"url": to_data_uri(image)}
|
"input_audio": to_input_audio(media_data, filename)
|
||||||
}
|
}
|
||||||
for image, _ in images
|
if is_data_an_audio(media_data, filename) else {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": to_data_uri(media_data)}
|
||||||
|
}
|
||||||
|
for media_data, filename in media
|
||||||
]
|
]
|
||||||
last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
|
last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
|
||||||
messages[-1] = last_message
|
messages[-1] = last_message
|
||||||
|
@@ -17,7 +17,7 @@ except ImportError:
|
|||||||
|
|
||||||
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
||||||
from ..helper import format_prompt, format_image_prompt, get_last_user_message
|
from ..helper import format_prompt, format_image_prompt, get_last_user_message
|
||||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||||
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
|
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
|
||||||
from ...image import to_bytes
|
from ...image import to_bytes
|
||||||
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
||||||
@@ -99,7 +99,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
messages: Messages,
|
messages: Messages,
|
||||||
auth_result: AuthResult,
|
auth_result: AuthResult,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
return_conversation: bool = False,
|
return_conversation: bool = False,
|
||||||
conversation: Conversation = None,
|
conversation: Conversation = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
@@ -108,7 +108,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
if not has_curl_cffi:
|
if not has_curl_cffi:
|
||||||
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
||||||
if model == llama_models["name"]:
|
if model == llama_models["name"]:
|
||||||
model = llama_models["text"] if images is None else llama_models["vision"]
|
model = llama_models["text"] if media is None else llama_models["vision"]
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
|
|
||||||
session = Session(**auth_result.get_dict())
|
session = Session(**auth_result.get_dict())
|
||||||
@@ -145,8 +145,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
}
|
}
|
||||||
data = CurlMime()
|
data = CurlMime()
|
||||||
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
|
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
|
||||||
if images is not None:
|
if media is not None:
|
||||||
for image, filename in images:
|
for image, filename in media:
|
||||||
data.addpart(
|
data.addpart(
|
||||||
"files",
|
"files",
|
||||||
filename=f"base64;{filename}",
|
filename=f"base64;{filename}",
|
||||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ...providers.types import Messages
|
from ...providers.types import Messages
|
||||||
from ...typing import ImagesType
|
from ...typing import MediaListType
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...errors import ModelNotSupportedError
|
from ...errors import ModelNotSupportedError
|
||||||
from ...providers.helper import get_last_user_message
|
from ...providers.helper import get_last_user_message
|
||||||
@@ -75,11 +75,11 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
max_inputs_lenght: int = 10000,
|
max_inputs_lenght: int = 10000,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if model == llama_models["name"]:
|
if model == llama_models["name"]:
|
||||||
model = llama_models["text"] if images is None else llama_models["vision"]
|
model = llama_models["text"] if media is None else llama_models["vision"]
|
||||||
if model in cls.model_aliases:
|
if model in cls.model_aliases:
|
||||||
model = cls.model_aliases[model]
|
model = cls.model_aliases[model]
|
||||||
provider_mapping = await cls.get_mapping(model, api_key)
|
provider_mapping = await cls.get_mapping(model, api_key)
|
||||||
@@ -103,7 +103,7 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||||||
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
|
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
|
||||||
messages = last_user_message
|
messages = last_user_message
|
||||||
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
|
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
|
||||||
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
|
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def calculate_lenght(messages: Messages) -> int:
|
def calculate_lenght(messages: Messages) -> int:
|
||||||
|
@@ -7,7 +7,7 @@ import random
|
|||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import format_prompt, format_image_prompt
|
from ..helper import format_prompt, format_image_prompt
|
||||||
from ...providers.response import JsonConversation, ImageResponse, Reasoning
|
from ...providers.response import JsonConversation, ImageResponse, Reasoning
|
||||||
@@ -68,7 +68,7 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
cookies: Cookies = None,
|
cookies: Cookies = None,
|
||||||
@@ -98,27 +98,27 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield conversation
|
yield conversation
|
||||||
|
|
||||||
if images is not None:
|
if media is not None:
|
||||||
data = FormData()
|
data = FormData()
|
||||||
for i in range(len(images)):
|
for i in range(len(media)):
|
||||||
images[i] = (to_bytes(images[i][0]), images[i][1])
|
media[i] = (to_bytes(media[i][0]), media[i][1])
|
||||||
for image, image_name in images:
|
for image, image_name in media:
|
||||||
data.add_field(f"files", image, filename=image_name)
|
data.add_field(f"files", image, filename=image_name)
|
||||||
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
|
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
image_files = await response.json()
|
image_files = await response.json()
|
||||||
images = [{
|
media = [{
|
||||||
"path": image_file,
|
"path": image_file,
|
||||||
"url": f"{cls.api_url}/gradio_api/file={image_file}",
|
"url": f"{cls.api_url}/gradio_api/file={image_file}",
|
||||||
"orig_name": images[i][1],
|
"orig_name": media[i][1],
|
||||||
"size": len(images[i][0]),
|
"size": len(media[i][0]),
|
||||||
"mime_type": is_accepted_format(images[i][0]),
|
"mime_type": is_accepted_format(media[i][0]),
|
||||||
"meta": {
|
"meta": {
|
||||||
"_type": "gradio.FileData"
|
"_type": "gradio.FileData"
|
||||||
}
|
}
|
||||||
} for i, image_file in enumerate(image_files)]
|
} for i, image_file in enumerate(image_files)]
|
||||||
|
|
||||||
async with cls.run(method, session, prompt, conversation, None if images is None else images.pop(), seed) as response:
|
async with cls.run(method, session, prompt, conversation, None if media is None else media.pop(), seed) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
|
||||||
async with cls.run("get", session, prompt, conversation, None, seed) as response:
|
async with cls.run("get", session, prompt, conversation, None, seed) as response:
|
||||||
|
@@ -3,13 +3,13 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import format_prompt, format_image_prompt
|
from ..helper import format_prompt, format_image_prompt
|
||||||
from ...providers.response import JsonConversation
|
from ...providers.response import JsonConversation
|
||||||
from ...requests.aiohttp import StreamSession, StreamResponse, FormData
|
from ...requests.aiohttp import StreamSession, StreamResponse, FormData
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...image import to_bytes, is_accepted_format, is_data_an_wav
|
from ...image import to_bytes, is_accepted_format, is_data_an_audio
|
||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ... import debug
|
from ... import debug
|
||||||
from .DeepseekAI_JanusPro7b import get_zerogpu_token
|
from .DeepseekAI_JanusPro7b import get_zerogpu_token
|
||||||
@@ -32,7 +32,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
models = [default_model]
|
models = [default_model]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, images: list = None):
|
def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, media: list = None):
|
||||||
headers = {
|
headers = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"x-zerogpu-token": conversation.zerogpu_token,
|
"x-zerogpu-token": conversation.zerogpu_token,
|
||||||
@@ -47,7 +47,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
[],
|
[],
|
||||||
{
|
{
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
"files": images,
|
"files": media,
|
||||||
},
|
},
|
||||||
None
|
None
|
||||||
],
|
],
|
||||||
@@ -70,7 +70,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": {"file": image}
|
"content": {"file": image}
|
||||||
} for image in images
|
} for image in media
|
||||||
]],
|
]],
|
||||||
"event_data": None,
|
"event_data": None,
|
||||||
"fn_index": 11,
|
"fn_index": 11,
|
||||||
@@ -91,7 +91,7 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
cookies: Cookies = None,
|
cookies: Cookies = None,
|
||||||
@@ -115,23 +115,23 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield conversation
|
yield conversation
|
||||||
|
|
||||||
if images is not None:
|
if media is not None:
|
||||||
data = FormData()
|
data = FormData()
|
||||||
mime_types = [None for i in range(len(images))]
|
mime_types = [None for i in range(len(media))]
|
||||||
for i in range(len(images)):
|
for i in range(len(media)):
|
||||||
mime_types[i] = is_data_an_wav(images[i][0], images[i][1])
|
mime_types[i] = is_data_an_audio(media[i][0], media[i][1])
|
||||||
images[i] = (to_bytes(images[i][0]), images[i][1])
|
media[i] = (to_bytes(media[i][0]), media[i][1])
|
||||||
mime_types[i] = is_accepted_format(images[i][0]) if mime_types[i] is None else mime_types[i]
|
mime_types[i] = is_accepted_format(media[i][0]) if mime_types[i] is None else mime_types[i]
|
||||||
for image, image_name in images:
|
for image, image_name in media:
|
||||||
data.add_field(f"files", to_bytes(image), filename=image_name)
|
data.add_field(f"files", to_bytes(image), filename=image_name)
|
||||||
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
|
async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
image_files = await response.json()
|
image_files = await response.json()
|
||||||
images = [{
|
media = [{
|
||||||
"path": image_file,
|
"path": image_file,
|
||||||
"url": f"{cls.api_url}/gradio_api/file={image_file}",
|
"url": f"{cls.api_url}/gradio_api/file={image_file}",
|
||||||
"orig_name": images[i][1],
|
"orig_name": media[i][1],
|
||||||
"size": len(images[i][0]),
|
"size": len(media[i][0]),
|
||||||
"mime_type": mime_types[i],
|
"mime_type": mime_types[i],
|
||||||
"meta": {
|
"meta": {
|
||||||
"_type": "gradio.FileData"
|
"_type": "gradio.FileData"
|
||||||
@@ -139,10 +139,10 @@ class Microsoft_Phi_4(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
} for i, image_file in enumerate(image_files)]
|
} for i, image_file in enumerate(image_files)]
|
||||||
|
|
||||||
|
|
||||||
async with cls.run("predict", session, prompt, conversation, images) as response:
|
async with cls.run("predict", session, prompt, conversation, media) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
|
||||||
async with cls.run("post", session, prompt, conversation, images) as response:
|
async with cls.run("post", session, prompt, conversation, media) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
|
||||||
async with cls.run("get", session, prompt, conversation) as response:
|
async with cls.run("get", session, prompt, conversation) as response:
|
||||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from aiohttp import ClientSession, FormData
|
from aiohttp import ClientSession, FormData
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages, ImagesType
|
from ...typing import AsyncResult, Messages, MediaListType
|
||||||
from ...requests import raise_for_status
|
from ...requests import raise_for_status
|
||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
@@ -25,7 +25,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
cls, model: str, messages: Messages,
|
cls, model: str, messages: Messages,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -36,10 +36,10 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=headers) as session:
|
||||||
if images:
|
if media:
|
||||||
data = FormData()
|
data = FormData()
|
||||||
data_bytes = to_bytes(images[0][0])
|
data_bytes = to_bytes(media[0][0])
|
||||||
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=images[0][1])
|
data.add_field("files", data_bytes, content_type=is_accepted_format(data_bytes), filename=media[0][1])
|
||||||
url = f"{cls.url}/gradio_api/upload?upload_id={get_random_string()}"
|
url = f"{cls.url}/gradio_api/upload?upload_id={get_random_string()}"
|
||||||
async with session.post(url, data=data, proxy=proxy) as response:
|
async with session.post(url, data=data, proxy=proxy) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages, ImagesType
|
from ...typing import AsyncResult, Messages, MediaListType
|
||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
|
||||||
@@ -67,15 +67,15 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
cls, model: str, messages: Messages, images: ImagesType = None, **kwargs
|
cls, model: str, messages: Messages, media: MediaListType = None, **kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if not model and images is not None:
|
if not model and media is not None:
|
||||||
model = cls.default_vision_model
|
model = cls.default_vision_model
|
||||||
is_started = False
|
is_started = False
|
||||||
random.shuffle(cls.providers)
|
random.shuffle(cls.providers)
|
||||||
for provider in cls.providers:
|
for provider in cls.providers:
|
||||||
if model in provider.model_aliases:
|
if model in provider.model_aliases:
|
||||||
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, images=images, **kwargs):
|
async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, media=media, **kwargs):
|
||||||
is_started = True
|
is_started = True
|
||||||
yield chunk
|
yield chunk
|
||||||
if is_started:
|
if is_started:
|
||||||
|
@@ -6,7 +6,7 @@ import base64
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..helper import filter_none
|
from ..helper import filter_none
|
||||||
from ...typing import AsyncResult, Messages, ImagesType
|
from ...typing import AsyncResult, Messages, MediaListType
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...providers.response import FinishReason, ToolCalls, Usage
|
from ...providers.response import FinishReason, ToolCalls, Usage
|
||||||
from ...errors import MissingAuthError
|
from ...errors import MissingAuthError
|
||||||
@@ -62,7 +62,7 @@ class Anthropic(OpenaiAPI):
|
|||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 120,
|
timeout: int = 120,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
temperature: float = None,
|
temperature: float = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
@@ -79,9 +79,9 @@ class Anthropic(OpenaiAPI):
|
|||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise MissingAuthError('Add a "api_key"')
|
raise MissingAuthError('Add a "api_key"')
|
||||||
|
|
||||||
if images is not None:
|
if media is not None:
|
||||||
insert_images = []
|
insert_images = []
|
||||||
for image, _ in images:
|
for image, _ in media:
|
||||||
data = to_bytes(image)
|
data = to_bytes(image)
|
||||||
insert_images.append({
|
insert_images.append({
|
||||||
"type": "image",
|
"type": "image",
|
||||||
|
@@ -19,7 +19,7 @@ except ImportError:
|
|||||||
has_nodriver = False
|
has_nodriver = False
|
||||||
|
|
||||||
from ... import debug
|
from ... import debug
|
||||||
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator
|
from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator
|
||||||
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube
|
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...requests.aiohttp import get_connector
|
from ...requests.aiohttp import get_connector
|
||||||
@@ -149,7 +149,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
cookies: Cookies = None,
|
cookies: Cookies = None,
|
||||||
connector: BaseConnector = None,
|
connector: BaseConnector = None,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
return_conversation: bool = False,
|
return_conversation: bool = False,
|
||||||
conversation: Conversation = None,
|
conversation: Conversation = None,
|
||||||
language: str = "en",
|
language: str = "en",
|
||||||
@@ -186,7 +186,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls.start_auto_refresh()
|
cls.start_auto_refresh()
|
||||||
)
|
)
|
||||||
|
|
||||||
images = await cls.upload_images(base_connector, images) if images else None
|
uploads = None if media is None else await cls.upload_images(base_connector, media)
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
cookies=cls._cookies,
|
cookies=cls._cookies,
|
||||||
headers=REQUEST_HEADERS,
|
headers=REQUEST_HEADERS,
|
||||||
@@ -205,7 +205,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
prompt,
|
prompt,
|
||||||
language=language,
|
language=language,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
images=images
|
uploads=uploads
|
||||||
))])
|
))])
|
||||||
}
|
}
|
||||||
async with client.post(
|
async with client.post(
|
||||||
@@ -327,10 +327,10 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
language: str,
|
language: str,
|
||||||
conversation: Conversation = None,
|
conversation: Conversation = None,
|
||||||
images: list[list[str, str]] = None,
|
uploads: list[list[str, str]] = None,
|
||||||
tools: list[list[str]] = []
|
tools: list[list[str]] = []
|
||||||
) -> list:
|
) -> list:
|
||||||
image_list = [[[image_url, 1], image_name] for image_url, image_name in images] if images else []
|
image_list = [[[image_url, 1], image_name] for image_url, image_name in uploads] if uploads else []
|
||||||
return [
|
return [
|
||||||
[prompt, 0, None, image_list, None, None, 0],
|
[prompt, 0, None, image_list, None, None, 0],
|
||||||
[language],
|
[language],
|
||||||
@@ -353,7 +353,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
0,
|
0,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def upload_images(connector: BaseConnector, images: ImagesType) -> list:
|
async def upload_images(connector: BaseConnector, media: MediaListType) -> list:
|
||||||
async def upload_image(image: bytes, image_name: str = None):
|
async def upload_image(image: bytes, image_name: str = None):
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
headers=UPLOAD_IMAGE_HEADERS,
|
headers=UPLOAD_IMAGE_HEADERS,
|
||||||
@@ -385,7 +385,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
) as response:
|
) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
return [await response.text(), image_name]
|
return [await response.text(), image_name]
|
||||||
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in images])
|
return await asyncio.gather(*[upload_image(image, image_name) for image, image_name in media])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):
|
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):
|
||||||
|
@@ -6,8 +6,8 @@ import requests
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from aiohttp import ClientSession, BaseConnector
|
from aiohttp import ClientSession, BaseConnector
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages, ImagesType
|
from ...typing import AsyncResult, Messages, MediaListType
|
||||||
from ...image import to_bytes, is_accepted_format
|
from ...image import to_bytes, is_data_an_media
|
||||||
from ...errors import MissingAuthError
|
from ...errors import MissingAuthError
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...providers.response import Usage, FinishReason
|
from ...providers.response import Usage, FinishReason
|
||||||
@@ -67,7 +67,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
api_base: str = api_base,
|
api_base: str = api_base,
|
||||||
use_auth_header: bool = False,
|
use_auth_header: bool = False,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
tools: Optional[list] = None,
|
tools: Optional[list] = None,
|
||||||
connector: BaseConnector = None,
|
connector: BaseConnector = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -94,13 +94,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
for message in messages
|
for message in messages
|
||||||
if message["role"] != "system"
|
if message["role"] != "system"
|
||||||
]
|
]
|
||||||
if images is not None:
|
if media is not None:
|
||||||
for image, _ in images:
|
for media_data, filename in media:
|
||||||
image = to_bytes(image)
|
image = to_bytes(image)
|
||||||
contents[-1]["parts"].append({
|
contents[-1]["parts"].append({
|
||||||
"inline_data": {
|
"inline_data": {
|
||||||
"mime_type": is_accepted_format(image),
|
"mime_type": is_data_an_media(image, filename),
|
||||||
"data": base64.b64encode(image).decode()
|
"data": base64.b64encode(media_data).decode()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
data = {
|
data = {
|
||||||
|
@@ -1,51 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import base64
|
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
from urllib.parse import quote_plus, unquote_plus
|
from typing import Dict, Any, AsyncIterator
|
||||||
from pathlib import Path
|
|
||||||
from aiohttp import ClientSession, BaseConnector
|
|
||||||
from typing import Dict, Any, Optional, AsyncIterator, List
|
|
||||||
|
|
||||||
from ... import debug
|
from ...typing import Messages, Cookies, AsyncResult
|
||||||
from ...typing import Messages, Cookies, ImagesType, AsyncResult
|
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration, AuthResult, RequestLogin
|
||||||
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration
|
from ...requests import StreamSession, get_args_from_nodriver, DEFAULT_HEADERS
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...requests.aiohttp import get_connector
|
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||||
from ...requests import get_nodriver
|
|
||||||
from ...errors import MissingAuthError
|
|
||||||
from ...cookies import get_cookies_dir
|
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|
||||||
from ..helper import format_prompt, get_cookies, get_last_user_message
|
from ..helper import format_prompt, get_cookies, get_last_user_message
|
||||||
|
|
||||||
class Conversation(JsonConversation):
|
class Conversation(JsonConversation):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
conversation_id: str,
|
conversation_id: str
|
||||||
response_id: str,
|
|
||||||
choice_id: str,
|
|
||||||
model: str
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.conversation_id = conversation_id
|
self.conversation_id = conversation_id
|
||||||
self.response_id = response_id
|
|
||||||
self.choice_id = choice_id
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
class Grok(AsyncGeneratorProvider, ProviderModelMixin):
|
class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
label = "Grok AI"
|
label = "Grok AI"
|
||||||
url = "https://grok.com"
|
url = "https://grok.com"
|
||||||
|
cookie_domain = ".grok.com"
|
||||||
assets_url = "https://assets.grok.com"
|
assets_url = "https://assets.grok.com"
|
||||||
conversation_url = "https://grok.com/rest/app-chat/conversations"
|
conversation_url = "https://grok.com/rest/app-chat/conversations"
|
||||||
|
|
||||||
needs_auth = True
|
needs_auth = True
|
||||||
working = False
|
working = True
|
||||||
|
|
||||||
default_model = "grok-3"
|
default_model = "grok-3"
|
||||||
models = [default_model, "grok-3-thinking", "grok-2"]
|
models = [default_model, "grok-3-thinking", "grok-2"]
|
||||||
|
|
||||||
_cookies: Cookies = None
|
@classmethod
|
||||||
|
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
|
||||||
|
if cookies is None:
|
||||||
|
cookies = get_cookies(cls.cookie_domain, False, True, False)
|
||||||
|
if cookies is not None and "sso" in cookies:
|
||||||
|
yield AuthResult(
|
||||||
|
cookies=cookies,
|
||||||
|
impersonate="chrome",
|
||||||
|
proxy=proxy,
|
||||||
|
headers=DEFAULT_HEADERS
|
||||||
|
)
|
||||||
|
return
|
||||||
|
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
|
||||||
|
yield AuthResult(
|
||||||
|
**await get_args_from_nodriver(
|
||||||
|
cls.url,
|
||||||
|
proxy=proxy,
|
||||||
|
wait_for='[href="/chat#private"]'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _prepare_payload(cls, model: str, message: str) -> Dict[str, Any]:
|
async def _prepare_payload(cls, model: str, message: str) -> Dict[str, Any]:
|
||||||
@@ -72,63 +77,43 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_authed(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
auth_result: AuthResult,
|
||||||
cookies: Cookies = None,
|
cookies: Cookies = None,
|
||||||
connector: BaseConnector = None,
|
|
||||||
images: ImagesType = None,
|
|
||||||
return_conversation: bool = False,
|
return_conversation: bool = False,
|
||||||
conversation: Optional[Conversation] = None,
|
conversation: Conversation = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
cls._cookies = cookies or cls._cookies or get_cookies(".grok.com", False, True)
|
conversation_id = None if conversation is None else conversation.conversation_id
|
||||||
if not cls._cookies:
|
prompt = format_prompt(messages) if conversation_id is None else get_last_user_message(messages)
|
||||||
raise MissingAuthError("Missing required cookies")
|
async with StreamSession(
|
||||||
|
**auth_result.get_dict()
|
||||||
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
|
||||||
base_connector = get_connector(connector, proxy)
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"accept": "*/*",
|
|
||||||
"accept-language": "en-GB,en;q=0.9",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"origin": "https://grok.com",
|
|
||||||
"priority": "u=1, i",
|
|
||||||
"referer": "https://grok.com/",
|
|
||||||
"sec-ch-ua": '"Not/A)Brand";v="8", "Chromium";v="126", "Brave";v="126"',
|
|
||||||
"sec-ch-ua-mobile": "?0",
|
|
||||||
"sec-ch-ua-platform": '"macOS"',
|
|
||||||
"sec-fetch-dest": "empty",
|
|
||||||
"sec-fetch-mode": "cors",
|
|
||||||
"sec-fetch-site": "same-origin",
|
|
||||||
"sec-gpc": "1",
|
|
||||||
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
|
|
||||||
}
|
|
||||||
|
|
||||||
async with ClientSession(
|
|
||||||
headers=headers,
|
|
||||||
cookies=cls._cookies,
|
|
||||||
connector=base_connector
|
|
||||||
) as session:
|
) as session:
|
||||||
payload = await cls._prepare_payload(model, prompt)
|
payload = await cls._prepare_payload(model, prompt)
|
||||||
response = await session.post(f"{cls.conversation_url}/new", json=payload)
|
if conversation_id is None:
|
||||||
|
url = f"{cls.conversation_url}/new"
|
||||||
|
else:
|
||||||
|
url = f"{cls.conversation_url}/{conversation_id}/responses"
|
||||||
|
async with session.post(url, json=payload) as response:
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
|
|
||||||
thinking_duration = None
|
thinking_duration = None
|
||||||
async for line in response.content:
|
async for line in response.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
try:
|
try:
|
||||||
json_data = json.loads(line)
|
json_data = json.loads(line)
|
||||||
result = json_data.get("result", {})
|
result = json_data.get("result", {})
|
||||||
|
if conversation_id is None:
|
||||||
|
conversation_id = result.get("conversation", {}).get("conversationId")
|
||||||
response_data = result.get("response", {})
|
response_data = result.get("response", {})
|
||||||
image = response_data.get("streamingImageGenerationResponse", None)
|
image = response_data.get("streamingImageGenerationResponse", None)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
yield ImagePreview(f'{cls.assets_url}/{image["imageUrl"]}', "", {"cookies": cookies, "headers": headers})
|
yield ImagePreview(f'{cls.assets_url}/{image["imageUrl"]}', "", {"cookies": cookies, "headers": headers})
|
||||||
token = response_data.get("token", "")
|
token = response_data.get("token", result.get("token"))
|
||||||
is_thinking = response_data.get("isThinking", False)
|
is_thinking = response_data.get("isThinking", result.get("isThinking"))
|
||||||
if token:
|
if token:
|
||||||
if is_thinking:
|
if is_thinking:
|
||||||
if thinking_duration is None:
|
if thinking_duration is None:
|
||||||
@@ -145,9 +130,12 @@ class Grok(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
generated_images = response_data.get("modelResponse", {}).get("generatedImageUrls", None)
|
generated_images = response_data.get("modelResponse", {}).get("generatedImageUrls", None)
|
||||||
if generated_images:
|
if generated_images:
|
||||||
yield ImageResponse([f'{cls.assets_url}/{image}' for image in generated_images], "", {"cookies": cookies, "headers": headers})
|
yield ImageResponse([f'{cls.assets_url}/{image}' for image in generated_images], "", {"cookies": cookies, "headers": headers})
|
||||||
title = response_data.get("title", {}).get("newTitle", "")
|
title = result.get("title", {}).get("newTitle", "")
|
||||||
if title:
|
if title:
|
||||||
yield TitleGeneration(title)
|
yield TitleGeneration(title)
|
||||||
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
|
if return_conversation and conversation_id is not None:
|
||||||
|
yield Conversation(conversation_id)
|
@@ -18,7 +18,7 @@ except ImportError:
|
|||||||
has_nodriver = False
|
has_nodriver = False
|
||||||
|
|
||||||
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...requests import StreamSession
|
from ...requests import StreamSession
|
||||||
from ...requests import get_nodriver
|
from ...requests import get_nodriver
|
||||||
@@ -127,7 +127,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
cls,
|
cls,
|
||||||
session: StreamSession,
|
session: StreamSession,
|
||||||
auth_result: AuthResult,
|
auth_result: AuthResult,
|
||||||
images: ImagesType,
|
media: MediaListType,
|
||||||
) -> ImageRequest:
|
) -> ImageRequest:
|
||||||
"""
|
"""
|
||||||
Upload an image to the service and get the download URL
|
Upload an image to the service and get the download URL
|
||||||
@@ -135,7 +135,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
Args:
|
Args:
|
||||||
session: The StreamSession object to use for requests
|
session: The StreamSession object to use for requests
|
||||||
headers: The headers to include in the requests
|
headers: The headers to include in the requests
|
||||||
images: The images to upload, either a PIL Image object or a bytes object
|
media: The images to upload, either a PIL Image object or a bytes object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An ImageRequest object that contains the download URL, file name, and other data
|
An ImageRequest object that contains the download URL, file name, and other data
|
||||||
@@ -187,9 +187,9 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
await raise_for_status(response, "Get download url failed")
|
await raise_for_status(response, "Get download url failed")
|
||||||
image_data["download_url"] = (await response.json())["download_url"]
|
image_data["download_url"] = (await response.json())["download_url"]
|
||||||
return ImageRequest(image_data)
|
return ImageRequest(image_data)
|
||||||
if not images:
|
if not media:
|
||||||
return
|
return
|
||||||
return [await upload_image(image, image_name) for image, image_name in images]
|
return [await upload_image(image, image_name) for image, image_name in media]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
|
def create_messages(cls, messages: Messages, image_requests: ImageRequest = None, system_hints: list = None):
|
||||||
@@ -268,7 +268,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
auto_continue: bool = False,
|
auto_continue: bool = False,
|
||||||
action: str = "next",
|
action: str = "next",
|
||||||
conversation: Conversation = None,
|
conversation: Conversation = None,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
return_conversation: bool = False,
|
return_conversation: bool = False,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -285,7 +285,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
auto_continue (bool): Flag to automatically continue the conversation.
|
auto_continue (bool): Flag to automatically continue the conversation.
|
||||||
action (str): Type of action ('next', 'continue', 'variant').
|
action (str): Type of action ('next', 'continue', 'variant').
|
||||||
conversation_id (str): ID of the conversation.
|
conversation_id (str): ID of the conversation.
|
||||||
images (ImagesType): Images to include in the conversation.
|
media (MediaListType): Images to include in the conversation.
|
||||||
return_conversation (bool): Flag to include response fields in the output.
|
return_conversation (bool): Flag to include response fields in the output.
|
||||||
**kwargs: Additional keyword arguments.
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
@@ -316,7 +316,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
cls._update_request_args(auth_result, session)
|
cls._update_request_args(auth_result, session)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
try:
|
try:
|
||||||
image_requests = await cls.upload_images(session, auth_result, images) if images else None
|
image_requests = None if media is None else await cls.upload_images(session, auth_result, media)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.error("OpenaiChat: Upload image failed")
|
debug.error("OpenaiChat: Upload image failed")
|
||||||
debug.error(e)
|
debug.error(e)
|
||||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ...typing import Messages, AsyncResult, ImagesType
|
from ...typing import Messages, AsyncResult, MediaListType
|
||||||
from ...requests import StreamSession
|
from ...requests import StreamSession
|
||||||
from ...image import to_data_uri
|
from ...image import to_data_uri
|
||||||
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
@@ -18,21 +18,21 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
debug.log(f"{cls.__name__}: {api_key}")
|
debug.log(f"{cls.__name__}: {api_key}")
|
||||||
if images is not None:
|
if media is not None:
|
||||||
for i in range(len(images)):
|
for i in range(len(media)):
|
||||||
images[i] = (to_data_uri(images[i][0]), images[i][1])
|
media[i] = (to_data_uri(media[i][0]), media[i][1])
|
||||||
async with StreamSession(
|
async with StreamSession(
|
||||||
headers={"Accept": "text/event-stream", **cls.headers},
|
headers={"Accept": "text/event-stream", **cls.headers},
|
||||||
) as session:
|
) as session:
|
||||||
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
|
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"images": images,
|
"media": media,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
**kwargs
|
**kwargs
|
||||||
}, ssl=cls.ssl) as response:
|
}, ssl=cls.ssl) as response:
|
||||||
|
@@ -5,11 +5,11 @@ import requests
|
|||||||
|
|
||||||
from ..helper import filter_none, format_image_prompt
|
from ..helper import filter_none, format_image_prompt
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
||||||
from ...errors import MissingAuthError, ResponseError
|
from ...errors import MissingAuthError, ResponseError
|
||||||
from ...image import to_data_uri
|
from ...image import to_data_uri, is_data_an_audio, to_input_audio
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
||||||
@@ -54,7 +54,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 120,
|
timeout: int = 120,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
api_endpoint: str = None,
|
api_endpoint: str = None,
|
||||||
api_base: str = None,
|
api_base: str = None,
|
||||||
@@ -66,7 +66,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
headers: dict = None,
|
headers: dict = None,
|
||||||
impersonate: str = None,
|
impersonate: str = None,
|
||||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
|
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "modalities", "audio"],
|
||||||
extra_data: dict = {},
|
extra_data: dict = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
@@ -98,19 +98,26 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||||
return
|
return
|
||||||
|
|
||||||
if images is not None and messages:
|
if media is not None and messages:
|
||||||
if not model and hasattr(cls, "default_vision_model"):
|
if not model and hasattr(cls, "default_vision_model"):
|
||||||
model = cls.default_vision_model
|
model = cls.default_vision_model
|
||||||
last_message = messages[-1].copy()
|
last_message = messages[-1].copy()
|
||||||
last_message["content"] = [
|
last_message["content"] = [
|
||||||
*[{
|
*[
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": to_input_audio(media_data, filename)
|
||||||
|
}
|
||||||
|
if is_data_an_audio(media_data, filename) else {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": to_data_uri(image)}
|
"image_url": {"url": to_data_uri(media_data, filename)}
|
||||||
} for image, _ in images],
|
}
|
||||||
|
for media_data, filename in media
|
||||||
|
],
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": messages[-1]["content"]
|
"text": last_message["content"]
|
||||||
}
|
} if isinstance(last_message["content"], str) else last_message["content"]
|
||||||
]
|
]
|
||||||
messages[-1] = last_message
|
messages[-1] = last_message
|
||||||
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||||
|
@@ -34,12 +34,14 @@ class ChatCompletion:
|
|||||||
ignore_stream: bool = False,
|
ignore_stream: bool = False,
|
||||||
**kwargs) -> Union[CreateResult, str]:
|
**kwargs) -> Union[CreateResult, str]:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
kwargs["images"] = [(image, image_name)]
|
kwargs["media"] = [(image, image_name)]
|
||||||
|
elif "images" in kwargs:
|
||||||
|
kwargs["media"] = kwargs.pop("images")
|
||||||
model, provider = get_model_and_provider(
|
model, provider = get_model_and_provider(
|
||||||
model, provider, stream,
|
model, provider, stream,
|
||||||
ignore_working,
|
ignore_working,
|
||||||
ignore_stream,
|
ignore_stream,
|
||||||
has_images="images" in kwargs,
|
has_images="media" in kwargs,
|
||||||
)
|
)
|
||||||
if "proxy" not in kwargs:
|
if "proxy" not in kwargs:
|
||||||
proxy = os.environ.get("G4F_PROXY")
|
proxy = os.environ.get("G4F_PROXY")
|
||||||
@@ -63,8 +65,10 @@ class ChatCompletion:
|
|||||||
ignore_working: bool = False,
|
ignore_working: bool = False,
|
||||||
**kwargs) -> Union[AsyncResult, Coroutine[str]]:
|
**kwargs) -> Union[AsyncResult, Coroutine[str]]:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
kwargs["images"] = [(image, image_name)]
|
kwargs["media"] = [(image, image_name)]
|
||||||
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="images" in kwargs)
|
elif "images" in kwargs:
|
||||||
|
kwargs["media"] = kwargs.pop("images")
|
||||||
|
model, provider = get_model_and_provider(model, provider, False, ignore_working, has_images="media" in kwargs)
|
||||||
if "proxy" not in kwargs:
|
if "proxy" not in kwargs:
|
||||||
proxy = os.environ.get("G4F_PROXY")
|
proxy = os.environ.get("G4F_PROXY")
|
||||||
if proxy:
|
if proxy:
|
||||||
|
@@ -38,7 +38,7 @@ import g4f.debug
|
|||||||
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
|
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
|
||||||
from g4f.providers.response import BaseConversation, JsonConversation
|
from g4f.providers.response import BaseConversation, JsonConversation
|
||||||
from g4f.client.helper import filter_none
|
from g4f.client.helper import filter_none
|
||||||
from g4f.image import is_data_uri_an_media
|
from g4f.image import is_data_an_media
|
||||||
from g4f.image.copy_images import images_dir, copy_images, get_source_url
|
from g4f.image.copy_images import images_dir, copy_images, get_source_url
|
||||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
|
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
|
||||||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||||
@@ -320,16 +320,18 @@ class Api:
|
|||||||
|
|
||||||
if config.image is not None:
|
if config.image is not None:
|
||||||
try:
|
try:
|
||||||
is_data_uri_an_media(config.image)
|
is_data_an_media(config.image)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return ErrorResponse.from_message(f"The image you send must be a data URI. Example: data:image/jpeg;base64,...", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
return ErrorResponse.from_message(f"The image you send must be a data URI. Example: data:image/jpeg;base64,...", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
if config.images is not None:
|
if config.media is None:
|
||||||
for image in config.images:
|
config.media = config.images
|
||||||
|
if config.media is not None:
|
||||||
|
for image in config.media:
|
||||||
try:
|
try:
|
||||||
is_data_uri_an_media(image[0])
|
is_data_an_media(image[0])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
example = json.dumps({"images": [["data:image/jpeg;base64,...", "filename"]]})
|
example = json.dumps({"media": [["data:image/jpeg;base64,...", "filename.jpg"]]})
|
||||||
return ErrorResponse.from_message(f'The image you send must be a data URI. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
return ErrorResponse.from_message(f'The media you send must be a data URIs. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
|
|
||||||
# Create the completion response
|
# Create the completion response
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
|
@@ -17,6 +17,7 @@ class ChatCompletionsConfig(BaseModel):
|
|||||||
image: Optional[str] = None
|
image: Optional[str] = None
|
||||||
image_name: Optional[str] = None
|
image_name: Optional[str] = None
|
||||||
images: Optional[list[tuple[str, str]]] = None
|
images: Optional[list[tuple[str, str]]] = None
|
||||||
|
media: Optional[list[tuple[str, str]]] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = None
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = None
|
||||||
|
@@ -293,14 +293,16 @@ class Completions:
|
|||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
if image is not None:
|
if image is not None:
|
||||||
kwargs["images"] = [(image, image_name)]
|
kwargs["media"] = [(image, image_name)]
|
||||||
|
elif "images" in kwargs:
|
||||||
|
kwargs["media"] = kwargs.pop("images")
|
||||||
model, provider = get_model_and_provider(
|
model, provider = get_model_and_provider(
|
||||||
model,
|
model,
|
||||||
self.provider if provider is None else provider,
|
self.provider if provider is None else provider,
|
||||||
stream,
|
stream,
|
||||||
ignore_working,
|
ignore_working,
|
||||||
ignore_stream,
|
ignore_stream,
|
||||||
has_images="images" in kwargs
|
has_images="media" in kwargs
|
||||||
)
|
)
|
||||||
stop = [stop] if isinstance(stop, str) else stop
|
stop = [stop] if isinstance(stop, str) else stop
|
||||||
if ignore_stream:
|
if ignore_stream:
|
||||||
@@ -481,7 +483,7 @@ class Images:
|
|||||||
proxy = self.client.proxy
|
proxy = self.client.proxy
|
||||||
prompt = "create a variation of this image"
|
prompt = "create a variation of this image"
|
||||||
if image is not None:
|
if image is not None:
|
||||||
kwargs["images"] = [(image, None)]
|
kwargs["media"] = [(image, None)]
|
||||||
|
|
||||||
error = None
|
error = None
|
||||||
response = None
|
response = None
|
||||||
@@ -581,14 +583,16 @@ class AsyncCompletions:
|
|||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
if image is not None:
|
if image is not None:
|
||||||
kwargs["images"] = [(image, image_name)]
|
kwargs["media"] = [(image, image_name)]
|
||||||
|
elif "images" in kwargs:
|
||||||
|
kwargs["media"] = kwargs.pop("images")
|
||||||
model, provider = get_model_and_provider(
|
model, provider = get_model_and_provider(
|
||||||
model,
|
model,
|
||||||
self.provider if provider is None else provider,
|
self.provider if provider is None else provider,
|
||||||
stream,
|
stream,
|
||||||
ignore_working,
|
ignore_working,
|
||||||
ignore_stream,
|
ignore_stream,
|
||||||
has_images="images" in kwargs,
|
has_images="media" in kwargs,
|
||||||
)
|
)
|
||||||
stop = [stop] if isinstance(stop, str) else stop
|
stop = [stop] if isinstance(stop, str) else stop
|
||||||
if ignore_stream:
|
if ignore_stream:
|
||||||
|
@@ -67,7 +67,7 @@ DOMAINS = [
|
|||||||
if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
|
if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
|
||||||
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
|
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
|
||||||
|
|
||||||
def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, single_browser: bool = False) -> Dict[str, str]:
|
def get_cookies(domain_name: str, raise_requirements_error: bool = True, single_browser: bool = False, cache_result: bool = True) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Load cookies for a given domain from all supported browsers and cache the results.
|
Load cookies for a given domain from all supported browsers and cache the results.
|
||||||
|
|
||||||
@@ -77,10 +77,11 @@ def get_cookies(domain_name: str = '', raise_requirements_error: bool = True, si
|
|||||||
Returns:
|
Returns:
|
||||||
Dict[str, str]: A dictionary of cookie names and values.
|
Dict[str, str]: A dictionary of cookie names and values.
|
||||||
"""
|
"""
|
||||||
if domain_name in CookiesConfig.cookies:
|
if cache_result and domain_name in CookiesConfig.cookies:
|
||||||
return CookiesConfig.cookies[domain_name]
|
return CookiesConfig.cookies[domain_name]
|
||||||
|
|
||||||
cookies = load_cookies_from_browsers(domain_name, raise_requirements_error, single_browser)
|
cookies = load_cookies_from_browsers(domain_name, raise_requirements_error, single_browser)
|
||||||
|
if cache_result:
|
||||||
CookiesConfig.cookies[domain_name] = cookies
|
CookiesConfig.cookies[domain_name] = cookies
|
||||||
return cookies
|
return cookies
|
||||||
|
|
||||||
@@ -108,8 +109,8 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
|
|||||||
for cookie_fn in browsers:
|
for cookie_fn in browsers:
|
||||||
try:
|
try:
|
||||||
cookie_jar = cookie_fn(domain_name=domain_name)
|
cookie_jar = cookie_fn(domain_name=domain_name)
|
||||||
if len(cookie_jar) and debug.logging:
|
if len(cookie_jar):
|
||||||
print(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
|
debug.log(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
|
||||||
for cookie in cookie_jar:
|
for cookie in cookie_jar:
|
||||||
if cookie.name not in cookies:
|
if cookie.name not in cookies:
|
||||||
if not cookie.expires or cookie.expires > time.time():
|
if not cookie.expires or cookie.expires > time.time():
|
||||||
@@ -119,8 +120,7 @@ def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool
|
|||||||
except BrowserCookieError:
|
except BrowserCookieError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if debug.logging:
|
debug.error(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
|
||||||
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
|
|
||||||
return cookies
|
return cookies
|
||||||
|
|
||||||
def set_cookies_dir(dir: str) -> None:
|
def set_cookies_dir(dir: str) -> None:
|
||||||
|
@@ -86,6 +86,23 @@
|
|||||||
height: 100%;
|
height: 100%;
|
||||||
text-align: center;
|
text-align: center;
|
||||||
z-index: 1;
|
z-index: 1;
|
||||||
|
transition: transform 0.25s ease-in;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container.slide {
|
||||||
|
transform: translateX(-100%);
|
||||||
|
transition: transform 0.15s ease-out;
|
||||||
|
}
|
||||||
|
|
||||||
|
.slide-button {
|
||||||
|
position: absolute;
|
||||||
|
top: 20px;
|
||||||
|
left: 20px;
|
||||||
|
background: var(--colour-4);
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
padding: 6px 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
header {
|
header {
|
||||||
@@ -105,7 +122,8 @@
|
|||||||
height: 100%;
|
height: 100%;
|
||||||
position: absolute;
|
position: absolute;
|
||||||
z-index: -1;
|
z-index: -1;
|
||||||
object-fit: contain;
|
object-fit: cover;
|
||||||
|
object-position: center;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
background: black;
|
background: black;
|
||||||
}
|
}
|
||||||
@@ -181,6 +199,7 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
<link rel="stylesheet" href="/static/css/all.min.css">
|
||||||
<script>
|
<script>
|
||||||
(async () => {
|
(async () => {
|
||||||
const isIframe = window.self !== window.top;
|
const isIframe = window.self !== window.top;
|
||||||
@@ -208,6 +227,10 @@
|
|||||||
<!-- Gradient Background Circle -->
|
<!-- Gradient Background Circle -->
|
||||||
<div class="gradient"></div>
|
<div class="gradient"></div>
|
||||||
|
|
||||||
|
<button class="slide-button">
|
||||||
|
<i class="fa-solid fa-arrow-left"></i>
|
||||||
|
</button>
|
||||||
|
|
||||||
<!-- Main Content -->
|
<!-- Main Content -->
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<header>
|
<header>
|
||||||
@@ -277,8 +300,11 @@
|
|||||||
if (oauthResult) {
|
if (oauthResult) {
|
||||||
try {
|
try {
|
||||||
oauthResult = JSON.parse(oauthResult);
|
oauthResult = JSON.parse(oauthResult);
|
||||||
|
user = await hub.whoAmI({accessToken: oauthResult.accessToken});
|
||||||
} catch {
|
} catch {
|
||||||
oauthResult = null;
|
oauthResult = null;
|
||||||
|
localStorage.removeItem("oauth");
|
||||||
|
localStorage.removeItem("HuggingFace-api_key");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
oauthResult ||= await oauthHandleRedirectIfPresent();
|
oauthResult ||= await oauthHandleRedirectIfPresent();
|
||||||
@@ -339,7 +365,7 @@
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const lower = data.prompt.toLowerCase();
|
const lower = data.prompt.toLowerCase();
|
||||||
const tags = ["nsfw", "timeline", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
|
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
|
||||||
for (i in tags) {
|
for (i in tags) {
|
||||||
if (lower.indexOf(tags[i]) != -1) {
|
if (lower.indexOf(tags[i]) != -1) {
|
||||||
console.log("Skipping image with tag: " + tags[i]);
|
console.log("Skipping image with tag: " + tags[i]);
|
||||||
@@ -363,6 +389,21 @@
|
|||||||
imageFeed.remove();
|
imageFeed.remove();
|
||||||
}
|
}
|
||||||
}, 7000);
|
}, 7000);
|
||||||
|
|
||||||
|
const container = document.querySelector('.container');
|
||||||
|
const button = document.querySelector('.slide-button');
|
||||||
|
const slideIcon = button.querySelector('i');
|
||||||
|
button.onclick = () => {
|
||||||
|
if (container.classList.contains('slide')) {
|
||||||
|
container.classList.remove('slide');
|
||||||
|
slideIcon.classList.remove('fa-arrow-right');
|
||||||
|
slideIcon.classList.add('fa-arrow-left');
|
||||||
|
} else {
|
||||||
|
container.classList.add('slide');
|
||||||
|
slideIcon.classList.remove('fa-arrow-left');
|
||||||
|
slideIcon.classList.add('fa-arrow-right');
|
||||||
|
}
|
||||||
|
}
|
||||||
})();
|
})();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
@@ -141,7 +141,7 @@ class Api:
|
|||||||
stream=True,
|
stream=True,
|
||||||
ignore_stream=True,
|
ignore_stream=True,
|
||||||
logging=False,
|
logging=False,
|
||||||
has_images="images" in kwargs,
|
has_images="media" in kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.error(e)
|
debug.error(e)
|
||||||
|
@@ -8,6 +8,7 @@ import asyncio
|
|||||||
import shutil
|
import shutil
|
||||||
import random
|
import random
|
||||||
import datetime
|
import datetime
|
||||||
|
import tempfile
|
||||||
from flask import Flask, Response, request, jsonify, render_template
|
from flask import Flask, Response, request, jsonify, render_template
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -111,11 +112,13 @@ class Backend_Api(Api):
|
|||||||
else:
|
else:
|
||||||
json_data = request.json
|
json_data = request.json
|
||||||
if "files" in request.files:
|
if "files" in request.files:
|
||||||
images = []
|
media = []
|
||||||
for file in request.files.getlist('files'):
|
for file in request.files.getlist('files'):
|
||||||
if file.filename != '' and is_allowed_extension(file.filename):
|
if file.filename != '' and is_allowed_extension(file.filename):
|
||||||
images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename))
|
newfile = tempfile.TemporaryFile()
|
||||||
json_data['images'] = images
|
shutil.copyfileobj(file.stream, newfile)
|
||||||
|
media.append((newfile, file.filename))
|
||||||
|
json_data['media'] = media
|
||||||
|
|
||||||
if app.demo and not json_data.get("provider"):
|
if app.demo and not json_data.get("provider"):
|
||||||
model = json_data.get("model")
|
model = json_data.get("model")
|
||||||
|
@@ -76,24 +76,24 @@ def is_allowed_extension(filename: str) -> bool:
|
|||||||
return '.' in filename and \
|
return '.' in filename and \
|
||||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||||
|
|
||||||
def is_data_uri_an_media(data_uri: str) -> str:
|
def is_data_an_media(data, filename: str = None) -> str:
|
||||||
return is_data_an_wav(data_uri) or is_data_uri_an_image(data_uri)
|
content_type = is_data_an_audio(data, filename)
|
||||||
|
if content_type is not None:
|
||||||
|
return content_type
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return is_accepted_format(data)
|
||||||
|
return is_data_uri_an_image(data)
|
||||||
|
|
||||||
def is_data_an_wav(data_uri: str, filename: str = None) -> str:
|
def is_data_an_audio(data_uri: str, filename: str = None) -> str:
|
||||||
"""
|
if filename:
|
||||||
Checks if the given data URI represents an image.
|
if filename.endswith(".wav"):
|
||||||
|
|
||||||
Args:
|
|
||||||
data_uri (str): The data URI to check.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the data URI is invalid or the image format is not allowed.
|
|
||||||
"""
|
|
||||||
if filename and filename.endswith(".wav"):
|
|
||||||
return "audio/wav"
|
|
||||||
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
|
|
||||||
if isinstance(data_uri, str) and re.match(r'data:audio/wav;base64,', data_uri):
|
|
||||||
return "audio/wav"
|
return "audio/wav"
|
||||||
|
elif filename.endswith(".mp3"):
|
||||||
|
return "audio/mpeg"
|
||||||
|
if isinstance(data_uri, str):
|
||||||
|
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
|
||||||
|
if audio_format:
|
||||||
|
return audio_format.group(1)
|
||||||
|
|
||||||
def is_data_uri_an_image(data_uri: str) -> bool:
|
def is_data_uri_an_image(data_uri: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -218,7 +218,7 @@ def to_bytes(image: ImageType) -> bytes:
|
|||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
return image
|
return image
|
||||||
elif isinstance(image, str) and image.startswith("data:"):
|
elif isinstance(image, str) and image.startswith("data:"):
|
||||||
is_data_uri_an_media(image)
|
is_data_an_media(image)
|
||||||
return extract_data_uri(image)
|
return extract_data_uri(image)
|
||||||
elif isinstance(image, Image):
|
elif isinstance(image, Image):
|
||||||
bytes_io = BytesIO()
|
bytes_io = BytesIO()
|
||||||
@@ -236,13 +236,29 @@ def to_bytes(image: ImageType) -> bytes:
|
|||||||
pass
|
pass
|
||||||
return image.read()
|
return image.read()
|
||||||
|
|
||||||
def to_data_uri(image: ImageType) -> str:
|
def to_data_uri(image: ImageType, filename: str = None) -> str:
|
||||||
if not isinstance(image, str):
|
if not isinstance(image, str):
|
||||||
data = to_bytes(image)
|
data = to_bytes(image)
|
||||||
data_base64 = base64.b64encode(data).decode()
|
data_base64 = base64.b64encode(data).decode()
|
||||||
return f"data:{is_accepted_format(data)};base64,{data_base64}"
|
return f"data:{is_data_an_media(data, filename)};base64,{data_base64}"
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def to_input_audio(audio: ImageType, filename: str = None) -> str:
|
||||||
|
if not isinstance(audio, str):
|
||||||
|
if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
|
||||||
|
return {
|
||||||
|
"data": base64.b64encode(to_bytes(audio)).decode(),
|
||||||
|
"format": "wav" if filename.endswith(".wav") else "mpeg"
|
||||||
|
}
|
||||||
|
raise ValueError("Invalid input audio")
|
||||||
|
audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
|
||||||
|
if audio:
|
||||||
|
return {
|
||||||
|
"data": audio.group(2),
|
||||||
|
"format": audio.group(1),
|
||||||
|
}
|
||||||
|
raise ValueError("Invalid input audio")
|
||||||
|
|
||||||
class ImageDataResponse():
|
class ImageDataResponse():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@@ -109,7 +109,7 @@ async def copy_images(
|
|||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
|
|
||||||
# Verify file format
|
# Verify file format
|
||||||
if not os.path.splitext(target_path)[1]:
|
if target is None and not os.path.splitext(target_path)[1]:
|
||||||
with open(target_path, "rb") as f:
|
with open(target_path, "rb") as f:
|
||||||
file_header = f.read(12)
|
file_header = f.read(12)
|
||||||
detected_type = is_accepted_format(file_header)
|
detected_type = is_accepted_format(file_header)
|
||||||
@@ -120,7 +120,7 @@ async def copy_images(
|
|||||||
|
|
||||||
# Build URL with safe encoding
|
# Build URL with safe encoding
|
||||||
url_filename = quote(os.path.basename(target_path))
|
url_filename = quote(os.path.basename(target_path))
|
||||||
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
|
return f"/images/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
|
||||||
|
|
||||||
except (ClientError, IOError, OSError) as e:
|
except (ClientError, IOError, OSError) as e:
|
||||||
debug.error(f"Image copying failed: {type(e).__name__}: {e}")
|
debug.error(f"Image copying failed: {type(e).__name__}: {e}")
|
||||||
|
@@ -25,7 +25,7 @@ from .. import debug
|
|||||||
|
|
||||||
SAFE_PARAMETERS = [
|
SAFE_PARAMETERS = [
|
||||||
"model", "messages", "stream", "timeout",
|
"model", "messages", "stream", "timeout",
|
||||||
"proxy", "images", "response_format",
|
"proxy", "media", "response_format",
|
||||||
"prompt", "negative_prompt", "tools", "conversation",
|
"prompt", "negative_prompt", "tools", "conversation",
|
||||||
"history_disabled",
|
"history_disabled",
|
||||||
"temperature", "top_k", "top_p",
|
"temperature", "top_k", "top_p",
|
||||||
@@ -56,7 +56,7 @@ PARAMETER_EXAMPLES = {
|
|||||||
"frequency_penalty": 1,
|
"frequency_penalty": 1,
|
||||||
"presence_penalty": 1,
|
"presence_penalty": 1,
|
||||||
"messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
|
"messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
|
||||||
"images": [["data:image/jpeg;base64,...", "filename.jpg"]],
|
"media": [["data:image/jpeg;base64,...", "filename.jpg"]],
|
||||||
"response_format": {"type": "json_object"},
|
"response_format": {"type": "json_object"},
|
||||||
"conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
|
"conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages, ImagesType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..client.service import get_model_and_provider
|
from ..client.service import get_model_and_provider
|
||||||
from ..client.helper import filter_json
|
from ..client.helper import filter_json
|
||||||
from .base_provider import AsyncGeneratorProvider
|
from .base_provider import AsyncGeneratorProvider
|
||||||
@@ -17,7 +17,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
images: ImagesType = None,
|
media: MediaListType = None,
|
||||||
tools: list[str] = None,
|
tools: list[str] = None,
|
||||||
response_format: dict = None,
|
response_format: dict = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -28,7 +28,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||||||
model, provider = get_model_and_provider(
|
model, provider = get_model_and_provider(
|
||||||
model, provider,
|
model, provider,
|
||||||
stream, logging=False,
|
stream, logging=False,
|
||||||
has_images=images is not None
|
has_images=media is not None
|
||||||
)
|
)
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
if len(tools) > 1:
|
if len(tools) > 1:
|
||||||
@@ -49,7 +49,7 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
|||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
images=images,
|
media=media,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
@@ -21,7 +21,7 @@ AsyncResult = AsyncIterator[Union[str, ResponseType]]
|
|||||||
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
|
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
|
||||||
Cookies = Dict[str, str]
|
Cookies = Dict[str, str]
|
||||||
ImageType = Union[str, bytes, IO, Image, os.PathLike]
|
ImageType = Union[str, bytes, IO, Image, os.PathLike]
|
||||||
ImagesType = List[Tuple[ImageType, Optional[str]]]
|
MediaListType = List[Tuple[ImageType, Optional[str]]]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Any',
|
'Any',
|
||||||
@@ -44,5 +44,5 @@ __all__ = [
|
|||||||
'Cookies',
|
'Cookies',
|
||||||
'Image',
|
'Image',
|
||||||
'ImageType',
|
'ImageType',
|
||||||
'ImagesType'
|
'MediaListType'
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user