mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-08 17:50:16 +08:00
feat: enhance audio model handling and improve image URL resolution
- Updated `PollinationsAI` to exclude "gemini" model from `audio_models` - Added logic in `PollinationsAI` to expand `audio_models` with voices from `default_audio_model` - Appended voice names to `text_models` list in `PollinationsAI` if present in `default_audio_model` - Modified `PollinationsAI._generate_text` to inject `audio` parameters when a voice model is used - Updated `save_response_media` call to include voice name in model list - Changed `OpenaiChat.get_generated_image` to support both `file-service://` and `sediment://` URLs using `conversation_id` - Modified `OpenaiChat.create_messages` to optionally pass `prompt` - Adjusted `OpenaiChat.run` to determine `prompt` explicitly and set messages accordingly - Updated `OpenaiChat.iter_messages_line` to handle `None` in `fields.p` safely - Passed `prompt` and `conversation_id` to `OpenaiChat.get_generated_image` inside image parsing loop - Fixed redirect logic in `backend_api.py` to safely handle missing `skip` query param - Enhanced `render` function in `website.py` to support live file serving with `live` query param - Added new route `/dist/<path:name>` to serve static files from `DIST_DIR` in `website.py` - Adjusted `render` to include `.live` suffix in cache filename when applicable - Modified HTML replacements in `render` to preserve local `dist/` path if `add_origion` is True
This commit is contained in:
@@ -105,9 +105,12 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cls.audio_models = {
|
cls.audio_models = {
|
||||||
model.get("name"): model.get("voices")
|
model.get("name"): model.get("voices")
|
||||||
for model in models
|
for model in models
|
||||||
if "output_modalities" in model and "audio" in model["output_modalities"]
|
if "output_modalities" in model and "audio" in model["output_modalities"] and model.get("name") != "gemini"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cls.default_audio_model in cls.audio_models:
|
||||||
|
cls.audio_models = {**cls.audio_models, **{voice: {} for voice in cls.audio_models[cls.default_audio_model]}}
|
||||||
|
|
||||||
# Create a set of unique text models starting with default model
|
# Create a set of unique text models starting with default model
|
||||||
unique_text_models = cls.text_models.copy()
|
unique_text_models = cls.text_models.copy()
|
||||||
|
|
||||||
@@ -120,6 +123,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if model_name and "input_modalities" in model and "text" in model["input_modalities"]:
|
if model_name and "input_modalities" in model and "text" in model["input_modalities"]:
|
||||||
unique_text_models.append(model_name)
|
unique_text_models.append(model_name)
|
||||||
|
|
||||||
|
if cls.default_audio_model in cls.audio_models:
|
||||||
|
unique_text_models.extend([voice for voice in cls.audio_models[cls.default_audio_model]])
|
||||||
|
|
||||||
# Convert to list and update text_models
|
# Convert to list and update text_models
|
||||||
cls.text_models = list(dict.fromkeys(unique_text_models))
|
cls.text_models = list(dict.fromkeys(unique_text_models))
|
||||||
|
|
||||||
@@ -207,6 +213,11 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": prompt
|
"content": prompt
|
||||||
}]
|
}]
|
||||||
|
if model and model in cls.audio_models[cls.default_audio_model]:
|
||||||
|
kwargs["audio"] = {
|
||||||
|
"voice": model,
|
||||||
|
}
|
||||||
|
model = cls.default_audio_model
|
||||||
async for result in cls._generate_text(
|
async for result in cls._generate_text(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -359,6 +370,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if finish_reason:
|
if finish_reason:
|
||||||
yield FinishReason(finish_reason)
|
yield FinishReason(finish_reason)
|
||||||
else:
|
else:
|
||||||
async for chunk in save_response_media(response, format_image_prompt(messages), [model]):
|
async for chunk in save_response_media(response, format_image_prompt(messages), [model, extra_parameters.get("audio", {}).get("voice")]):
|
||||||
yield chunk
|
yield chunk
|
||||||
return
|
return
|
||||||
|
@@ -254,16 +254,25 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_generated_image(cls, session: StreamSession, auth_result: AuthResult, element: dict, prompt: str = None) -> ImageResponse:
|
async def get_generated_image(cls, session: StreamSession, auth_result: AuthResult, element: dict, prompt: str, conversation_id: str) -> ImageResponse:
|
||||||
try:
|
try:
|
||||||
prompt = element["metadata"]["dalle"]["prompt"]
|
prompt = element["metadata"]["dalle"]["prompt"]
|
||||||
file_id = element["asset_pointer"].split("file-service://", 1)[1]
|
except IndexError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
file_id = element["asset_pointer"]
|
||||||
|
if "file-service://" in file_id:
|
||||||
|
file_id = file_id.split("file-service://", 1)[-1]
|
||||||
|
url = f"{cls.url}/backend-api/files/{file_id}/download"
|
||||||
|
else:
|
||||||
|
file_id = file_id.split("sediment://")[-1]
|
||||||
|
url = f"{cls.url}/backend-api/conversation/{conversation_id}/attachment/{file_id}/download"
|
||||||
except TypeError:
|
except TypeError:
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"No Image: {e.__class__.__name__}: {e}")
|
raise RuntimeError(f"No Image: {element} - {e}")
|
||||||
try:
|
try:
|
||||||
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=auth_result.headers) as response:
|
async with session.get(url, headers=auth_result.headers) as response:
|
||||||
cls._update_request_args(auth_result, session)
|
cls._update_request_args(auth_result, session)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
download_url = (await response.json())["download_url"]
|
download_url = (await response.json())["download_url"]
|
||||||
@@ -285,6 +294,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
media: MediaListType = None,
|
media: MediaListType = None,
|
||||||
return_conversation: bool = False,
|
return_conversation: bool = False,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
|
prompt: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
"""
|
"""
|
||||||
@@ -403,10 +413,11 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
if conversation.conversation_id is not None:
|
if conversation.conversation_id is not None:
|
||||||
data["conversation_id"] = conversation.conversation_id
|
data["conversation_id"] = conversation.conversation_id
|
||||||
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
||||||
|
prompt = get_last_user_message(messages) if prompt is None else prompt
|
||||||
if action != "continue":
|
if action != "continue":
|
||||||
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
|
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
|
||||||
conversation.parent_message_id = None
|
conversation.parent_message_id = None
|
||||||
messages = messages if conversation.conversation_id is None else [{"role": "user", "content": get_last_user_message(messages)}]
|
messages = messages if conversation.conversation_id is None else [{"role": "user", "content": prompt}]
|
||||||
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
|
data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None)
|
||||||
headers = {
|
headers = {
|
||||||
**cls._headers,
|
**cls._headers,
|
||||||
@@ -433,7 +444,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
buffer = u""
|
buffer = u""
|
||||||
async for line in response.iter_lines():
|
async for line in response.iter_lines():
|
||||||
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources):
|
async for chunk in cls.iter_messages_line(session, auth_result, line, conversation, sources, prompt):
|
||||||
if isinstance(chunk, str):
|
if isinstance(chunk, str):
|
||||||
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
|
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@@ -475,7 +486,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
yield FinishReason(conversation.finish_reason)
|
yield FinishReason(conversation.finish_reason)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
|
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes, fields: Conversation, sources: Sources, prompt: str) -> AsyncIterator:
|
||||||
if not line.startswith(b"data: "):
|
if not line.startswith(b"data: "):
|
||||||
return
|
return
|
||||||
elif line.startswith(b"data: [DONE]"):
|
elif line.startswith(b"data: [DONE]"):
|
||||||
@@ -490,7 +501,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
if line["type"] == "title_generation":
|
if line["type"] == "title_generation":
|
||||||
yield TitleGeneration(line["title"])
|
yield TitleGeneration(line["title"])
|
||||||
fields.p = line.get("p", fields.p)
|
fields.p = line.get("p", fields.p)
|
||||||
if fields.p.startswith("/message/content/thoughts"):
|
if fields.p is not None and fields.p.startswith("/message/content/thoughts"):
|
||||||
if fields.p.endswith("/content"):
|
if fields.p.endswith("/content"):
|
||||||
if fields.thoughts_summary:
|
if fields.thoughts_summary:
|
||||||
yield Reasoning(token="", status=fields.thoughts_summary)
|
yield Reasoning(token="", status=fields.thoughts_summary)
|
||||||
@@ -539,7 +550,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
generated_images = []
|
generated_images = []
|
||||||
for element in c.get("parts"):
|
for element in c.get("parts"):
|
||||||
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
|
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
|
||||||
image = cls.get_generated_image(session, auth_result, element)
|
image = cls.get_generated_image(session, auth_result, element, prompt, fields.conversation_id)
|
||||||
generated_images.append(image)
|
generated_images.append(image)
|
||||||
for image_response in await asyncio.gather(*generated_images):
|
for image_response in await asyncio.gather(*generated_images):
|
||||||
if image_response is not None:
|
if image_response is not None:
|
||||||
|
@@ -360,7 +360,7 @@ class Backend_Api(Api):
|
|||||||
if seed not in ["true", "True", "1"]:
|
if seed not in ["true", "True", "1"]:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
return redirect(f"/media/{random.choice(match_files)}"), 302
|
return redirect(f"/media/{random.choice(match_files)}"), 302
|
||||||
return redirect(f"/media/{match_files[int(request.args.get('skip', 0))]}", 302)
|
return redirect(f"/media/{match_files[int(request.args.get('skip') or 0)]}", 302)
|
||||||
|
|
||||||
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
|
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
|
||||||
def upload_cookies():
|
def upload_cookies():
|
||||||
|
@@ -3,28 +3,36 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from flask import send_from_directory, redirect
|
from flask import send_from_directory, redirect, request
|
||||||
from ...image.copy_images import secure_filename, get_media_dir, ensure_media_dir
|
from ...image.copy_images import secure_filename, get_media_dir, ensure_media_dir
|
||||||
from ...errors import VersionNotFoundError
|
from ...errors import VersionNotFoundError
|
||||||
from ... import version
|
from ... import version
|
||||||
|
|
||||||
GPT4FREE_URL = "https://gpt4free.github.io"
|
GPT4FREE_URL = "https://gpt4free.github.io"
|
||||||
|
DIST_DIR = "./gpt4free.github.io/dist"
|
||||||
|
|
||||||
def redirect_home():
|
def redirect_home():
|
||||||
return redirect('/chat')
|
return redirect('/chat')
|
||||||
|
|
||||||
def render(filename = "chat"):
|
def render(filename = "chat", add_origion = True):
|
||||||
|
if request.args.get("live"):
|
||||||
|
add_origion = False
|
||||||
|
if os.path.exists(DIST_DIR):
|
||||||
|
path = os.path.abspath(os.path.join(os.path.dirname(DIST_DIR), (filename + ("" if "." in filename else ".html"))))
|
||||||
|
print( f"Debug mode: {path}")
|
||||||
|
return send_from_directory(os.path.dirname(path), os.path.basename(path))
|
||||||
try:
|
try:
|
||||||
latest_version = version.utils.latest_version
|
latest_version = version.utils.latest_version
|
||||||
except VersionNotFoundError:
|
except VersionNotFoundError:
|
||||||
latest_version = version.utils.current_version
|
latest_version = version.utils.current_version
|
||||||
today = datetime.today().strftime('%Y-%m-%d')
|
today = datetime.today().strftime('%Y-%m-%d')
|
||||||
cache_file = os.path.join(get_media_dir(), f"{today}.{secure_filename(filename)}.{version.utils.current_version}-{latest_version}.html")
|
cache_file = os.path.join(get_media_dir(), f"{today}.{secure_filename(filename)}.{version.utils.current_version}-{latest_version}{'.live' if add_origion else ''}.html")
|
||||||
if not os.path.exists(cache_file):
|
if not os.path.exists(cache_file):
|
||||||
ensure_media_dir()
|
ensure_media_dir()
|
||||||
html = requests.get(f"{GPT4FREE_URL}/{filename}.html").text
|
html = requests.get(f"{GPT4FREE_URL}/{filename}.html").text
|
||||||
html = html.replace("../dist/", f"{GPT4FREE_URL}/dist/")
|
if add_origion:
|
||||||
html = html.replace('"dist/', f"\"{GPT4FREE_URL}/dist/")
|
html = html.replace("../dist/", f"dist/")
|
||||||
|
html = html.replace("\"dist/", f"\"{GPT4FREE_URL}/dist/")
|
||||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(html)
|
f.write(html)
|
||||||
return send_from_directory(os.path.abspath(get_media_dir()), os.path.basename(cache_file))
|
return send_from_directory(os.path.abspath(get_media_dir()), os.path.basename(cache_file))
|
||||||
@@ -57,6 +65,10 @@ class Website:
|
|||||||
'function': redirect_home,
|
'function': redirect_home,
|
||||||
'methods': ['GET', 'POST']
|
'methods': ['GET', 'POST']
|
||||||
},
|
},
|
||||||
|
'/dist/<path:name>': {
|
||||||
|
'function': self._dist,
|
||||||
|
'methods': ['GET']
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _index(self, filename = "index"):
|
def _index(self, filename = "index"):
|
||||||
@@ -71,3 +83,6 @@ class Website:
|
|||||||
def _chat(self, filename = "chat"):
|
def _chat(self, filename = "chat"):
|
||||||
filename = "chat/index" if filename == 'chat' else secure_filename(filename)
|
filename = "chat/index" if filename == 'chat' else secure_filename(filename)
|
||||||
return render(filename)
|
return render(filename)
|
||||||
|
|
||||||
|
def _dist(self, name: str):
|
||||||
|
return send_from_directory(os.path.abspath(DIST_DIR), name)
|
Reference in New Issue
Block a user