feat: add EdgeTTS audio provider and global image→media refactor

- **Docs**
  - `docs/file.md`: update upload instructions to use inline `bucket` content parts instead of `tool_calls/bucket_tool`.
  - `docs/media.md`: add asynchronous audio transcription example, detailed explanation, and notes.

- **New audio provider**
  - Add `g4f/Provider/audio/EdgeTTS.py` implementing Edge Text‑to‑Speech (`EdgeTTS`).
  - Create `g4f/Provider/audio/__init__.py` for provider export.
  - Register provider in `g4f/Provider/__init__.py`.

- **Refactor image → media**
  - Introduce `generated_media/` directory and `get_media_dir()` helper in `g4f/image/copy_images.py`; add `ensure_media_dir()`; keep back‑compat with legacy `generated_images/`.
  - Replace `images_dir` references with `get_media_dir()` across:
    - `g4f/api/__init__.py`
    - `g4f/client/stubs.py`
    - `g4f/gui/server/api.py`
    - `g4f/gui/server/backend_api.py`
    - `g4f/image/copy_images.py`
  - Rename CLI/API config field/flag from `image_provider` to `media_provider` (`g4f/cli.py`, `g4f/api/__init__.py`, `g4f/client/__init__.py`).
  - Extend `g4f/image/__init__.py`
    - add `MEDIA_TYPE_MAP`, `get_extension()`
    - revise `is_allowed_extension()`, `to_input_audio()` to support wider media types.

- **Provider adjustments**
  - `g4f/Provider/ARTA.py`: swap `raise_error()` parameter order.
  - `g4f/Provider/Cloudflare.py`: drop unused `MissingRequirementsError` import; move `get_args_from_nodriver()` inside try; handle `FileNotFoundError`.

- **Core enhancements**
  - `g4f/providers/any_provider.py`: use `default_model` instead of literal `"default"`; broaden model/provider matching; update model list cleanup.
  - `g4f/models.py`: safeguard provider count logic when model name is falsy.
  - `g4f/providers/base_provider.py`: catch `json.JSONDecodeError` when reading auth cache, delete corrupted file.
  - `g4f/providers/response.py`: allow `AudioResponse` to accept extra kwargs.

- **Misc**
  - Remove obsolete `g4f/image.py`.
  - `g4f/Provider/Cloudflare.py`, `g4f/client/types.py`: minor whitespace and import tidy‑ups.
This commit is contained in:
hlohaus
2025-04-19 03:20:57 +02:00
parent 0a070bdf10
commit e83282fc4b
23 changed files with 253 additions and 387 deletions

View File

@@ -180,23 +180,17 @@ fileInput.addEventListener('change', () => {
**Integrating with `ChatCompletion`:**
To incorporate file uploads into your client applications, include the `tool_calls` parameter in your chat completion requests, using the `bucket_tool` function. The `bucket_id` is passed as a JSON object within your prompt.
To incorporate file uploads into your client applications, include the `bucket` in your chat completion requests, using inline content parts.
```json
{
"messages": [
{
"role": "user",
"content": "Answer this question using the files in the specified bucket: ...your question...\n{\"bucket_id\": \"your_actual_bucket_id\"}"
}
],
"tool_calls": [
{
"function": {
"name": "bucket_tool"
},
"type": "function"
"content": [
{"type": "text", "text": "Answer this question using the files in the specified bucket: ...your question..."},
{"bucket_id": "your_actual_bucket_id"}
]
}
]
}

View File

@@ -30,6 +30,8 @@ asyncio.run(main())
#### **Transcribe an Audio File:**
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:
```python
import asyncio
from g4f.client import AsyncClient
@@ -41,15 +43,32 @@ async def main():
with open("audio.wav", "rb") as audio_file:
response = await client.chat.completions.create(
messages="Transcribe this audio",
provider=g4f.Provider.Microsoft_Phi_4,
media=[[audio_file, "audio.wav"]],
modalities=["text"],
)
print(response.choices[0].message.content)
asyncio.run(main())
print(response.choices[0].message.content)
if __name__ == "__main__":
asyncio.run(main())
```
#### Explanation
- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`.
- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use.
- **API Call**: The `chat.completions.create` method is called with:
- `messages`: Containing a user message instructing the model to transcribe the audio.
- `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`).
- `modalities=["text"]`: Specifies that the output should be text (the transcription).
- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed.
#### Notes
- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality.
- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider.
- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models).
This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API.
---
### 2. **Image Generation**

View File

@@ -203,7 +203,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
else:
raise ResponseError(f"Image generation failed with status: {status}")
async def raise_error(response: ClientResponse, message: str):
async def raise_error(message: str, response: ClientResponse):
if response.ok:
return
error_text = await response.text()

View File

@@ -20,7 +20,7 @@ from ..cookies import get_cookies_dir
from .helper import format_image_prompt, render_messages
from ..providers.response import JsonConversation, ImageResponse
from ..tools.media import merge_media
from ..errors import RateLimitError
from ..errors import RateLimitError, NoValidHarFileError
from .. import debug
class Conversation(JsonConversation):
@@ -470,6 +470,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
except Exception as e:
debug.log(f"Blackbox: Error reading HAR file {file}: {e}")
return None
except NoValidHarFileError:
pass
except Exception as e:
debug.log(f"Blackbox: Error searching HAR files: {e}")
return None

View File

@@ -8,7 +8,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileM
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
from ..providers.response import FinishReason, Usage
from ..errors import ResponseStatusError, ModelNotFoundError, MissingRequirementsError
from ..errors import ResponseStatusError, ModelNotFoundError
from .. import debug
from .helper import render_messages
@@ -72,11 +72,11 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
except ResponseStatusError as f:
if has_nodriver:
get_running_loop(check_nested=True)
args = get_args_from_nodriver(cls.url)
try:
args = get_args_from_nodriver(cls.url)
cls._args = asyncio.run(args)
read_models()
except RuntimeError as e:
except (RuntimeError, FileNotFoundError) as e:
cls.models = cls.fallback_models
debug.log(f"Nodriver is not available: {type(e).__name__}: {e}")
else:

View File

@@ -28,6 +28,10 @@ try:
from .mini_max import HailuoAI, MiniMax
except ImportError as e:
debug.error("MiniMax providers not loaded:", e)
try:
from .audio import EdgeTTS
except ImportError as e:
debug.error("Audio providers not loaded:", e)
try:
from .AllenAI import AllenAI

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import os
import random
import asyncio
try:
import edge_tts
from edge_tts import VoicesManager
has_edge_tts = True
except ImportError:
has_edge_tts = False
from ...typing import AsyncResult, Messages
from ...providers.response import AudioResponse
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
label = "Edge TTS"
working = has_edge_tts
default_model = "edge-tts"
default_locale = "en-US"
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
voices = asyncio.run(VoicesManager.create())
cls.default_model = voices.find(Locale=cls.default_locale)[0]["Name"]
cls.models = [voice["Name"] for voice in voices.voices]
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
prompt: str = None,
language: str = None,
locale: str = None,
audio: dict = {"voice": None, "format": "mp3"},
extra_parameters: list[str] = ["rate", "volume", "pitch"],
**kwargs
) -> AsyncResult:
prompt = format_image_prompt(messages, prompt)
if not prompt:
raise ValueError("Prompt is empty.")
voice = audio.get("voice", model)
if not voice:
voices = await VoicesManager.create()
if locale is None:
if language is None:
voices = voices.find(Locale=cls.default_locale)
elif "-" in language:
voices = voices.find(Locale=language)
else:
voices = voices.find(Language=language)
else:
voices = voices.find(Locale=locale)
if not voices:
raise ValueError(f"No voices found for language '{language}' and locale '{locale}'.")
voice = random.choice(voices)["Name"]
format = audio.get("format", "mp3")
filename = get_filename([cls.default_model], prompt, f".{format}", prompt)
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
communicate = edge_tts.Communicate(prompt, voice=voice, proxy=proxy, **extra_parameters)
await communicate.save(target_path)
yield AudioResponse(f"/media/{filename}", voice=voice, prompt=prompt)

View File

@@ -0,0 +1 @@
from .EdgeTTS import EdgeTTS

View File

