From b68b9ff6be4fb44543bb63e5ec9d3b5b14b64fcd Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Sat, 19 Apr 2025 06:23:46 +0200 Subject: [PATCH] feat: add audio generation support for multiple providers - Added new examples for `client.media.generate` with `PollinationsAI`, `EdgeTTS`, and `Gemini` in `docs/media.md` - Modified `PollinationsAI.py` to default to `default_audio_model` when audio data is present - Adjusted `PollinationsAI.py` to conditionally construct message list from `prompt` when media is being generated - Rearranged `PollinationsAI.py` response handling to yield `save_response_media` after checking for non-JSON content types - Added support in `EdgeTTS.py` to use default values for `language`, `locale`, and `format` from class attributes - Improved voice selection logic in `EdgeTTS.py` to fallback to default locale or language when not explicitly provided - Updated `EdgeTTS.py` to yield `AudioResponse` with `text` field included - Modified `Gemini.py` to support `.ogx` audio generation when `model == "gemini-audio"` or `audio` is passed - Used `format_image_prompt` in `Gemini.py` to create audio prompt and saved audio file using `synthesize` - Appended `AudioResponse` to `Gemini.py` for audio generation flow - Added `save()` method to `Image` class in `stubs.py` to support saving `/media/` files locally - Changed `client/__init__.py` to fallback to `options["text"]` if `alt` is missing in `Images.create` - Ensured `AudioResponse` in `copy_images.py` includes the `text` (prompt) field - Added `Annotated` fallback definition in `api/__init__.py` for compatibility with older Python versions --- docs/media.md | 24 ++++++++++++++++ g4f/Provider/PollinationsAI.py | 46 +++++++++++++++++-------------- g4f/Provider/audio/EdgeTTS.py | 19 +++++++------ g4f/Provider/needs_auth/Gemini.py | 18 ++++++++++-- g4f/api/__init__.py | 8 +++++- g4f/client/__init__.py | 3 +- g4f/client/stubs.py | 4 +++ g4f/image/copy_images.py | 2 +- 8 files changed, 90 insertions(+), 34 deletions(-) diff --git a/docs/media.md b/docs/media.md index 0083018a..b7979f3f 100644 --- a/docs/media.md +++ b/docs/media.md @@ -28,6 +28,30 @@ async def main(): asyncio.run(main()) ``` +#### **More examples for Generate Audio:** + +```python +from g4f.client import Client + +from g4f.Provider import EdgeTTS, Gemini, PollinationsAI + +client = Client(provider=PollinationsAI) +response = client.media.generate("Hello", audio={"voice": "alloy", "format": "mp3"}) +response.data[0].save("openai.mp3") + +client = Client(provider=PollinationsAI) +response = client.media.generate("Hello", model="hypnosis-tracy") +response.data[0].save("hypnosis.mp3") + +client = Client(provider=Gemini) +response = client.media.generate("Hello", model="gemini-audio") +response.data[0].save("gemini.ogx") + +client = Client(provider=EdgeTTS) +response = client.media.generate("Hello", audio={"locale": "en-US"}) +response.data[0].save("edge-tts.mp3") +``` + #### **Transcribe an Audio File:** Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously: diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 42e05017..d5f2af4f 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -177,7 +177,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): if is_data_an_audio(media_data, filename): has_audio = True break - model = next(iter(cls.audio_models)) if has_audio else model + model = cls.default_audio_model if has_audio else model try: model = cls.get_model(model) except ModelNotFoundError: @@ -202,6 +202,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): ): yield chunk else: + if prompt is not None and len(messages) == 1: + messages = [{ + "role": "user", + "content": prompt + }] async for result in cls._generate_text( model=model, messages=messages, @@ -315,9 +320,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): }) async with session.post(url, json=data) as response: await raise_for_status(response) - async for chunk in save_response_media(response, format_image_prompt(messages), [model]): - yield chunk - return if response.headers["content-type"].startswith("text/plain"): yield await response.text() return @@ -339,20 +341,24 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): finish_reason = choice.get("finish_reason") if finish_reason: yield FinishReason(finish_reason) - return - result = await response.json() - if "choices" in result: - choice = result["choices"][0] - message = choice.get("message", {}) - content = message.get("content", "") - if content: - yield content - if "tool_calls" in message: - yield ToolCalls(message["tool_calls"]) + elif response.headers["content-type"].startswith("application/json"): + result = await response.json() + if "choices" in result: + choice = result["choices"][0] + message = choice.get("message", {}) + content = message.get("content", "") + if content: + yield content + if "tool_calls" in message: + yield ToolCalls(message["tool_calls"]) + else: + raise ResponseError(result) + if result.get("usage") is not None: + yield Usage(**result["usage"]) + finish_reason = choice.get("finish_reason") + if finish_reason: + yield FinishReason(finish_reason) else: - raise ResponseError(result) - if result.get("usage") is not None: - yield Usage(**result["usage"]) - finish_reason = choice.get("finish_reason") - if finish_reason: - yield FinishReason(finish_reason) + async for chunk in save_response_media(response, format_image_prompt(messages), [model]): + yield chunk + return diff --git a/g4f/Provider/audio/EdgeTTS.py b/g4f/Provider/audio/EdgeTTS.py index 1a873f44..9159010a 100644 --- a/g4f/Provider/audio/EdgeTTS.py +++ b/g4f/Provider/audio/EdgeTTS.py @@ -20,8 +20,9 @@ from ..helper import format_image_prompt class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin): label = "Edge TTS" working = has_edge_tts - default_model = "edge-tts" + default_language = "en" default_locale = "en-US" + default_format = "mp3" @classmethod def get_models(cls) -> list[str]: @@ -38,29 +39,29 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin): messages: Messages, proxy: str = None, prompt: str = None, - audio: dict = {"voice": None, "format": "mp3"}, + audio: dict = {}, **kwargs ) -> AsyncResult: prompt = format_image_prompt(messages, prompt) if not prompt: raise ValueError("Prompt is empty.") - voice = audio.get("voice", model) + voice = audio.get("voice", model if model and model != "edge-tts" else None) if not voice: voices = await VoicesManager.create() if "locale" in audio: voices = voices.find(Locale=audio["locale"]) - elif "language" in audio: - if "-" in audio["language"]: - voices = voices.find(Locale=audio["language"]) + elif audio.get("language", cls.default_language) != cls.default_language: + if "-" in audio.get("language"): + voices = voices.find(Locale=audio.get("language")) else: - voices = voices.find(Language=audio["language"]) + voices = voices.find(Language=audio.get("language")) else: voices = voices.find(Locale=cls.default_locale) if not voices: raise ValueError(f"No voices found for language '{audio.get('language')}' and locale '{audio.get('locale')}'.") voice = random.choice(voices)["Name"] - format = audio.get("format", "mp3") + format = audio.get("format", cls.default_format) filename = get_filename([cls.default_model], prompt, f".{format}", prompt) target_path = os.path.join(get_media_dir(), filename) ensure_media_dir() @@ -69,4 +70,4 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin): communicate = edge_tts.Communicate(prompt, voice=voice, proxy=proxy, **extra_parameters) await communicate.save(target_path) - yield AudioResponse(f"/media/{filename}", voice=voice, prompt=prompt) \ No newline at end of file + yield AudioResponse(f"/media/{filename}", voice=voice, text=prompt) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index 554ecf70..78c27dbb 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -20,16 +20,17 @@ except ImportError: from ... import debug from ...typing import Messages, Cookies, MediaListType, AsyncResult, AsyncIterator -from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube +from ...providers.response import JsonConversation, Reasoning, RequestLogin, ImageResponse, YouTube, AudioResponse from ...requests.raise_for_status import raise_for_status from ...requests.aiohttp import get_connector from ...requests import get_nodriver +from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir from ...errors import MissingAuthError from ...image import to_bytes from ...cookies import get_cookies_dir from ...tools.media import merge_media from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..helper import format_prompt, get_cookies, get_last_user_message +from ..helper import format_prompt, get_cookies, get_last_user_message, format_image_prompt from ... import debug REQUEST_HEADERS = { @@ -68,6 +69,7 @@ models = { "gemini-2.0-flash-exp": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f299729663a2343f"]'}, "gemini-2.0-flash-thinking": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"9c17b1863f581b8a"]'}, "gemini-2.0-flash-thinking-with-apps": {"x-goog-ext-525001261-jspb": '[null,null,null,null,"f8f8f5ea629f5d37"]'}, + "gemini-audio": {} } class Gemini(AsyncGeneratorProvider, ProviderModelMixin): @@ -153,8 +155,20 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin): return_conversation: bool = False, conversation: Conversation = None, language: str = "en", + prompt: str = None, + audio: dict = None, **kwargs ) -> AsyncResult: + if audio is not None or model == "gemini-audio": + prompt = format_image_prompt(messages, prompt) + filename = get_filename(["gemini"], prompt, ".ogx", prompt) + ensure_media_dir() + path = os.path.join(get_media_dir(), filename) + with open(path, "wb") as f: + async for chunk in cls.synthesize({"text": prompt}, proxy): + f.write(chunk) + yield AudioResponse(f"/media/{filename}", text=prompt) + return cls._cookies = cookies or cls._cookies or get_cookies(GOOGLE_COOKIE_DOMAIN, False, True) if conversation is not None and getattr(conversation, "model", None) != model: conversation = None diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 81d970c5..28587023 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -33,6 +33,12 @@ from starlette.responses import FileResponse from types import SimpleNamespace from typing import Union, Optional, List +try: + from typing import Annotated +except ImportError: + class Annotated: + pass + import g4f import g4f.Provider import g4f.debug @@ -52,7 +58,7 @@ from .stubs import ( ChatCompletionsConfig, ImageGenerationConfig, ProviderResponseModel, ModelResponseModel, ErrorResponseModel, ProviderResponseDetailModel, - FileResponseModel, UploadResponseModel, Annotated + FileResponseModel, UploadResponseModel ) from g4f import debug diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index fc967bca..e813d0e7 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -463,7 +463,8 @@ class Images: urls.extend(item.urls) if not urls: return None - return MediaResponse(urls, items[0].alt, items[0].options) + alt = getattr(items[0], "alt", items[0].options.get("text")) + return MediaResponse(urls, alt, items[0].options) def create_variation( self, diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py index 8e183566..4175c92e 100644 --- a/g4f/client/stubs.py +++ b/g4f/client/stubs.py @@ -217,6 +217,10 @@ class Image(BaseModel): revised_prompt=revised_prompt )) + def save(self, path: str): + if self.url is not None and self.url.startswith("/media/"): + os.rename(self.url.replace("/media", get_media_dir()), path) + class ImagesResponse(BaseModel): data: List[Image] model: str diff --git a/g4f/image/copy_images.py b/g4f/image/copy_images.py index 9405aaa0..2f400459 100644 --- a/g4f/image/copy_images.py +++ b/g4f/image/copy_images.py @@ -70,7 +70,7 @@ async def save_response_media(response: StreamResponse, prompt: str, tags: list[ if response.method == "GET": media_url = f"{media_url}?url={str(response.url)}" if content_type.startswith("audio/"): - yield AudioResponse(media_url) + yield AudioResponse(media_url, text=prompt) elif content_type.startswith("video/"): yield VideoResponse(media_url, prompt) else: