feat: add audio generation support for multiple providers

- Added new examples for `client.media.generate` with `PollinationsAI`, `EdgeTTS`, and `Gemini` in `docs/media.md`
- Modified `PollinationsAI.py` to default to `default_audio_model` when audio data is present
- Adjusted `PollinationsAI.py` to conditionally construct message list from `prompt` when media is being generated
- Rearranged `PollinationsAI.py` response handling to yield `save_response_media` after checking for non-JSON content types
- Added support in `EdgeTTS.py` to use default values for `language`, `locale`, and `format` from class attributes
- Improved voice selection logic in `EdgeTTS.py` to fallback to default locale or language when not explicitly provided
- Updated `EdgeTTS.py` to yield `AudioResponse` with `text` field included
- Modified `Gemini.py` to support `.ogx` audio generation when `model == "gemini-audio"` or `audio` is passed
- Used `format_image_prompt` in `Gemini.py` to create audio prompt and saved audio file using `synthesize`
- Appended `AudioResponse` to `Gemini.py` for audio generation flow
- Added `save()` method to `Image` class in `stubs.py` to support saving `/media/` files locally
- Changed `client/__init__.py` to fallback to `options["text"]` if `alt` is missing in `Images.create`
- Ensured `AudioResponse` in `copy_images.py` includes the `text` (prompt) field
- Added `Annotated` fallback definition in `api/__init__.py` for compatibility with older Python versions
This commit is contained in:
hlohaus
2025-04-19 06:23:46 +02:00
parent 2f46008228
commit b68b9ff6be
8 changed files with 90 additions and 34 deletions

View File

@@ -28,6 +28,30 @@ async def main():
asyncio.run(main()) asyncio.run(main())
``` ```
#### **More examples for Generate Audio:**
```python
from g4f.client import Client
from g4f.Provider import EdgeTTS, Gemini, PollinationsAI
client = Client(provider=PollinationsAI)
response = client.media.generate("Hello", audio={"voice": "alloy", "format": "mp3"})
response.data[0].save("openai.mp3")
client = Client(provider=PollinationsAI)
response = client.media.generate("Hello", model="hypnosis-tracy")
response.data[0].save("hypnosis.mp3")
client = Client(provider=Gemini)
response = client.media.generate("Hello", model="gemini-audio")
response.data[0].save("gemini.ogx")
client = Client(provider=EdgeTTS)
response = client.media.generate("Hello", audio={"locale": "en-US"})
response.data[0].save("edge-tts.mp3")
```
#### **Transcribe an Audio File:** #### **Transcribe an Audio File:**
Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously: Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously:

View File

@@ -177,7 +177,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if is_data_an_audio(media_data, filename): if is_data_an_audio(media_data, filename):
has_audio = True has_audio = True
break break
model = next(iter(cls.audio_models)) if has_audio else model model = cls.default_audio_model if has_audio else model
try: try:
model = cls.get_model(model) model = cls.get_model(model)
except ModelNotFoundError: except ModelNotFoundError:
@@ -202,6 +202,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
): ):
yield chunk yield chunk
else: else:
if prompt is not None and len(messages) == 1:
messages = [{
"role": "user",
"content": prompt
}]
async for result in cls._generate_text( async for result in cls._generate_text(
model=model, model=model,
messages=messages, messages=messages,
@@ -315,9 +320,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
}) })
async with session.post(url, json=data) as response: async with session.post(url, json=data) as response:
await raise_for_status(response) await raise_for_status(response)
async for chunk in save_response_media(response, format_image_prompt(messages), [model]):
yield chunk
return
if response.headers["content-type"].startswith("text/plain"): if response.headers["content-type"].startswith("text/plain"):
yield await response.text() yield await response.text()
return return
@@ -339,20 +341,24 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
finish_reason = choice.get("finish_reason") finish_reason = choice.get("finish_reason")
if finish_reason: if finish_reason:
yield FinishReason(finish_reason) yield FinishReason(finish_reason)
return elif response.headers["content-type"].startswith("application/json"):
result = await response.json() result = await response.json()
if "choices" in result: if "choices" in result:
choice = result["choices"][0] choice = result["choices"][0]
message = choice.get("message", {}) message = choice.get("message", {})
content = message.get("content", "") content = message.get("content", "")
if content: if content:
yield content yield content
if "tool_calls" in message: if "tool_calls" in message:
yield ToolCalls(message["tool_calls"]) yield ToolCalls(message["tool_calls"])
else:
raise ResponseError(result)
if result.get("usage") is not None:
yield Usage(**result["usage"])
finish_reason = choice.get("finish_reason")
if finish_reason:
yield FinishReason(finish_reason)
else: else:
raise ResponseError(result) async for chunk in save_response_media(response, format_image_prompt(messages), [model]):
if result.get("usage") is not None: yield chunk
yield Usage(**result["usage"]) return
finish_reason = choice.get("finish_reason")
if finish_reason:
yield FinishReason(finish_reason)