@@ -40,7 +40,7 @@ from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_p
from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none
from g4f.image import is_data_an_media, EXTENSIONS_MAP
from g4f.image.copy_images import images_dir, copy_media, get_source_url
from g4f.image.copy_images import get_media_dir, copy_media, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.providers.types import ProviderType
@@ -130,7 +130,7 @@ class AppConfig:
ignore_cookie_files: bool = False
model: str = None
provider: str = None
image_provider: str = None
media_provider: str = None
proxy: str = None
gui: bool = False
demo: bool = False
@@ -419,12 +419,13 @@ class Api:
):
if config.provider is None:
config.provider = provider
if config.provider is None:
config.provider = AppConfig.media_provider
if credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials
try:
response = await self.client.images.generate(
**config.dict(exclude_none=True),
provider=AppConfig.image_provider if config.provider is None else config.provider
)
for image in response.data:
if hasattr(image, "url") and image.url.startswith("/"):
@@ -562,9 +563,9 @@ class Api:
HTTP_404_NOT_FOUND: {}
})
async def get_media(filename, request: Request):
target = os.path.join(images_dir, os.path.basename(filename))
target = os.path.join(get_media_dir(), os.path.basename(filename))
if not os.path.isfile(target):
other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename)))
other_name = os.path.join(get_media_dir(), os.path.basename(quote_plus(filename)))
if os.path.isfile(other_name):
target = other_name
ext = os.path.splitext(filename)[1][1:]
@@ -627,7 +628,7 @@ class Api:
def format_exception(e: Union[Exception, str], config: Union[ChatCompletionsConfig, ImageGenerationConfig] = None, image: bool = False) -> str:
last_provider = {} if not image else g4f.get_last_provider(True)
provider = (AppConfig.image_provider if image else AppConfig.provider)
provider = (AppConfig.media_provider if image else AppConfig.provider)
model = AppConfig.model
if config is not None:
if config.provider is not None:

View File

@@ -16,7 +16,7 @@ def get_api_parser():
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --reload and --workers)")
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
default=None, help="Default provider for chat completion. (incompatible with --reload and --workers)")
api_parser.add_argument("--image-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and hasattr(provider, "image_models")],
api_parser.add_argument("--media-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and bool(getattr(provider, "image_models", False))],
default=None, help="Default provider for image generation. (incompatible with --reload and --workers)"),
api_parser.add_argument("--proxy", default=None, help="Default used proxy. (incompatible with --reload and --workers)")
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
@@ -59,7 +59,7 @@ def run_api_args(args):
ignored_providers=args.ignored_providers,
g4f_api_key=args.g4f_api_key,
provider=args.provider,
image_provider=args.image_provider,
media_provider=args.media_provider,
proxy=args.proxy,
model=args.model,
gui=args.gui,

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import os
import time
import random
import string
@@ -8,7 +9,7 @@ import aiohttp
import base64
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image.copy_images import copy_media
from ..image.copy_images import copy_media, get_media_dir
from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import *
@@ -16,11 +17,11 @@ from ..errors import NoMediaResponseError
from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator
from ..providers.any_provider import AnyProvider
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..Provider import OpenaiAccount, PollinationsImage
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
from .models import ClientModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .types import IterResponse, Client as BaseClient
from .service import convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose
from .. import debug
@@ -261,15 +262,15 @@ class Client(BaseClient):
def __init__(
self,
provider: Optional[ProviderType] = None,
image_provider: Optional[ImageProvider] = None,
media_provider: Optional[ProviderType] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
self.chat: Chat = Chat(self, provider)
if image_provider is None:
image_provider = provider
self.models: ClientModels = ClientModels(self, provider, image_provider)
self.images: Images = Images(self, image_provider)
if media_provider is None:
media_provider = kwargs.get("image_provider", provider)
self.models: ClientModels = ClientModels(self, provider, media_provider)
self.images: Images = Images(self, media_provider)
self.media: Images = self.images
class Completions:
@@ -364,7 +365,7 @@ class Images:
"""
return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
async def get_provider_handler(self, model: Optional[str], provider: Optional[ProviderType], default: ProviderType) -> ProviderType:
if provider is None:
provider_handler = self.provider
if provider_handler is None:
@@ -387,7 +388,7 @@ class Images:
api_key: Optional[str] = None,
**kwargs
) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
provider_handler = await self.get_provider_handler(model, provider, PollinationsImage)
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
if proxy is None:
proxy = self.client.proxy
@@ -407,20 +408,17 @@ class Images:
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
if isinstance(response, MediaResponse):
return await self._process_image_response(
response,
model,
provider_name,
response_format,
proxy
)
if response is None:
if error is not None:
raise error
raise NoMediaResponseError(f"No image response from {provider_name}")
raise NoMediaResponseError(f"Unexpected response type: {type(response)}")
raise NoMediaResponseError(f"No media response from {provider_name}")
return await self._process_image_response(
response,
model,
provider_name,
response_format,
proxy
)
async def _generate_image_response(
self,
@@ -441,7 +439,7 @@ class Images:
prompt=prompt,
**kwargs
):
if isinstance(item, MediaResponse):
if isinstance(item, (MediaResponse, AudioResponse)):
items.append(item)
elif hasattr(provider_handler, "create_completion"):
for item in provider_handler.create_completion(
@@ -451,13 +449,15 @@ class Images:
prompt=prompt,
**kwargs
):
if isinstance(item, MediaResponse):
if isinstance(item, (MediaResponse, AudioResponse)):
items.append(item)
else:
raise ValueError(f"Provider {provider_name} does not support image generation")
urls = []
for item in items:
if isinstance(item.urls, str):
if isinstance(item, AudioResponse):
urls.append(item.to_uri())
elif isinstance(item.urls, str):
urls.append(item.urls)
elif isinstance(item.urls, list):
urls.extend(item.urls)
@@ -508,14 +508,11 @@ class Images:
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
if isinstance(response, MediaResponse):
return await self._process_image_response(response, model, provider_name, response_format, proxy)
if response is None:
if error is not None:
raise error
raise NoMediaResponseError(f"No image response from {provider_name}")
raise NoMediaResponseError(f"Unexpected response type: {type(response)}")
raise NoMediaResponseError(f"No media response from {provider_name}")
return await self._process_image_response(response, model, provider_name, response_format, proxy)
async def _process_image_response(
self,
@@ -531,12 +528,16 @@ class Images:
elif response_format == "b64_json":
# Convert URLs directly to base64 without saving
async def get_b64_from_url(url: str) -> Image:
if url.startswith("/media/"):
with open(os.path.join(get_media_dir(), os.path.basename(url)), "wb") as f:
b64_data = base64.b64encode(f.read()).decode()
return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
async with aiohttp.ClientSession(cookies=response.get("cookies")) as session:
async with session.get(url, proxy=proxy) as resp:
if resp.status == 200:
image_data = await resp.read()
b64_data = base64.b64encode(image_data).decode()
b64_data = base64.b64encode(await resp.read()).decode()
return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
return Image.model_construct(url=url, revised_prompt=response.alt)
images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
else:
# Save locally for None (default) case
@@ -554,15 +555,15 @@ class AsyncClient(BaseClient):
def __init__(
self,
provider: Optional[ProviderType] = None,
image_provider: Optional[ImageProvider] = None,
media_provider: Optional[ProviderType] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
self.chat: AsyncChat = AsyncChat(self, provider)
if image_provider is None:
image_provider = provider
self.models: ClientModels = ClientModels(self, provider, image_provider)
self.images: AsyncImages = AsyncImages(self, image_provider)
if media_provider is None:
media_provider = kwargs.get("image_provider", provider)
self.models: ClientModels = ClientModels(self, provider, media_provider)
self.images: AsyncImages = AsyncImages(self, media_provider)
self.media: AsyncImages = self.images
class AsyncChat:

View File

@@ -5,7 +5,7 @@ from typing import Optional, List
from time import time
from ..image import extract_data_uri
from ..image.copy_images import images_dir
from ..image.copy_images import get_media_dir
from ..client.helper import filter_markdown
from .helper import filter_none
@@ -123,7 +123,7 @@ class ChatCompletionMessage(BaseModel):
def save(self, filepath: str, allowd_types = None):
if hasattr(self.content, "data"):
os.rename(self.content.data.replace("/media", images_dir), filepath)
os.rename(self.content.data.replace("/media", get_media_dir()), filepath)
return
if self.content.startswith("data:"):
with open(filepath, "wb") as f:

View File

@@ -6,7 +6,6 @@ from .stubs import ChatCompletion, ChatCompletionChunk
from ..providers.types import BaseProvider
from typing import Union, Iterator, AsyncIterator
ImageProvider = Union[BaseProvider, object]
Proxies = Union[dict, str]
IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]]
AsyncIterResponse = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]
@@ -19,7 +18,7 @@ class Client():
**kwargs
) -> None:
self.api_key: str = api_key
self.proxies= proxies
self.proxies = proxies
self.proxy: str = self.get_proxy()
def get_proxy(self) -> Union[str, None]:

View File

@@ -8,7 +8,7 @@ from flask import send_from_directory
from inspect import signature
from ...errors import VersionNotFoundError, MissingAuthError
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
from ...image.copy_images import copy_media, ensure_media_dir, get_media_dir
from ...tools.run_tools import iter_run_tools
from ... import Provider
from ...providers.base_provider import ProviderModelMixin
@@ -96,8 +96,8 @@ class Api:
}
def serve_images(self, name):
ensure_images_dir()
return send_from_directory(os.path.abspath(images_dir), name)
ensure_media_dir()
return send_from_directory(os.path.abspath(get_media_dir()), name)
def _prepare_conversation_kwargs(self, json_data: dict):
kwargs = {**json_data}

View File

@@ -25,7 +25,7 @@ from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename, get_source_url, images_dir
from ...image.copy_images import secure_filename, get_source_url, get_media_dir
from ... import ChatCompletion
from ... import models
from .api import Api
@@ -346,11 +346,12 @@ class Backend_Api(Api):
@app.route('/search/<search>', methods=['GET'])
def find_media(search: str):
safe_search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
if not os.access(images_dir, os.R_OK):
media_dir = get_media_dir()
if not os.access(media_dir, os.R_OK):
return jsonify({"error": {"message": "Not found"}}), 404
if search not in self.match_files:
self.match_files[search] = {}
for root, _, files in os.walk(images_dir):
for root, _, files in os.walk(media_dir):
for file in files:
mime_type = is_allowed_extension(file)
if mime_type is not None:
@@ -438,7 +439,7 @@ class Backend_Api(Api):
def get_provider_models(self, provider: str):
api_key = request.headers.get("x_api_key")
api_base = request.headers.get("x_api_base")
ignored = request.headers.get("x_ignored").split()
ignored = request.headers.get("x_ignored", "").split()
models = super().get_provider_models(provider, api_key, api_base, ignored)
if models is None:
return "Provider not found", 404

View File

@@ -1,253 +0,0 @@
from __future__ import annotations
import os
import re
import io
import base64
from urllib.parse import quote_plus
from io import BytesIO
from pathlib import Path
try:
from PIL.Image import open as open_image, new as new_image
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
has_requirements = True
except ImportError:
has_requirements = False
from .typing import ImageType, Union, Image, Optional, Cookies
from .errors import MissingRequirementsError
from .requests.aiohttp import get_connector
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
EXTENSIONS_MAP: dict[str, str] = {
"image/png": "png",
"image/jpeg": "jpg",
"image/gif": "gif",
"image/webp": "webp",
}
# Define the directory for generated images
images_dir = "./generated_images"
def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
Converts the input image to a PIL Image object.
Args:
image (Union[str, bytes, Image]): The input image.
Returns:
Image: The converted PIL Image object.
"""
if not has_requirements:
raise MissingRequirementsError('Install "pillow" package for images')
if isinstance(image, str) and image.startswith("data:"):
is_data_uri_an_image(image)
image = extract_data_uri(image)
if is_svg:
try:
import cairosvg
except ImportError:
raise MissingRequirementsError('Install "cairosvg" package for svg images')
if not isinstance(image, bytes):
image = image.read()
buffer = BytesIO()
cairosvg.svg2png(image, write_to=buffer)
return open_image(buffer)
if isinstance(image, bytes):
is_accepted_format(image)
return open_image(BytesIO(image))
elif not isinstance(image, Image):
image = open_image(image)
image.load()
return image
return image
def is_allowed_extension(filename: str) -> bool:
"""
Checks if the given filename has an allowed extension.
Args:
filename (str): The filename to check.
Returns:
bool: True if the extension is allowed, False otherwise.
"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_image(data_uri: str) -> bool:
"""
Checks if the given data URI represents an image.
Args:
data_uri (str): The data URI to check.
Raises:
ValueError: If the data URI is invalid or the image format is not allowed.
"""
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
# Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> str:
"""
Checks if the given binary data represents an image with an accepted format.
Args:
binary_data (bytes): The binary data to check.
Raises:
ValueError: If the image format is not allowed.
"""
if binary_data.startswith(b'\xFF\xD8\xFF'):
return "image/jpeg"
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
return "image/png"
elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
return "image/gif"
elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
return "image/jpeg"
elif binary_data.startswith(b'\xFF\xD8'):
return "image/jpeg"
elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
return "image/webp"
else:
raise ValueError("Invalid image format (from magic code).")
def extract_data_uri(data_uri: str) -> bytes:
"""
Extracts the binary data from the given data URI.
Args:
data_uri (str): The data URI.
Returns:
bytes: The extracted binary data.
"""
data = data_uri.split(",")[-1]
data = base64.b64decode(data)
return data
def get_orientation(image: Image) -> int:
"""
Gets the orientation of the given image.
Args:
image (Image): The image.
Returns:
int: The orientation value.
"""
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
if orientation is not None:
return orientation
def process_image(image: Image, new_width: int, new_height: int) -> Image:
"""
Processes the given image by adjusting its orientation and resizing it.
Args:
image (Image): The image to process.
new_width (int): The new width of the image.
new_height (int): The new height of the image.
Returns:
Image: The processed image.
"""
# Fix orientation
orientation = get_orientation(image)
if orientation:
if orientation > 4:
image = image.transpose(FLIP_LEFT_RIGHT)
if orientation in [3, 4]:
image = image.transpose(ROTATE_180)
if orientation in [5, 6]:
image = image.transpose(ROTATE_270)
if orientation in [7, 8]:
image = image.transpose(ROTATE_90)
# Resize image
image.thumbnail((new_width, new_height))
# Remove transparency
if image.mode == "RGBA":
image.load()
white = new_image('RGB', image.size, (255, 255, 255))
white.paste(image, mask=image.split()[-1])
return white
# Convert to RGB for jpg format
elif image.mode != "RGB":
image = image.convert("RGB")
return image
def to_bytes(image: ImageType) -> bytes:
"""
Converts the given image to bytes.
Args:
image (ImageType): The image to convert.
Returns:
bytes: The image as bytes.
"""
if isinstance(image, bytes):
return image
elif isinstance(image, str) and image.startswith("data:"):
is_data_uri_an_image(image)
return extract_data_uri(image)
elif isinstance(image, Image):
bytes_io = BytesIO()
image.save(bytes_io, image.format)
image.seek(0)
return bytes_io.getvalue()
elif isinstance(image, (str, os.PathLike)):
return Path(image).read_bytes()
elif isinstance(image, Path):
return image.read_bytes()
else:
try:
image.seek(0)
except (AttributeError, io.UnsupportedOperation):
pass
return image.read()
def to_data_uri(image: ImageType) -> str:
if not isinstance(image, str):
data = to_bytes(image)
data_base64 = base64.b64encode(data).decode()
return f"data:{is_accepted_format(data)};base64,{data_base64}"
return image
class ImageDataResponse():
def __init__(
self,
images: Union[str, list],
alt: str,
):
self.images = images
self.alt = alt
def get_list(self) -> list[str]:
return [self.images] if isinstance(self.images, str) else self.images
class ImageRequest:
def __init__(
self,
options: dict = {}
):
self.options = options
def get(self, key: str):
return self.options.get(key)

View File

@@ -17,13 +17,6 @@ except ImportError:
from ..typing import ImageType, Union, Image
from ..errors import MissingRequirementsError
MEDIA_TYPE_MAP: dict[str, str] = {
"image/png": "png",
"image/jpeg": "jpg",
"image/gif": "gif",
"image/webp": "webp",
}
EXTENSIONS_MAP: dict[str, str] = {
# Image
"png": "image/png",
@@ -44,6 +37,8 @@ EXTENSIONS_MAP: dict[str, str] = {
"mp4": "video/mp4",
}
MEDIA_TYPE_MAP: dict[str, str] = {value: key for key, value in EXTENSIONS_MAP.items()}
def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
Converts the input image to a PIL Image object.
@@ -82,6 +77,12 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
return image
def get_extension(filename: str) -> Optional[str]:
if '.' in filename:
ext = os.path.splitext(filename)[1][1:].lower()
return ext if ext in EXTENSIONS_MAP else None
return None
def is_allowed_extension(filename: str) -> Optional[str]:
"""
Checks if the given filename has an allowed extension.
@@ -92,8 +93,10 @@ def is_allowed_extension(filename: str) -> Optional[str]:
Returns:
bool: True if the extension is allowed, False otherwise.
"""
ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None
return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None
extension = get_extension(filename)
if extension is None:
return None
return EXTENSIONS_MAP[extension]
def is_data_an_media(data, filename: str = None) -> str:
content_type = is_data_an_audio(data, filename)
@@ -105,12 +108,11 @@ def is_data_an_media(data, filename: str = None) -> str:
def is_data_an_audio(data_uri: str = None, filename: str = None) -> str:
if filename:
if filename.endswith(".wav"):
return "audio/wav"
elif filename.endswith(".mp3"):
return "audio/mpeg"
elif filename.endswith(".m4a"):
return "audio/m4a"
extension = get_extension(filename)
if extension is not None:
media_type = EXTENSIONS_MAP[extension]
if media_type.startswith("audio/"):
return media_type
if isinstance(data_uri, str):
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
if audio_format:
@@ -266,10 +268,13 @@ def to_data_uri(image: ImageType, filename: str = None) -> str:
def to_input_audio(audio: ImageType, filename: str = None) -> str:
if not isinstance(audio, str):
if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
if filename is not None:
format = get_extension(filename)
if format is None:
raise ValueError("Invalid input audio")
return {
"data": base64.b64encode(to_bytes(audio)).decode(),
"format": "wav" if filename.endswith(".wav") else "mp3"
"format": format
}
raise ValueError("Invalid input audio")
audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)

View File

@@ -21,6 +21,13 @@ from .. import debug
# Directory for storing generated images
images_dir = "./generated_images"
media_dir = "./generated_media"
def get_media_dir() -> str:#
"""Get the directory for storing generated media files"""
if os.access(images_dir, os.R_OK):
return images_dir
return media_dir
def get_media_extension(media: str) -> str:
"""Extract media file extension from URL or filename"""
@@ -34,9 +41,10 @@ def get_media_extension(media: str) -> str:
raise ValueError(f"Unsupported media extension: {extension} in: {media}")
return extension
def ensure_images_dir():
def ensure_media_dir():
"""Create images directory if it doesn't exist"""
os.makedirs(images_dir, exist_ok=True)
if not os.access(images_dir, os.R_OK):
os.makedirs(media_dir, exist_ok=True)
def get_source_url(image: str, default: str = None) -> str:
"""Extract original URL from image parameter if present"""
@@ -46,30 +54,27 @@ def get_source_url(image: str, default: str = None) -> str:
return decoded_url
return default
def is_valid_media_type(content_type: str) -> bool:
return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/")
async def save_response_media(response: StreamResponse, prompt: str, tags: list[str]) -> AsyncIterator:
"""Save media from response to local file and return URL"""
content_type = response.headers["content-type"]
if is_valid_media_type(content_type):
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
if extension not in EXTENSIONS_MAP:
raise ValueError(f"Unsupported media type: {content_type}")
filename = get_filename(tags, prompt, f".{extension}", prompt)
target_path = os.path.join(images_dir, filename)
with open(target_path, 'wb') as f:
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
f.write(chunk)
media_url = f"/media/{filename}"
if response.method == "GET":
media_url = f"{media_url}?url={str(response.url)}"
if content_type.startswith("audio/"):
yield AudioResponse(media_url)
elif content_type.startswith("video/"):
yield VideoResponse(media_url, prompt)
else:
yield ImageResponse(media_url, prompt)
extension = MEDIA_TYPE_MAP.get(content_type)
if extension is None:
raise ValueError(f"Unsupported media type: {content_type}")
filename = get_filename(tags, prompt, f".{extension}", prompt)
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
with open(target_path, 'wb') as f:
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
f.write(chunk)
media_url = f"/media/{filename}"
if response.method == "GET":
media_url = f"{media_url}?url={str(response.url)}"
if content_type.startswith("audio/"):
yield AudioResponse(media_url)
elif content_type.startswith("video/"):
yield VideoResponse(media_url, prompt)
else:
yield ImageResponse(media_url, prompt)
def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
return "".join((
@@ -97,7 +102,7 @@ async def copy_media(
"""
if add_url:
add_url = not cookies
ensure_images_dir()
ensure_media_dir()
async with ClientSession(
connector=get_connector(proxy=proxy),
@@ -113,7 +118,7 @@ async def copy_media(
if target_path is None:
# Build safe filename with full Unicode support
filename = get_filename(tags, alt, get_media_extension(image), image)
target_path = os.path.join(images_dir, filename)
target_path = os.path.join(get_media_dir(), filename)
try:
# Handle different image types
if image.startswith("data:"):
@@ -132,7 +137,7 @@ async def copy_media(
response.raise_for_status()
media_type = response.headers.get("content-type", "application/octet-stream")
if media_type not in ("application/octet-stream", "binary/octet-stream"):
if not is_valid_media_type(media_type):
if media_type not in MEDIA_TYPE_MAP:
raise ValueError(f"Unsupported media type: {media_type}")
with open(target_path, "wb") as f:
async for chunk in response.content.iter_any():

View File

@@ -1006,6 +1006,6 @@ __models__ = {
if model.best_provider is not None and model.best_provider.working
else [])
for model in ModelUtils.convert.values()]
if [p for p in providers if p.working]
if model.name and [True for provider in providers if provider.working]
}
_all_models = list(__models__.keys())

View File

@@ -37,7 +37,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
cls.models_count = {
model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1
}
all_models = ["default"] + list(model_with_providers.keys())
all_models = [cls.default_model] + list(model_with_providers.keys())
for provider in [OpenaiChat, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok]:
if not provider.working or getattr(provider, "parent", provider.__name__) in ignored:
continue
@@ -63,6 +63,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
).replace("-03-2025", ""
).replace("-20250219", ""
).replace("-20241022", ""
).replace("-20240904", ""
).replace("-2025-04-16", ""
).replace("-2025-04-14", ""
).replace("-0125", ""
@@ -72,10 +73,13 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
).replace("-2409", ""
).replace("-2410", ""
).replace("-2411", ""
).replace("-1119", ""
).replace("-0919", ""
).replace("-02-24", ""
).replace("-03-25", ""
).replace("-03-26", ""
).replace("-01-21", ""
).replace("-002", ""
).replace(".1-", "-"
).replace("_", "."
).replace("c4ai-", ""
@@ -98,8 +102,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
for provider in [Microsoft_Phi_4, PollinationsAI]:
if provider.working and getattr(provider, "parent", provider.__name__) not in ignored:
cls.audio_models.update(provider.audio_models)
cls.models_count.update({model: all_models.count(model) + cls.models_count.get(model, 0) for model in all_models})
return list(dict.fromkeys([model if model else "default" for model in all_models]))
cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)})
return list(dict.fromkeys([model if model else cls.default_model for model in all_models]))
@classmethod
async def create_async_generator(
@@ -117,7 +121,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
providers = model.split(":")
model = providers.pop()
providers = [getattr(Provider, provider) for provider in providers]
elif not model or model == "default":
elif not model or model == cls.default_model:
model = ""
has_image = False
has_audio = "audio" in kwargs
if not has_audio and media is not None:
@@ -133,11 +138,11 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
else:
providers = models.default.best_provider.providers
else:
for provider in [OpenaiChat, HuggingSpace, Cloudflare, LMArenaProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox]:
if provider.working and (model if model else "auto") in provider.get_models():
providers.append(provider)
for provider in [HuggingFace, HuggingFaceMedia, LambdaChat, LMArenaProvider, CopilotAccount, PollinationsAI, DeepInfraChat]:
if model in provider.model_aliases:
for provider in [
OpenaiChat, Cloudflare, LMArenaProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox,
HuggingFace, HuggingFaceMedia, HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat
]:
if not model or model in provider.get_models() or model in provider.model_aliases:
providers.append(provider)
if model in models.__models__:
for provider in models.__models__[model][1]:

View File

@@ -449,10 +449,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
cache_file = cls.get_cache_file()
try:
if cache_file.exists():
with cache_file.open("r") as f:
data = f.read()
if data:
auth_result = AuthResult(**json.loads(data))
try:
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
except json.JSONDecodeError:
cache_file.unlink()
raise MissingAuthError(f"Invalid auth file: {cache_file}")
else:
raise MissingAuthError
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
@@ -478,8 +480,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
cache_file = cls.get_cache_file()
try:
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
try:
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
except json.JSONDecodeError:
cache_file.unlink()
raise MissingAuthError(f"Invalid auth file: {cache_file}")
else:
raise MissingAuthError
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))

View File

@@ -264,9 +264,10 @@ class YouTube(HiddenResponse):
]))
class AudioResponse(ResponseType):
def __init__(self, data: Union[bytes, str]) -> None:
def __init__(self, data: Union[bytes, str], **kwargs) -> None:
"""Initialize with audio data bytes."""
self.data = data
self.options = kwargs
def to_uri(self) -> str:
if isinstance(self.data, str):