mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-09-26 20:31:14 +08:00
feat: add EdgeTTS audio provider and global image→media refactor
- **Docs** - `docs/file.md`: update upload instructions to use inline `bucket` content parts instead of `tool_calls/bucket_tool`. - `docs/media.md`: add asynchronous audio transcription example, detailed explanation, and notes. - **New audio provider** - Add `g4f/Provider/audio/EdgeTTS.py` implementing Edge Text‑to‑Speech (`EdgeTTS`). - Create `g4f/Provider/audio/__init__.py` for provider export. - Register provider in `g4f/Provider/__init__.py`. - **Refactor image → media** - Introduce `generated_media/` directory and `get_media_dir()` helper in `g4f/image/copy_images.py`; add `ensure_media_dir()`; keep back‑compat with legacy `generated_images/`. - Replace `images_dir` references with `get_media_dir()` across: - `g4f/api/__init__.py` - `g4f/client/stubs.py` - `g4f/gui/server/api.py` - `g4f/gui/server/backend_api.py` - `g4f/image/copy_images.py` - Rename CLI/API config field/flag from `image_provider` to `media_provider` (`g4f/cli.py`, `g4f/api/__init__.py`, `g4f/client/__init__.py`). - Extend `g4f/image/__init__.py` - add `MEDIA_TYPE_MAP`, `get_extension()` - revise `is_allowed_extension()`, `to_input_audio()` to support wider media types. - **Provider adjustments** - `g4f/Provider/ARTA.py`: swap `raise_error()` parameter order. - `g4f/Provider/Cloudflare.py`: drop unused `MissingRequirementsError` import; move `get_args_from_nodriver()` inside try; handle `FileNotFoundError`. - **Core enhancements** - `g4f/providers/any_provider.py`: use `default_model` instead of literal `"default"`; broaden model/provider matching; update model list cleanup. - `g4f/models.py`: safeguard provider count logic when model name is falsy. - `g4f/providers/base_provider.py`: catch `json.JSONDecodeError` when reading auth cache, delete corrupted file. - `g4f/providers/response.py`: allow `AudioResponse` to accept extra kwargs. - **Misc** - Remove obsolete `g4f/image.py`. - `g4f/Provider/Cloudflare.py`, `g4f/client/types.py`: minor whitespace and import tidy‑ups.
This commit is contained in:
16
docs/file.md
16
docs/file.md
@@ -180,23 +180,17 @@ fileInput.addEventListener('change', () => {
|
||||
|
||||
**Integrating with `ChatCompletion`:**
|
||||
|
||||
To incorporate file uploads into your client applications, include the `tool_calls` parameter in your chat completion requests, using the `bucket_tool` function. The `bucket_id` is passed as a JSON object within your prompt.
|
||||
|
||||
To incorporate file uploads into your client applications, include the `bucket` in your chat completion requests, using inline content parts.
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Answer this question using the files in the specified bucket: ...your question...\n{\"bucket_id\": \"your_actual_bucket_id\"}"
|
||||
}
|
||||
],
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "bucket_tool"
|
||||
},
|
||||
"type": "function"
|
||||
"content": [
|
||||
{"type": "text", "text": "Answer this question using the files in the specified bucket: ...your question..."},
|
||||
{"bucket_id": "your_actual_bucket_id"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@@ -30,6 +30,8 @@ asyncio.run(main())
|
||||
|
||||
#### **Transcribe an Audio File:**
|
||||
|
||||
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from g4f.client import AsyncClient
|
||||
@@ -41,15 +43,32 @@ async def main():
|
||||
with open("audio.wav", "rb") as audio_file:
|
||||
response = await client.chat.completions.create(
|
||||
messages="Transcribe this audio",
|
||||
provider=g4f.Provider.Microsoft_Phi_4,
|
||||
media=[[audio_file, "audio.wav"]],
|
||||
modalities=["text"],
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
|
||||
asyncio.run(main())
|
||||
print(response.choices[0].message.content)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
#### Explanation
|
||||
- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`.
|
||||
- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use.
|
||||
- **API Call**: The `chat.completions.create` method is called with:
|
||||
- `messages`: Containing a user message instructing the model to transcribe the audio.
|
||||
- `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`).
|
||||
- `modalities=["text"]`: Specifies that the output should be text (the transcription).
|
||||
- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed.
|
||||
|
||||
#### Notes
|
||||
- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality.
|
||||
- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider.
|
||||
- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models).
|
||||
|
||||
This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API.
|
||||
|
||||
---
|
||||
|
||||
### 2. **Image Generation**
|
||||
|
@@ -203,7 +203,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
else:
|
||||
raise ResponseError(f"Image generation failed with status: {status}")
|
||||
|
||||
async def raise_error(response: ClientResponse, message: str):
|
||||
async def raise_error(message: str, response: ClientResponse):
|
||||
if response.ok:
|
||||
return
|
||||
error_text = await response.text()
|
||||
|
@@ -20,7 +20,7 @@ from ..cookies import get_cookies_dir
|
||||
from .helper import format_image_prompt, render_messages
|
||||
from ..providers.response import JsonConversation, ImageResponse
|
||||
from ..tools.media import merge_media
|
||||
from ..errors import RateLimitError
|
||||
from ..errors import RateLimitError, NoValidHarFileError
|
||||
from .. import debug
|
||||
|
||||
class Conversation(JsonConversation):
|
||||
@@ -470,6 +470,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
except Exception as e:
|
||||
debug.log(f"Blackbox: Error reading HAR file {file}: {e}")
|
||||
return None
|
||||
except NoValidHarFileError:
|
||||
pass
|
||||
except Exception as e:
|
||||
debug.log(f"Blackbox: Error searching HAR files: {e}")
|
||||
return None
|
||||
|
@@ -8,7 +8,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileM
|
||||
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
||||
from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
|
||||
from ..providers.response import FinishReason, Usage
|
||||
from ..errors import ResponseStatusError, ModelNotFoundError, MissingRequirementsError
|
||||
from ..errors import ResponseStatusError, ModelNotFoundError
|
||||
from .. import debug
|
||||
from .helper import render_messages
|
||||
|
||||
@@ -72,11 +72,11 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||
except ResponseStatusError as f:
|
||||
if has_nodriver:
|
||||
get_running_loop(check_nested=True)
|
||||
args = get_args_from_nodriver(cls.url)
|
||||
try:
|
||||
args = get_args_from_nodriver(cls.url)
|
||||
cls._args = asyncio.run(args)
|
||||
read_models()
|
||||
except RuntimeError as e:
|
||||
except (RuntimeError, FileNotFoundError) as e:
|
||||
cls.models = cls.fallback_models
|
||||
debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
|
||||
else:
|
||||
|
@@ -28,6 +28,10 @@ try:
|
||||
from .mini_max import HailuoAI, MiniMax
|
||||
except ImportError as e:
|
||||
debug.error("MiniMax providers not loaded:", e)
|
||||
try:
|
||||
from .audio import EdgeTTS
|
||||
except ImportError as e:
|
||||
debug.error("Audio providers not loaded:", e)
|
||||
|
||||
try:
|
||||
from .AllenAI import AllenAI
|
||||
|
75
g4f/Provider/audio/EdgeTTS.py
Normal file
75
g4f/Provider/audio/EdgeTTS.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
import edge_tts
|
||||
from edge_tts import VoicesManager
|
||||
has_edge_tts = True
|
||||
except ImportError:
|
||||
has_edge_tts = False
|
||||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...providers.response import AudioResponse
|
||||
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_image_prompt
|
||||
|
||||
class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Edge TTS"
|
||||
working = has_edge_tts
|
||||
default_model = "edge-tts"
|
||||
default_locale = "en-US"
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
if not cls.models:
|
||||
voices = asyncio.run(VoicesManager.create())
|
||||
cls.default_model = voices.find(Locale=cls.default_locale)[0]["Name"]
|
||||
cls.models = [voice["Name"] for voice in voices.voices]
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
prompt: str = None,
|
||||
language: str = None,
|
||||
locale: str = None,
|
||||
audio: dict = {"voice": None, "format": "mp3"},
|
||||
extra_parameters: list[str] = ["rate", "volume", "pitch"],
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is empty.")
|
||||
voice = audio.get("voice", model)
|
||||
if not voice:
|
||||
voices = await VoicesManager.create()
|
||||
if locale is None:
|
||||
if language is None:
|
||||
voices = voices.find(Locale=cls.default_locale)
|
||||
elif "-" in language:
|
||||
voices = voices.find(Locale=language)
|
||||
else:
|
||||
voices = voices.find(Language=language)
|
||||
else:
|
||||
voices = voices.find(Locale=locale)
|
||||
if not voices:
|
||||
raise ValueError(f"No voices found for language '{language}' and locale '{locale}'.")
|
||||
voice = random.choice(voices)["Name"]
|
||||
|
||||
format = audio.get("format", "mp3")
|
||||
filename = get_filename([cls.default_model], prompt, f".{format}", prompt)
|
||||
target_path = os.path.join(get_media_dir(), filename)
|
||||
ensure_media_dir()
|
||||
|
||||
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
|
||||
communicate = edge_tts.Communicate(prompt, voice=voice, proxy=proxy, **extra_parameters)
|
||||
|
||||
await communicate.save(target_path)
|
||||
yield AudioResponse(f"/media/{filename}", voice=voice, prompt=prompt)
|
1
g4f/Provider/audio/__init__.py
Normal file
1
g4f/Provider/audio/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .EdgeTTS import EdgeTTS
|
@@ -40,7 +40,7 @@ from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_p
|
||||
from g4f.providers.response import BaseConversation, JsonConversation
|
||||
from g4f.client.helper import filter_none
|
||||
from g4f.image import is_data_an_media, EXTENSIONS_MAP
|
||||
from g4f.image.copy_images import images_dir, copy_media, get_source_url
|
||||
from g4f.image.copy_images import get_media_dir, copy_media, get_source_url
|
||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
|
||||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||
from g4f.providers.types import ProviderType
|
||||
@@ -130,7 +130,7 @@ class AppConfig:
|
||||
ignore_cookie_files: bool = False
|
||||
model: str = None
|
||||
provider: str = None
|
||||
image_provider: str = None
|
||||
media_provider: str = None
|
||||
proxy: str = None
|
||||
gui: bool = False
|
||||
demo: bool = False
|
||||
@@ -419,12 +419,13 @@ class Api:
|
||||
):
|
||||
if config.provider is None:
|
||||
config.provider = provider
|
||||
if config.provider is None:
|
||||
config.provider = AppConfig.media_provider
|
||||
if credentials is not None and credentials.credentials != "secret":
|
||||
config.api_key = credentials.credentials
|
||||
try:
|
||||
response = await self.client.images.generate(
|
||||
**config.dict(exclude_none=True),
|
||||
provider=AppConfig.image_provider if config.provider is None else config.provider
|
||||
)
|
||||
for image in response.data:
|
||||
if hasattr(image, "url") and image.url.startswith("/"):
|
||||
@@ -562,9 +563,9 @@ class Api:
|
||||
HTTP_404_NOT_FOUND: {}
|
||||
})
|
||||
async def get_media(filename, request: Request):
|
||||
target = os.path.join(images_dir, os.path.basename(filename))
|
||||
target = os.path.join(get_media_dir(), os.path.basename(filename))
|
||||
if not os.path.isfile(target):
|
||||
other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename)))
|
||||
other_name = os.path.join(get_media_dir(), os.path.basename(quote_plus(filename)))
|
||||
if os.path.isfile(other_name):
|
||||
target = other_name
|
||||
ext = os.path.splitext(filename)[1][1:]
|
||||
@@ -627,7 +628,7 @@ class Api:
|
||||
|
||||
def format_exception(e: Union[Exception, str], config: Union[ChatCompletionsConfig, ImageGenerationConfig] = None, image: bool = False) -> str:
|
||||
last_provider = {} if not image else g4f.get_last_provider(True)
|
||||
provider = (AppConfig.image_provider if image else AppConfig.provider)
|
||||
provider = (AppConfig.media_provider if image else AppConfig.provider)
|
||||
model = AppConfig.model
|
||||
if config is not None:
|
||||
if config.provider is not None:
|
||||
|
@@ -16,7 +16,7 @@ def get_api_parser():
|
||||
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --reload and --workers)")
|
||||
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
|
||||
default=None, help="Default provider for chat completion. (incompatible with --reload and --workers)")
|
||||
api_parser.add_argument("--image-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and hasattr(provider, "image_models")],
|
||||
api_parser.add_argument("--media-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and bool(getattr(provider, "image_models", False))],
|
||||
default=None, help="Default provider for image generation. (incompatible with --reload and --workers)"),
|
||||
api_parser.add_argument("--proxy", default=None, help="Default used proxy. (incompatible with --reload and --workers)")
|
||||
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
|
||||
@@ -59,7 +59,7 @@ def run_api_args(args):
|
||||
ignored_providers=args.ignored_providers,
|
||||
g4f_api_key=args.g4f_api_key,
|
||||
provider=args.provider,
|
||||
image_provider=args.image_provider,
|
||||
media_provider=args.media_provider,
|
||||
proxy=args.proxy,
|
||||
model=args.model,
|
||||
gui=args.gui,
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
@@ -8,7 +9,7 @@ import aiohttp
|
||||
import base64
|
||||
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
|
||||
|
||||
from ..image.copy_images import copy_media
|
||||
from ..image.copy_images import copy_media, get_media_dir
|
||||
from ..typing import Messages, ImageType
|
||||
from ..providers.types import ProviderType, BaseRetryProvider
|
||||
from ..providers.response import *
|
||||
@@ -16,11 +17,11 @@ from ..errors import NoMediaResponseError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.asyncio import to_sync_generator
|
||||
from ..providers.any_provider import AnyProvider
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..Provider import OpenaiAccount, PollinationsImage
|
||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
||||
from .models import ClientModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .types import IterResponse, Client as BaseClient
|
||||
from .service import convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||
from .. import debug
|
||||
@@ -261,15 +262,15 @@ class Client(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
provider: Optional[ProviderType] = None,
|
||||
image_provider: Optional[ImageProvider] = None,
|
||||
media_provider: Optional[ProviderType] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.chat: Chat = Chat(self, provider)
|
||||
if image_provider is None:
|
||||
image_provider = provider
|
||||
self.models: ClientModels = ClientModels(self, provider, image_provider)
|
||||
self.images: Images = Images(self, image_provider)
|
||||
if media_provider is None:
|
||||
media_provider = kwargs.get("image_provider", provider)
|
||||
self.models: ClientModels = ClientModels(self, provider, media_provider)
|
||||
self.images: Images = Images(self, media_provider)
|
||||
self.media: Images = self.images
|
||||
|
||||
class Completions:
|
||||
@@ -364,7 +365,7 @@ class Images:
|
||||
"""
|
||||
return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
|
||||
|
||||
async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
|
||||
async def get_provider_handler(self, model: Optional[str], provider: Optional[ProviderType], default: ProviderType) -> ProviderType:
|
||||
if provider is None:
|
||||
provider_handler = self.provider
|
||||
if provider_handler is None:
|
||||
@@ -387,7 +388,7 @@ class Images:
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
|
||||
provider_handler = await self.get_provider_handler(model, provider, PollinationsImage)
|
||||
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
|
||||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
@@ -407,20 +408,17 @@ class Images:
|
||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
|
||||
|
||||
if isinstance(response, MediaResponse):
|
||||
return await self._process_image_response(
|
||||
response,
|
||||
model,
|
||||
provider_name,
|
||||
response_format,
|
||||
proxy
|
||||
)
|
||||
if response is None:
|
||||
if error is not None:
|
||||
raise error
|
||||
raise NoMediaResponseError(f"No image response from {provider_name}")
|
||||
raise NoMediaResponseError(f"Unexpected response type: {type(response)}")
|
||||
raise NoMediaResponseError(f"No media response from {provider_name}")
|
||||
return await self._process_image_response(
|
||||
response,
|
||||
model,
|
||||
provider_name,
|
||||
response_format,
|
||||
proxy
|
||||
)
|
||||
|
||||
async def _generate_image_response(
|
||||
self,
|
||||
@@ -441,7 +439,7 @@ class Images:
|
||||
prompt=prompt,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(item, MediaResponse):
|
||||
if isinstance(item, (MediaResponse, AudioResponse)):
|
||||
items.append(item)
|
||||
elif hasattr(provider_handler, "create_completion"):
|
||||
for item in provider_handler.create_completion(
|
||||
@@ -451,13 +449,15 @@ class Images:
|
||||
prompt=prompt,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(item, MediaResponse):
|
||||
if isinstance(item, (MediaResponse, AudioResponse)):
|
||||
items.append(item)
|
||||
else:
|
||||
raise ValueError(f"Provider {provider_name} does not support image generation")
|
||||
urls = []
|
||||
for item in items:
|
||||
if isinstance(item.urls, str):
|
||||
if isinstance(item, AudioResponse):
|
||||
urls.append(item.to_uri())
|
||||
elif isinstance(item.urls, str):
|
||||
urls.append(item.urls)
|
||||
elif isinstance(item.urls, list):
|
||||
urls.extend(item.urls)
|
||||
@@ -508,14 +508,11 @@ class Images:
|
||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
|
||||
if isinstance(response, MediaResponse):
|
||||
return await self._process_image_response(response, model, provider_name, response_format, proxy)
|
||||
if response is None:
|
||||
if error is not None:
|
||||
raise error
|
||||
raise NoMediaResponseError(f"No image response from {provider_name}")
|
||||
raise NoMediaResponseError(f"Unexpected response type: {type(response)}")
|
||||
raise NoMediaResponseError(f"No media response from {provider_name}")
|
||||
return await self._process_image_response(response, model, provider_name, response_format, proxy)
|
||||
|
||||
async def _process_image_response(
|
||||
self,
|
||||
@@ -531,12 +528,16 @@ class Images:
|
||||
elif response_format == "b64_json":
|
||||
# Convert URLs directly to base64 without saving
|
||||
async def get_b64_from_url(url: str) -> Image:
|
||||
if url.startswith("/media/"):
|
||||
with open(os.path.join(get_media_dir(), os.path.basename(url)), "wb") as f:
|
||||
b64_data = base64.b64encode(f.read()).decode()
|
||||
return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
|
||||
async with aiohttp.ClientSession(cookies=response.get("cookies")) as session:
|
||||
async with session.get(url, proxy=proxy) as resp:
|
||||
if resp.status == 200:
|
||||
image_data = await resp.read()
|
||||
b64_data = base64.b64encode(image_data).decode()
|
||||
b64_data = base64.b64encode(await resp.read()).decode()
|
||||
return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
|
||||
return Image.model_construct(url=url, revised_prompt=response.alt)
|
||||
images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
|
||||
else:
|
||||
# Save locally for None (default) case
|
||||
@@ -554,15 +555,15 @@ class AsyncClient(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
provider: Optional[ProviderType] = None,
|
||||
image_provider: Optional[ImageProvider] = None,
|
||||
media_provider: Optional[ProviderType] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.chat: AsyncChat = AsyncChat(self, provider)
|
||||
if image_provider is None:
|
||||
image_provider = provider
|
||||
self.models: ClientModels = ClientModels(self, provider, image_provider)
|
||||
self.images: AsyncImages = AsyncImages(self, image_provider)
|
||||
if media_provider is None:
|
||||
media_provider = kwargs.get("image_provider", provider)
|
||||
self.models: ClientModels = ClientModels(self, provider, media_provider)
|
||||
self.images: AsyncImages = AsyncImages(self, media_provider)
|
||||
self.media: AsyncImages = self.images
|
||||
|
||||
class AsyncChat:
|
||||
|
@@ -5,7 +5,7 @@ from typing import Optional, List
|
||||
from time import time
|
||||
|
||||
from ..image import extract_data_uri
|
||||
from ..image.copy_images import images_dir
|
||||
from ..image.copy_images import get_media_dir
|
||||
from ..client.helper import filter_markdown
|
||||
from .helper import filter_none
|
||||
|
||||
@@ -123,7 +123,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
|
||||
def save(self, filepath: str, allowd_types = None):
|
||||
if hasattr(self.content, "data"):
|
||||
os.rename(self.content.data.replace("/media", images_dir), filepath)
|
||||
os.rename(self.content.data.replace("/media", get_media_dir()), filepath)
|
||||
return
|
||||
if self.content.startswith("data:"):
|
||||
with open(filepath, "wb") as f:
|
||||
|
@@ -6,7 +6,6 @@ from .stubs import ChatCompletion, ChatCompletionChunk
|
||||
from ..providers.types import BaseProvider
|
||||
from typing import Union, Iterator, AsyncIterator
|
||||
|
||||
ImageProvider = Union[BaseProvider, object]
|
||||
Proxies = Union[dict, str]
|
||||
IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]]
|
||||
AsyncIterResponse = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]
|
||||
@@ -19,7 +18,7 @@ class Client():
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.api_key: str = api_key
|
||||
self.proxies= proxies
|
||||
self.proxies = proxies
|
||||
self.proxy: str = self.get_proxy()
|
||||
|
||||
def get_proxy(self) -> Union[str, None]:
|
||||
|
@@ -8,7 +8,7 @@ from flask import send_from_directory
|
||||
from inspect import signature
|
||||
|
||||
from ...errors import VersionNotFoundError, MissingAuthError
|
||||
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
|
||||
from ...image.copy_images import copy_media, ensure_media_dir, get_media_dir
|
||||
from ...tools.run_tools import iter_run_tools
|
||||
from ... import Provider
|
||||
from ...providers.base_provider import ProviderModelMixin
|
||||
@@ -96,8 +96,8 @@ class Api:
|
||||
}
|
||||
|
||||
def serve_images(self, name):
|
||||
ensure_images_dir()
|
||||
return send_from_directory(os.path.abspath(images_dir), name)
|
||||
ensure_media_dir()
|
||||
return send_from_directory(os.path.abspath(get_media_dir()), name)
|
||||
|
||||
def _prepare_conversation_kwargs(self, json_data: dict):
|
||||
kwargs = {**json_data}
|
||||
|
@@ -25,7 +25,7 @@ from ...tools.run_tools import iter_run_tools
|
||||
from ...errors import ProviderNotFoundError
|
||||
from ...image import is_allowed_extension
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...image.copy_images import secure_filename, get_source_url, images_dir
|
||||
from ...image.copy_images import secure_filename, get_source_url, get_media_dir
|
||||
from ... import ChatCompletion
|
||||
from ... import models
|
||||
from .api import Api
|
||||
@@ -346,11 +346,12 @@ class Backend_Api(Api):
|
||||
@app.route('/search/<search>', methods=['GET'])
|
||||
def find_media(search: str):
|
||||
safe_search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
|
||||
if not os.access(images_dir, os.R_OK):
|
||||
media_dir = get_media_dir()
|
||||
if not os.access(media_dir, os.R_OK):
|
||||
return jsonify({"error": {"message": "Not found"}}), 404
|
||||
if search not in self.match_files:
|
||||
self.match_files[search] = {}
|
||||
for root, _, files in os.walk(images_dir):
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
mime_type = is_allowed_extension(file)
|
||||
if mime_type is not None:
|
||||
@@ -438,7 +439,7 @@ class Backend_Api(Api):
|
||||
def get_provider_models(self, provider: str):
|
||||
api_key = request.headers.get("x_api_key")
|
||||
api_base = request.headers.get("x_api_base")
|
||||
ignored = request.headers.get("x_ignored").split()
|
||||
ignored = request.headers.get("x_ignored", "").split()
|
||||
models = super().get_provider_models(provider, api_key, api_base, ignored)
|
||||
if models is None:
|
||||
return "Provider not found", 404
|
||||
|
253
g4f/image.py
253
g4f/image.py
@@ -1,253 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import base64
|
||||
from urllib.parse import quote_plus
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
try:
|
||||
from PIL.Image import open as open_image, new as new_image
|
||||
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
|
||||
from .typing import ImageType, Union, Image, Optional, Cookies
|
||||
from .errors import MissingRequirementsError
|
||||
from .requests.aiohttp import get_connector
|
||||
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
|
||||
|
||||
EXTENSIONS_MAP: dict[str, str] = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/gif": "gif",
|
||||
"image/webp": "webp",
|
||||
}
|
||||
|
||||
# Define the directory for generated images
|
||||
images_dir = "./generated_images"
|
||||
|
||||
def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
"""
|
||||
Converts the input image to a PIL Image object.
|
||||
|
||||
Args:
|
||||
image (Union[str, bytes, Image]): The input image.
|
||||
|
||||
Returns:
|
||||
Image: The converted PIL Image object.
|
||||
"""
|
||||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install "pillow" package for images')
|
||||
|
||||
if isinstance(image, str) and image.startswith("data:"):
|
||||
is_data_uri_an_image(image)
|
||||
image = extract_data_uri(image)
|
||||
|
||||
if is_svg:
|
||||
try:
|
||||
import cairosvg
|
||||
except ImportError:
|
||||
raise MissingRequirementsError('Install "cairosvg" package for svg images')
|
||||
if not isinstance(image, bytes):
|
||||
image = image.read()
|
||||
buffer = BytesIO()
|
||||
cairosvg.svg2png(image, write_to=buffer)
|
||||
return open_image(buffer)
|
||||
|
||||
if isinstance(image, bytes):
|
||||
is_accepted_format(image)
|
||||
return open_image(BytesIO(image))
|
||||
elif not isinstance(image, Image):
|
||||
image = open_image(image)
|
||||
image.load()
|
||||
return image
|
||||
|
||||
return image
|
||||
|
||||
def is_allowed_extension(filename: str) -> bool:
|
||||
"""
|
||||
Checks if the given filename has an allowed extension.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the extension is allowed, False otherwise.
|
||||
"""
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
def is_data_uri_an_image(data_uri: str) -> bool:
|
||||
"""
|
||||
Checks if the given data URI represents an image.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the data URI is invalid or the image format is not allowed.
|
||||
"""
|
||||
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
|
||||
if not re.match(r'data:image/(\w+);base64,', data_uri):
|
||||
raise ValueError("Invalid data URI image.")
|
||||
# Extract the image format from the data URI
|
||||
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
|
||||
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
|
||||
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
|
||||
raise ValueError("Invalid image format (from mime file type).")
|
||||
|
||||
def is_accepted_format(binary_data: bytes) -> str:
|
||||
"""
|
||||
Checks if the given binary data represents an image with an accepted format.
|
||||
|
||||
Args:
|
||||
binary_data (bytes): The binary data to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image format is not allowed.
|
||||
"""
|
||||
if binary_data.startswith(b'\xFF\xD8\xFF'):
|
||||
return "image/jpeg"
|
||||
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
|
||||
return "image/png"
|
||||
elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
|
||||
return "image/gif"
|
||||
elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
|
||||
return "image/jpeg"
|
||||
elif binary_data.startswith(b'\xFF\xD8'):
|
||||
return "image/jpeg"
|
||||
elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
|
||||
return "image/webp"
|
||||
else:
|
||||
raise ValueError("Invalid image format (from magic code).")
|
||||
|
||||
def extract_data_uri(data_uri: str) -> bytes:
|
||||
"""
|
||||
Extracts the binary data from the given data URI.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI.
|
||||
|
||||
Returns:
|
||||
bytes: The extracted binary data.
|
||||
"""
|
||||
data = data_uri.split(",")[-1]
|
||||
data = base64.b64decode(data)
|
||||
return data
|
||||
|
||||
def get_orientation(image: Image) -> int:
|
||||
"""
|
||||
Gets the orientation of the given image.
|
||||
|
||||
Args:
|
||||
image (Image): The image.
|
||||
|
||||
Returns:
|
||||
int: The orientation value.
|
||||
"""
|
||||
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
|
||||
if exif_data is not None:
|
||||
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
|
||||
if orientation is not None:
|
||||
return orientation
|
||||
|
||||
def process_image(image: Image, new_width: int, new_height: int) -> Image:
|
||||
"""
|
||||
Processes the given image by adjusting its orientation and resizing it.
|
||||
|
||||
Args:
|
||||
image (Image): The image to process.
|
||||
new_width (int): The new width of the image.
|
||||
new_height (int): The new height of the image.
|
||||
|
||||
Returns:
|
||||
Image: The processed image.
|
||||
"""
|
||||
# Fix orientation
|
||||
orientation = get_orientation(image)
|
||||
if orientation:
|
||||
if orientation > 4:
|
||||
image = image.transpose(FLIP_LEFT_RIGHT)
|
||||
if orientation in [3, 4]:
|
||||
image = image.transpose(ROTATE_180)
|
||||
if orientation in [5, 6]:
|
||||
image = image.transpose(ROTATE_270)
|
||||
if orientation in [7, 8]:
|
||||
image = image.transpose(ROTATE_90)
|
||||
# Resize image
|
||||
image.thumbnail((new_width, new_height))
|
||||
# Remove transparency
|
||||
if image.mode == "RGBA":
|
||||
image.load()
|
||||
white = new_image('RGB', image.size, (255, 255, 255))
|
||||
white.paste(image, mask=image.split()[-1])
|
||||
return white
|
||||
# Convert to RGB for jpg format
|
||||
elif image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def to_bytes(image: ImageType) -> bytes:
|
||||
"""
|
||||
Converts the given image to bytes.
|
||||
|
||||
Args:
|
||||
image (ImageType): The image to convert.
|
||||
|
||||
Returns:
|
||||
bytes: The image as bytes.
|
||||
"""
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
elif isinstance(image, str) and image.startswith("data:"):
|
||||
is_data_uri_an_image(image)
|
||||
return extract_data_uri(image)
|
||||
elif isinstance(image, Image):
|
||||
bytes_io = BytesIO()
|
||||
image.save(bytes_io, image.format)
|
||||
image.seek(0)
|
||||
return bytes_io.getvalue()
|
||||
elif isinstance(image, (str, os.PathLike)):
|
||||
return Path(image).read_bytes()
|
||||
elif isinstance(image, Path):
|
||||
return image.read_bytes()
|
||||
else:
|
||||
try:
|
||||
image.seek(0)
|
||||
except (AttributeError, io.UnsupportedOperation):
|
||||
pass
|
||||
return image.read()
|
||||
|
||||
def to_data_uri(image: ImageType) -> str:
|
||||
if not isinstance(image, str):
|
||||
data = to_bytes(image)
|
||||
data_base64 = base64.b64encode(data).decode()
|
||||
return f"data:{is_accepted_format(data)};base64,{data_base64}"
|
||||
return image
|
||||
|
||||
class ImageDataResponse():
|
||||
def __init__(
|
||||
self,
|
||||
images: Union[str, list],
|
||||
alt: str,
|
||||
):
|
||||
self.images = images
|
||||
self.alt = alt
|
||||
|
||||
def get_list(self) -> list[str]:
|
||||
return [self.images] if isinstance(self.images, str) else self.images
|
||||
|
||||
class ImageRequest:
|
||||
def __init__(
|
||||
self,
|
||||
options: dict = {}
|
||||
):
|
||||
self.options = options
|
||||
|
||||
def get(self, key: str):
|
||||
return self.options.get(key)
|
@@ -17,13 +17,6 @@ except ImportError:
|
||||
from ..typing import ImageType, Union, Image
|
||||
from ..errors import MissingRequirementsError
|
||||
|
||||
MEDIA_TYPE_MAP: dict[str, str] = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/gif": "gif",
|
||||
"image/webp": "webp",
|
||||
}
|
||||
|
||||
EXTENSIONS_MAP: dict[str, str] = {
|
||||
# Image
|
||||
"png": "image/png",
|
||||
@@ -44,6 +37,8 @@ EXTENSIONS_MAP: dict[str, str] = {
|
||||
"mp4": "video/mp4",
|
||||
}
|
||||
|
||||
MEDIA_TYPE_MAP: dict[str, str] = {value: key for key, value in EXTENSIONS_MAP.items()}
|
||||
|
||||
def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
"""
|
||||
Converts the input image to a PIL Image object.
|
||||
@@ -82,6 +77,12 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
|
||||
return image
|
||||
|
||||
def get_extension(filename: str) -> Optional[str]:
|
||||
if '.' in filename:
|
||||
ext = os.path.splitext(filename)[1][1:].lower()
|
||||
return ext if ext in EXTENSIONS_MAP else None
|
||||
return None
|
||||
|
||||
def is_allowed_extension(filename: str) -> Optional[str]:
|
||||
"""
|
||||
Checks if the given filename has an allowed extension.
|
||||
@@ -92,8 +93,10 @@ def is_allowed_extension(filename: str) -> Optional[str]:
|
||||
Returns:
|
||||
bool: True if the extension is allowed, False otherwise.
|
||||
"""
|
||||
ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None
|
||||
return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None
|
||||
extension = get_extension(filename)
|
||||
if extension is None:
|
||||
return None
|
||||
return EXTENSIONS_MAP[extension]
|
||||
|
||||
def is_data_an_media(data, filename: str = None) -> str:
|
||||
content_type = is_data_an_audio(data, filename)
|
||||
@@ -105,12 +108,11 @@ def is_data_an_media(data, filename: str = None) -> str:
|
||||
|
||||
def is_data_an_audio(data_uri: str = None, filename: str = None) -> str:
|
||||
if filename:
|
||||
if filename.endswith(".wav"):
|
||||
return "audio/wav"
|
||||
elif filename.endswith(".mp3"):
|
||||
return "audio/mpeg"
|
||||
elif filename.endswith(".m4a"):
|
||||
return "audio/m4a"
|
||||
extension = get_extension(filename)
|
||||
if extension is not None:
|
||||
media_type = EXTENSIONS_MAP[extension]
|
||||
if media_type.startswith("audio/"):
|
||||
return media_type
|
||||
if isinstance(data_uri, str):
|
||||
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
|
||||
if audio_format:
|
||||
@@ -266,10 +268,13 @@ def to_data_uri(image: ImageType, filename: str = None) -> str:
|
||||
|
||||
def to_input_audio(audio: ImageType, filename: str = None) -> str:
|
||||
if not isinstance(audio, str):
|
||||
if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
|
||||
if filename is not None:
|
||||
format = get_extension(filename)
|
||||
if format is None:
|
||||
raise ValueError("Invalid input audio")
|
||||
return {
|
||||
"data": base64.b64encode(to_bytes(audio)).decode(),
|
||||
"format": "wav" if filename.endswith(".wav") else "mp3"
|
||||
"format": format
|
||||
}
|
||||
raise ValueError("Invalid input audio")
|
||||
audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
|
||||
|
@@ -21,6 +21,13 @@ from .. import debug
|
||||
|
||||
# Directory for storing generated images
|
||||
images_dir = "./generated_images"
|
||||
media_dir = "./generated_media"
|
||||
|
||||
def get_media_dir() -> str:#
|
||||
"""Get the directory for storing generated media files"""
|
||||
if os.access(images_dir, os.R_OK):
|
||||
return images_dir
|
||||
return media_dir
|
||||
|
||||
def get_media_extension(media: str) -> str:
|
||||
"""Extract media file extension from URL or filename"""
|
||||
@@ -34,9 +41,10 @@ def get_media_extension(media: str) -> str:
|
||||
raise ValueError(f"Unsupported media extension: {extension} in: {media}")
|
||||
return extension
|
||||
|
||||
def ensure_images_dir():
|
||||
def ensure_media_dir():
|
||||
"""Create images directory if it doesn't exist"""
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
if not os.access(images_dir, os.R_OK):
|
||||
os.makedirs(media_dir, exist_ok=True)
|
||||
|
||||
def get_source_url(image: str, default: str = None) -> str:
|
||||
"""Extract original URL from image parameter if present"""
|
||||
@@ -46,30 +54,27 @@ def get_source_url(image: str, default: str = None) -> str:
|
||||
return decoded_url
|
||||
return default
|
||||
|
||||
def is_valid_media_type(content_type: str) -> bool:
|
||||
return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/")
|
||||
|
||||
async def save_response_media(response: StreamResponse, prompt: str, tags: list[str]) -> AsyncIterator:
|
||||
"""Save media from response to local file and return URL"""
|
||||
content_type = response.headers["content-type"]
|
||||
if is_valid_media_type(content_type):
|
||||
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
|
||||
if extension not in EXTENSIONS_MAP:
|
||||
raise ValueError(f"Unsupported media type: {content_type}")
|
||||
filename = get_filename(tags, prompt, f".{extension}", prompt)
|
||||
target_path = os.path.join(images_dir, filename)
|
||||
with open(target_path, 'wb') as f:
|
||||
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
|
||||
f.write(chunk)
|
||||
media_url = f"/media/{filename}"
|
||||
if response.method == "GET":
|
||||
media_url = f"{media_url}?url={str(response.url)}"
|
||||
if content_type.startswith("audio/"):
|
||||
yield AudioResponse(media_url)
|
||||
elif content_type.startswith("video/"):
|
||||
yield VideoResponse(media_url, prompt)
|
||||
else:
|
||||
yield ImageResponse(media_url, prompt)
|
||||
extension = MEDIA_TYPE_MAP.get(content_type)
|
||||
if extension is None:
|
||||
raise ValueError(f"Unsupported media type: {content_type}")
|
||||
filename = get_filename(tags, prompt, f".{extension}", prompt)
|
||||
target_path = os.path.join(get_media_dir(), filename)
|
||||
ensure_media_dir()
|
||||
with open(target_path, 'wb') as f:
|
||||
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
|
||||
f.write(chunk)
|
||||
media_url = f"/media/{filename}"
|
||||
if response.method == "GET":
|
||||
media_url = f"{media_url}?url={str(response.url)}"
|
||||
if content_type.startswith("audio/"):
|
||||
yield AudioResponse(media_url)
|
||||
elif content_type.startswith("video/"):
|
||||
yield VideoResponse(media_url, prompt)
|
||||
else:
|
||||
yield ImageResponse(media_url, prompt)
|
||||
|
||||
def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
|
||||
return "".join((
|
||||
@@ -97,7 +102,7 @@ async def copy_media(
|
||||
"""
|
||||
if add_url:
|
||||
add_url = not cookies
|
||||
ensure_images_dir()
|
||||
ensure_media_dir()
|
||||
|
||||
async with ClientSession(
|
||||
connector=get_connector(proxy=proxy),
|
||||
@@ -113,7 +118,7 @@ async def copy_media(
|
||||
if target_path is None:
|
||||
# Build safe filename with full Unicode support
|
||||
filename = get_filename(tags, alt, get_media_extension(image), image)
|
||||
target_path = os.path.join(images_dir, filename)
|
||||
target_path = os.path.join(get_media_dir(), filename)
|
||||
try:
|
||||
# Handle different image types
|
||||
if image.startswith("data:"):
|
||||
@@ -132,7 +137,7 @@ async def copy_media(
|
||||
response.raise_for_status()
|
||||
media_type = response.headers.get("content-type", "application/octet-stream")
|
||||
if media_type not in ("application/octet-stream", "binary/octet-stream"):
|
||||
if not is_valid_media_type(media_type):
|
||||
if media_type not in MEDIA_TYPE_MAP:
|
||||
raise ValueError(f"Unsupported media type: {media_type}")
|
||||
with open(target_path, "wb") as f:
|
||||
async for chunk in response.content.iter_any():
|
||||
|
@@ -1006,6 +1006,6 @@ __models__ = {
|
||||
if model.best_provider is not None and model.best_provider.working
|
||||
else [])
|
||||
for model in ModelUtils.convert.values()]
|
||||
if [p for p in providers if p.working]
|
||||
if model.name and [True for provider in providers if provider.working]
|
||||
}
|
||||
_all_models = list(__models__.keys())
|
||||
|
@@ -37,7 +37,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls.models_count = {
|
||||
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
|
||||
}
|
||||
all_models = ["default"] + list(model_with_providers.keys())
|
||||
all_models = [cls.default_model] + list(model_with_providers.keys())
|
||||
for provider in [OpenaiChat, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok]:
|
||||
if not provider.working or getattr(provider, "parent", provider.__name__) in ignored:
|
||||
continue
|
||||
@@ -63,6 +63,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
).replace("-03-2025", ""
|
||||
).replace("-20250219", ""
|
||||
).replace("-20241022", ""
|
||||
).replace("-20240904", ""
|
||||
).replace("-2025-04-16", ""
|
||||
).replace("-2025-04-14", ""
|
||||
).replace("-0125", ""
|
||||
@@ -72,10 +73,13 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
).replace("-2409", ""
|
||||
).replace("-2410", ""
|
||||
).replace("-2411", ""
|
||||
).replace("-1119", ""
|
||||
).replace("-0919", ""
|
||||
).replace("-02-24", ""
|
||||
).replace("-03-25", ""
|
||||
).replace("-03-26", ""
|
||||
).replace("-01-21", ""
|
||||
).replace("-002", ""
|
||||
).replace(".1-", "-"
|
||||
).replace("_", "."
|
||||
).replace("c4ai-", ""
|
||||
@@ -98,8 +102,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
for provider in [Microsoft_Phi_4, PollinationsAI]:
|
||||
if provider.working and getattr(provider, "parent", provider.__name__) not in ignored:
|
||||
cls.audio_models.update(provider.audio_models)
|
||||
cls.models_count.update({model: all_models.count(model) + cls.models_count.get(model, 0) for model in all_models})
|
||||
return list(dict.fromkeys([model if model else "default" for model in all_models]))
|
||||
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
|
||||
return list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@@ -117,7 +121,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
providers = model.split(":")
|
||||
model = providers.pop()
|
||||
providers = [getattr(Provider, provider) for provider in providers]
|
||||
elif not model or model == "default":
|
||||
elif not model or model == cls.default_model:
|
||||
model = ""
|
||||
has_image = False
|
||||
has_audio = "audio" in kwargs
|
||||
if not has_audio and media is not None:
|
||||
@@ -133,11 +138,11 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
else:
|
||||
providers = models.default.best_provider.providers
|
||||
else:
|
||||
for provider in [OpenaiChat, HuggingSpace, Cloudflare, LMArenaProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox]:
|
||||
if provider.working and (model if model else "auto") in provider.get_models():
|
||||
providers.append(provider)
|
||||
for provider in [HuggingFace, HuggingFaceMedia, LambdaChat, LMArenaProvider, CopilotAccount, PollinationsAI, DeepInfraChat]:
|
||||
if model in provider.model_aliases:
|
||||
for provider in [
|
||||
OpenaiChat, Cloudflare, LMArenaProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox,
|
||||
HuggingFace, HuggingFaceMedia, HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat
|
||||
]:
|
||||
if not model or model in provider.get_models() or model in provider.model_aliases:
|
||||
providers.append(provider)
|
||||
if model in models.__models__:
|
||||
for provider in models.__models__[model][1]:
|
||||
|
@@ -449,10 +449,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||
cache_file = cls.get_cache_file()
|
||||
try:
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
data = f.read()
|
||||
if data:
|
||||
auth_result = AuthResult(**json.loads(data))
|
||||
try:
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
except json.JSONDecodeError:
|
||||
cache_file.unlink()
|
||||
raise MissingAuthError(f"Invalid auth file: {cache_file}")
|
||||
else:
|
||||
raise MissingAuthError
|
||||
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
@@ -478,8 +480,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||
cache_file = cls.get_cache_file()
|
||||
try:
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
try:
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
except json.JSONDecodeError:
|
||||
cache_file.unlink()
|
||||
raise MissingAuthError(f"Invalid auth file: {cache_file}")
|
||||
else:
|
||||
raise MissingAuthError
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
|
@@ -264,9 +264,10 @@ class YouTube(HiddenResponse):
|
||||
]))
|
||||
|
||||
class AudioResponse(ResponseType):
|
||||
def __init__(self, data: Union[bytes, str]) -> None:
|
||||
def __init__(self, data: Union[bytes, str], **kwargs) -> None:
|
||||
"""Initialize with audio data bytes."""
|
||||
self.data = data
|
||||
self.options = kwargs
|
||||
|
||||
def to_uri(self) -> str:
|
||||
if isinstance(self.data, str):
|
||||
|
Reference in New Issue
Block a user