View File

@@ -20,8 +20,9 @@ from ..helper import format_image_prompt
class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin): class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
label = "Edge TTS" label = "Edge TTS"
working = has_edge_tts working = has_edge_tts
default_model = "edge-tts" default_language = "en"
default_locale = "en-US" default_locale = "en-US"
default_format = "mp3"
@classmethod @classmethod
def get_models(cls) -> list[str]: def get_models(cls) -> list[str]:
@@ -38,29 +39,29 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages, messages: Messages,
proxy: str = None, proxy: str = None,
prompt: str = None, prompt: str = None,
audio: dict = {"voice": None, "format": "mp3"}, audio: dict = {},
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
prompt = format_image_prompt(messages, prompt) prompt = format_image_prompt(messages, prompt)
if not prompt: if not prompt:
raise ValueError("Prompt is empty.") raise ValueError("Prompt is empty.")
voice = audio.get("voice", model) voice = audio.get("voice", model if model and model != "edge-tts" else None)
if not voice: if not voice:
voices = await VoicesManager.create() voices = await VoicesManager.create()
if "locale" in audio: if "locale" in audio:
voices = voices.find(Locale=audio["locale"]) voices = voices.find(Locale=audio["locale"])
elif "language" in audio: elif audio.get("language", cls.default_language) != cls.default_language:
if "-" in audio["language"]: if "-" in audio.get("language"):
voices = voices.find(Locale=audio["language"]) voices = voices.find(Locale=audio.get("language"))
else: else:
voices = voices.find(Language=audio["language"]) voices = voices.find(Language=audio.get("language"))
else: else:
voices = voices.find(Locale=cls.default_locale) voices = voices.find(Locale=cls.default_locale)
if not voices: if not voices:
raise ValueError(f"No voices found for language '{audio.get('language')}' and locale '{audio.get('locale')}'.") raise ValueError(f"No voices found for language '{audio.get('language')}' and locale '{audio.get('locale')}'.")
voice = random.choice(voices)["Name"] voice = random.choice(voices)["Name"]
format = audio.get("format", "mp3") format = audio.get("format", cls.default_format)
filename = get_filename([cls.default_model], prompt, f".{format}", prompt) filename = get_filename([cls.default_model], prompt, f".{format}", prompt)
target_path = os.path.join(get_media_dir(), filename) target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir() ensure_media_dir()
@@ -69,4 +70,4 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
communicate = edge_tts.Communicate(prompt, voice=voice, proxy=proxy, **extra_parameters) communicate = edge_tts.Communicate(prompt, voice=voice, proxy=proxy, **extra_parameters)
await communicate.save(target_path) await communicate.save(target_path)
yield AudioResponse(f"/media/{filename}", voice=voice, prompt=prompt) yield AudioResponse(f"/media/{filename}", voice=voice, text=prompt)

View File

@@ -20,16 +20,17 @@ except ImportError:
from ... import debug from ... import debug
from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator
from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube, AudioResponse
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector from ...requests.aiohttp import get_connector
from ...requests import get_nodriver from ...requests import get_nodriver
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
from ...errors import MissingAuthError from ...errors import MissingAuthError
from ...image import to_bytes from ...image import to_bytes
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ...tools.media import merge_media from ...tools.media import merge_media
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies, get_last_user_message from ..helper import format_prompt, get_cookies, get_last_user_message, format_image_prompt
from ... import debug from ... import debug
REQUEST_HEADERS = { REQUEST_HEADERS = {
@@ -68,6 +69,7 @@ models = {
"gemini-2.0-flash-exp": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f299729663a2343f"]'}, "gemini-2.0-flash-exp": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f299729663a2343f"]'},
"gemini-2.0-flash-thinking": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"9c17b1863f581b8a"]'}, "gemini-2.0-flash-thinking": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"9c17b1863f581b8a"]'},
"gemini-2.0-flash-thinking-with-apps": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f8f8f5ea629f5d37"]'}, "gemini-2.0-flash-thinking-with-apps": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f8f8f5ea629f5d37"]'},
"gemini-audio": {}
} }
class Gemini(AsyncGeneratorProvider, ProviderModelMixin): class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
@@ -153,8 +155,20 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
return_conversation: bool = False, return_conversation: bool = False,
conversation: Conversation = None, conversation: Conversation = None,
language: str = "en", language: str = "en",
prompt: str = None,
audio: dict = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if audio is not None or model == "gemini-audio":
prompt = format_image_prompt(messages, prompt)
filename = get_filename(["gemini"], prompt, ".ogx", prompt)
ensure_media_dir()
path = os.path.join(get_media_dir(), filename)
with open(path, "wb") as f:
async for chunk in cls.synthesize({"text": prompt}, proxy):
f.write(chunk)
yield AudioResponse(f"/media/{filename}", text=prompt)
return
cls._cookies = cookies or cls._cookies or get_cookies(GOOGLE_COOKIE_DOMAIN, False, True) cls._cookies = cookies or cls._cookies or get_cookies(GOOGLE_COOKIE_DOMAIN, False, True)
if conversation is not None and getattr(conversation, "model", None) != model: if conversation is not None and getattr(conversation, "model", None) != model:
conversation = None conversation = None

View File

@@ -33,6 +33,12 @@ from starlette.responses import FileResponse
from types import SimpleNamespace from types import SimpleNamespace
from typing import Union, Optional, List from typing import Union, Optional, List
try:
from typing import Annotated
except ImportError:
class Annotated:
pass
import g4f import g4f
import g4f.Provider import g4f.Provider
import g4f.debug import g4f.debug
@@ -52,7 +58,7 @@ from .stubs import (
ChatCompletionsConfig, ImageGenerationConfig, ChatCompletionsConfig, ImageGenerationConfig,
ProviderResponseModel, ModelResponseModel, ProviderResponseModel, ModelResponseModel,
ErrorResponseModel, ProviderResponseDetailModel, ErrorResponseModel, ProviderResponseDetailModel,
FileResponseModel, UploadResponseModel, Annotated FileResponseModel, UploadResponseModel
) )
from g4f import debug from g4f import debug

View File

@@ -463,7 +463,8 @@ class Images:
urls.extend(item.urls) urls.extend(item.urls)
if not urls: if not urls:
return None return None
return MediaResponse(urls, items[0].alt, items[0].options) alt = getattr(items[0], "alt", items[0].options.get("text"))
return MediaResponse(urls, alt, items[0].options)
def create_variation( def create_variation(
self, self,

View File

@@ -217,6 +217,10 @@ class Image(BaseModel):
revised_prompt=revised_prompt revised_prompt=revised_prompt
)) ))
def save(self, path: str):
if self.url is not None and self.url.startswith("/media/"):
os.rename(self.url.replace("/media", get_media_dir()), path)
class ImagesResponse(BaseModel): class ImagesResponse(BaseModel):
data: List[Image] data: List[Image]
model: str model: str

View File

@@ -70,7 +70,7 @@ async def save_response_media(response: StreamResponse, prompt: str, tags: list[
if response.method == "GET": if response.method == "GET":
media_url = f"{media_url}?url={str(response.url)}" media_url = f"{media_url}?url={str(response.url)}"
if content_type.startswith("audio/"): if content_type.startswith("audio/"):
yield AudioResponse(media_url) yield AudioResponse(media_url, text=prompt)
elif content_type.startswith("video/"): elif content_type.startswith("video/"):
yield VideoResponse(media_url, prompt) yield VideoResponse(media_url, prompt)
else: else: