Support synthesize in Openai generator (#2394)

* Improve download of generated images, serve images in the api

* Add support for conversation handling in the api

* Add orginal prompt to image response

* Add download images option in gui, fix loading model list in Airforce

* Support speech synthesize in Openai generator
This commit is contained in:
H Lohaus
2024-11-21 05:00:08 +01:00
committed by GitHub
parent ffb4b0d162
commit eae317a166
8 changed files with 260 additions and 77 deletions

View File

@@ -7,7 +7,6 @@ import json
import base64 import base64
import time import time
import requests import requests
from aiohttp import ClientWebSocketResponse
from copy import copy from copy import copy
try: try:
@@ -28,7 +27,7 @@ from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import StreamSession from ...requests.aiohttp import StreamSession
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, ResponseError from ...errors import MissingAuthError, ResponseError
from ...providers.response import BaseConversation from ...providers.response import BaseConversation, FinishReason, SynthesizeData
from ..helper import format_cookies from ..helper import format_cookies
from ..openai.har_file import get_request_config, NoValidHarFileError from ..openai.har_file import get_request_config, NoValidHarFileError
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
@@ -367,19 +366,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises: Raises:
RuntimeError: If an error occurs during processing. RuntimeError: If an error occurs during processing.
""" """
await cls.login(proxy)
async with StreamSession( async with StreamSession(
proxy=proxy, proxy=proxy,
impersonate="chrome", impersonate="chrome",
timeout=timeout timeout=timeout
) as session: ) as session:
if cls._expires is not None and cls._expires < time.time():
cls._headers = cls._api_key = None
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
cls._set_api_key(RequestConfig.access_token)
except NoValidHarFileError as e:
await cls.nodriver_auth(proxy)
try: try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e: except Exception as e:
@@ -469,18 +462,25 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await asyncio.sleep(5) await asyncio.sleep(5)
continue continue
await raise_for_status(response) await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation): if return_conversation:
if return_conversation: history_disabled = False
history_disabled = False yield conversation
return_conversation = False async for line in response.iter_lines():
yield conversation async for chunk in cls.iter_messages_line(session, line, conversation):
yield chunk yield chunk
if not history_disabled:
yield SynthesizeData(cls.__name__, {
"conversation_id": conversation.conversation_id,
"message_id": conversation.message_id,
"voice": "maple",
})
if auto_continue and conversation.finish_reason == "max_tokens": if auto_continue and conversation.finish_reason == "max_tokens":
conversation.finish_reason = None conversation.finish_reason = None
action = "continue" action = "continue"
await asyncio.sleep(5) await asyncio.sleep(5)
else: else:
break break
yield FinishReason(conversation.finish_reason)
if history_disabled and auto_continue: if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, conversation.conversation_id) await cls.delete_conversation(session, cls._headers, conversation.conversation_id)
@@ -541,10 +541,38 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if "error" in line and line.get("error"): if "error" in line and line.get("error"):
raise RuntimeError(line.get("error")) raise RuntimeError(line.get("error"))
@classmethod
async def synthesize(cls, params: dict) -> AsyncIterator[bytes]:
await cls.login()
async with StreamSession(
impersonate="chrome",
timeout=900
) as session:
async with session.get(
f"{cls.url}/backend-api/synthesize",
params=params,
headers=cls._headers
) as response:
await raise_for_status(response)
async for chunk in response.iter_content():
yield chunk
@classmethod
async def login(cls, proxy: str = None):
if cls._expires is not None and cls._expires < time.time():
cls._headers = cls._api_key = None
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
cls._set_api_key(RequestConfig.access_token)
except NoValidHarFileError:
if has_nodriver:
await cls.nodriver_auth(proxy)
else:
raise
@classmethod @classmethod
async def nodriver_auth(cls, proxy: str = None): async def nodriver_auth(cls, proxy: str = None):
if not has_nodriver:
return
if has_platformdirs: if has_platformdirs:
user_data_dir = user_config_dir("g4f-nodriver") user_data_dir = user_config_dir("g4f-nodriver")
else: else:

View File

@@ -13,7 +13,7 @@ from ..providers.base_provider import AsyncGeneratorProvider
from ..image import ImageResponse, copy_images, images_dir from ..image import ImageResponse, copy_images, images_dir
from ..typing import Messages, Image, ImageType from ..typing import Messages, Image, ImageType
from ..providers.types import ProviderType from ..providers.types import ProviderType
from ..providers.response import ResponseType, FinishReason, BaseConversation from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
from ..errors import NoImageResponseError, ModelNotFoundError from ..errors import NoImageResponseError, ModelNotFoundError
from ..providers.retry_provider import IterListProvider from ..providers.retry_provider import IterListProvider
from ..providers.base_provider import get_running_loop from ..providers.base_provider import get_running_loop
@@ -60,6 +60,8 @@ def iter_response(
elif isinstance(chunk, BaseConversation): elif isinstance(chunk, BaseConversation):
yield chunk yield chunk
continue continue
elif isinstance(chunk, SynthesizeData):
continue
chunk = str(chunk) chunk = str(chunk)
content += chunk content += chunk
@@ -121,6 +123,8 @@ async def async_iter_response(
elif isinstance(chunk, BaseConversation): elif isinstance(chunk, BaseConversation):
yield chunk yield chunk
continue continue
elif isinstance(chunk, SynthesizeData):
continue
chunk = str(chunk) chunk = str(chunk)
content += chunk content += chunk

View File

@@ -259,7 +259,6 @@ body {
flex-direction: column; flex-direction: column;
gap: var(--section-gap); gap: var(--section-gap);
padding: var(--inner-gap) var(--section-gap); padding: var(--inner-gap) var(--section-gap);
padding-bottom: 0;
} }
.message.print { .message.print {
@@ -271,7 +270,11 @@ body {
} }
.message.regenerate { .message.regenerate {
opacity: 0.75; background-color: var(--colour-6);
}
.white .message.regenerate {
background-color: var(--colour-4);
} }
.message:last-child { .message:last-child {
@@ -407,6 +410,7 @@ body {
.message .count .fa-clipboard.clicked, .message .count .fa-clipboard.clicked,
.message .count .fa-print.clicked, .message .count .fa-print.clicked,
.message .count .fa-rotate.clicked,
.message .count .fa-volume-high.active { .message .count .fa-volume-high.active {
color: var(--accent); color: var(--accent);
} }
@@ -430,6 +434,15 @@ body {
font-size: 12px; font-size: 12px;
} }
.message audio {
display: none;
max-width: 400px;
}
.message audio.show {
display: block;
}
.count_total { .count_total {
font-size: 12px; font-size: 12px;
padding-left: 25px; padding-left: 25px;
@@ -1159,7 +1172,10 @@ a:-webkit-any-link {
.message .user { .message .user {
display: none; display: none;
} }
.message.regenerate { body {
opacity: 1; height: auto;
}
.box {
backdrop-filter: none;
} }
} }

View File

@@ -28,6 +28,7 @@ let message_storage = {};
let controller_storage = {}; let controller_storage = {};
let content_storage = {}; let content_storage = {};
let error_storage = {}; let error_storage = {};
let synthesize_storage = {};
messageInput.addEventListener("blur", () => { messageInput.addEventListener("blur", () => {
window.scrollTo(0, 0); window.scrollTo(0, 0);
@@ -134,6 +135,13 @@ const register_message_buttons = async () => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
const content_el = el.parentElement.parentElement;
const audio = content_el.querySelector("audio");
if (audio) {
audio.classList.add("show");
audio.play();
return;
}
let playlist = []; let playlist = [];
function play_next() { function play_next() {
const next = playlist.shift(); const next = playlist.shift();
@@ -155,7 +163,6 @@ const register_message_buttons = async () => {
el.dataset.running = true; el.dataset.running = true;
el.classList.add("blink") el.classList.add("blink")
el.classList.add("active") el.classList.add("active")
const content_el = el.parentElement.parentElement;
const message_el = content_el.parentElement; const message_el = content_el.parentElement;
let speechText = await get_message(window.conversation_id, message_el.dataset.index); let speechText = await get_message(window.conversation_id, message_el.dataset.index);
@@ -215,8 +222,8 @@ const register_message_buttons = async () => {
const message_el = el.parentElement.parentElement.parentElement; const message_el = el.parentElement.parentElement.parentElement;
el.classList.add("clicked"); el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000); setTimeout(() => el.classList.remove("clicked"), 1000);
await ask_gpt(message_el.dataset.index, get_message_id()); await ask_gpt(get_message_id(), message_el.dataset.index);
}) });
} }
}); });
document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => { document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
@@ -301,25 +308,29 @@ const handle_ask = async () => {
<i class="fa-regular fa-clipboard"></i> <i class="fa-regular fa-clipboard"></i>
<a><i class="fa-brands fa-whatsapp"></i></a> <a><i class="fa-brands fa-whatsapp"></i></a>
<i class="fa-solid fa-print"></i> <i class="fa-solid fa-print"></i>
<i class="fa-solid fa-rotate"></i>
</div> </div>
</div> </div>
</div> </div>
`; `;
highlight(message_box); highlight(message_box);
stop_generating.classList.remove("stop_generating-hidden"); await ask_gpt(message_id);
await ask_gpt(-1, message_id);
}; };
async function remove_cancel_button() { async function safe_remove_cancel_button() {
for (let key in controller_storage) {
if (!controller_storage[key].signal.aborted) {
return;
}
}
stop_generating.classList.add("stop_generating-hidden"); stop_generating.classList.add("stop_generating-hidden");
} }
regenerate.addEventListener("click", async () => { regenerate.addEventListener("click", async () => {
regenerate.classList.add("regenerate-hidden"); regenerate.classList.add("regenerate-hidden");
setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000); setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
stop_generating.classList.remove("stop_generating-hidden");
await hide_message(window.conversation_id); await hide_message(window.conversation_id);
await ask_gpt(-1, get_message_id()); await ask_gpt(get_message_id());
}); });
stop_generating.addEventListener("click", async () => { stop_generating.addEventListener("click", async () => {
@@ -337,23 +348,23 @@ stop_generating.addEventListener("click", async () => {
} }
} }
} }
await load_conversation(window.conversation_id); await load_conversation(window.conversation_id, false);
}); });
const prepare_messages = (messages, message_index = -1) => { const prepare_messages = (messages, message_index = -1) => {
// Removes none user messages at end if (message_index >= 0) {
if (message_index == -1) {
let last_message;
while (last_message = messages.pop()) {
if (last_message["role"] == "user") {
messages.push(last_message);
break;
}
}
} else if (message_index >= 0) {
messages = messages.filter((_, index) => message_index >= index); messages = messages.filter((_, index) => message_index >= index);
} }
// Removes none user messages at end
let last_message;
while (last_message = messages.pop()) {
if (last_message["role"] == "user") {
messages.push(last_message);
break;
}
}
let new_messages = []; let new_messages = [];
if (systemPrompt?.value) { if (systemPrompt?.value) {
new_messages.push({ new_messages.push({
@@ -377,9 +388,11 @@ const prepare_messages = (messages, message_index = -1) => {
// Remove generated images from history // Remove generated images from history
new_message.content = filter_message(new_message.content); new_message.content = filter_message(new_message.content);
delete new_message.provider; delete new_message.provider;
delete new_message.synthesize;
new_messages.push(new_message) new_messages.push(new_message)
} }
}); });
return new_messages; return new_messages;
} }
@@ -427,6 +440,8 @@ async function add_message_chunk(message, message_id) {
let p = document.createElement("p"); let p = document.createElement("p");
p.innerText = message.log; p.innerText = message.log;
log_storage.appendChild(p); log_storage.appendChild(p);
} else if (message.type == "synthesize") {
synthesize_storage[message_id] = message.synthesize;
} }
let scroll_down = ()=>{ let scroll_down = ()=>{
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) { if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
@@ -434,8 +449,10 @@ async function add_message_chunk(message, message_id) {
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
} }
} }
setTimeout(scroll_down, 200); if (!content_map.container.classList.contains("regenerate")) {
setTimeout(scroll_down, 1000); scroll_down();
setTimeout(scroll_down, 200);
}
} }
cameraInput?.addEventListener("click", (e) => { cameraInput?.addEventListener("click", (e) => {
@@ -452,45 +469,58 @@ imageInput?.addEventListener("click", (e) => {
} }
}); });
const ask_gpt = async (message_index = -1, message_id) => { const ask_gpt = async (message_id, message_index = -1) => {
let messages = await get_messages(window.conversation_id); let messages = await get_messages(window.conversation_id);
let total_messages = messages.length;
messages = prepare_messages(messages, message_index); messages = prepare_messages(messages, message_index);
message_index = total_messages
message_storage[message_id] = ""; message_storage[message_id] = "";
stop_generating.classList.remove(".stop_generating-hidden"); stop_generating.classList.remove("stop_generating-hidden");
message_box.scrollTop = message_box.scrollHeight; if (message_index == -1) {
window.scrollTo(0, 0); await scroll_to_bottom();
}
let count_total = message_box.querySelector('.count_total'); let count_total = message_box.querySelector('.count_total');
count_total ? count_total.parentElement.removeChild(count_total) : null; count_total ? count_total.parentElement.removeChild(count_total) : null;
message_box.innerHTML += ` const message_el = document.createElement("div");
<div class="message" data-index="${message_index}"> message_el.classList.add("message");
<div class="assistant"> if (message_index != -1) {
${gpt_image} message_el.classList.add("regenerate");
<i class="fa-solid fa-xmark"></i> }
<i class="fa-regular fa-phone-arrow-down-left"></i> message_el.innerHTML += `
</div> <div class="assistant">
<div class="content" id="gpt_${message_id}"> ${gpt_image}
<div class="provider"></div> <i class="fa-solid fa-xmark"></i>
<div class="content_inner"><span class="cursor"></span></div> <i class="fa-regular fa-phone-arrow-down-left"></i>
<div class="count"></div> </div>
</div> <div class="content" id="gpt_${message_id}">
<div class="provider"></div>
<div class="content_inner"><span class="cursor"></span></div>
<div class="count"></div>
</div> </div>
`; `;
if (message_index == -1) {
message_box.appendChild(message_el);
} else {
parent_message = message_box.querySelector(`.message[data-index="${message_index}"]`);
if (!parent_message) {
return;
}
parent_message.after(message_el);
}
controller_storage[message_id] = new AbortController(); controller_storage[message_id] = new AbortController();
let content_el = document.getElementById(`gpt_${message_id}`) let content_el = document.getElementById(`gpt_${message_id}`)
let content_map = content_storage[message_id] = { let content_map = content_storage[message_id] = {
container: message_el,
content: content_el, content: content_el,
inner: content_el.querySelector('.content_inner'), inner: content_el.querySelector('.content_inner'),
count: content_el.querySelector('.count'), count: content_el.querySelector('.count'),
} }
if (message_index == -1) {
await scroll_to_bottom(); await scroll_to_bottom();
}
try { try {
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput; const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
const file = input && input.files.length > 0 ? input.files[0] : null; const file = input && input.files.length > 0 ? input.files[0] : null;
@@ -527,14 +557,23 @@ const ask_gpt = async (message_index = -1, message_id) => {
delete controller_storage[message_id]; delete controller_storage[message_id];
if (!error_storage[message_id] && message_storage[message_id]) { if (!error_storage[message_id] && message_storage[message_id]) {
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null; const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider); await add_message(
await safe_load_conversation(window.conversation_id); window.conversation_id,
"assistant",
message_storage[message_id],
message_provider,
message_index,
synthesize_storage[message_id]
);
await safe_load_conversation(window.conversation_id, message_index == -1);
} else { } else {
let cursorDiv = message_box.querySelector(".cursor"); let cursorDiv = message_el.querySelector(".cursor");
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv); if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
} }
await scroll_to_bottom(); if (message_index == -1) {
await remove_cancel_button(); await scroll_to_bottom();
}
await safe_remove_cancel_button();
await register_message_buttons(); await register_message_buttons();
await load_conversations(); await load_conversations();
regenerate.classList.remove("regenerate-hidden"); regenerate.classList.remove("regenerate-hidden");
@@ -687,6 +726,15 @@ const load_conversation = async (conversation_id, scroll=true) => {
${item.provider.model ? ' with ' + item.provider.model : ''} ${item.provider.model ? ' with ' + item.provider.model : ''}
</div> </div>
` : ""; ` : "";
let audio = "";
if (item.synthesize) {
const synthesize_params = (new URLSearchParams(item.synthesize.data)).toString();
audio = `
<audio controls preload="none">
<source src="/backend-api/v2/synthesize/${item.synthesize.provider}?${synthesize_params}" type="audio/mpeg">
</audio>
`;
}
elements += ` elements += `
<div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}"> <div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}">
<div class="${item.role}"> <div class="${item.role}">
@@ -700,12 +748,14 @@ const load_conversation = async (conversation_id, scroll=true) => {
<div class="content"> <div class="content">
${provider} ${provider}
<div class="content_inner">${markdown_render(item.content)}</div> <div class="content_inner">${markdown_render(item.content)}</div>
${audio}
<div class="count"> <div class="count">
${count_words_and_tokens(item.content, next_provider?.model)} ${count_words_and_tokens(item.content, next_provider?.model)}
<i class="fa-solid fa-volume-high"></i> <i class="fa-solid fa-volume-high"></i>
<i class="fa-regular fa-clipboard"></i> <i class="fa-regular fa-clipboard"></i>
<a><i class="fa-brands fa-whatsapp"></i></a> <a><i class="fa-brands fa-whatsapp"></i></a>
<i class="fa-solid fa-print"></i> <i class="fa-solid fa-print"></i>
<i class="fa-solid fa-rotate"></i>
</div> </div>
</div> </div>
</div> </div>
@@ -830,14 +880,35 @@ const get_message = async (conversation_id, index) => {
return messages[index]["content"]; return messages[index]["content"];
}; };
const add_message = async (conversation_id, role, content, provider) => { const add_message = async (
conversation_id, role, content,
provider = null,
message_index = -1,
synthesize_data = null
) => {
const conversation = await get_conversation(conversation_id); const conversation = await get_conversation(conversation_id);
if (!conversation) return; if (!conversation) return;
conversation.items.push({ const new_message = {
role: role, role: role,
content: content, content: content,
provider: provider provider: provider,
}); };
if (synthesize_data) {
new_message.synthesize = synthesize_data;
}
if (message_index == -1) {
conversation.items.push(new_message);
} else {
const new_messages = [];
conversation.items.forEach((item, index)=>{
new_messages.push(item);
if (index == message_index) {
new_message.regenerate = true;
new_messages.push(new_message);
}
});
conversation.items = new_messages;
}
await save_conversation(conversation_id, conversation); await save_conversation(conversation_id, conversation);
return conversation.items.length - 1; return conversation.items.length - 1;
}; };

View File

@@ -13,7 +13,7 @@ from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__ from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.response import BaseConversation, FinishReason from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
from g4f.client.service import convert_to_provider from g4f.client.service import convert_to_provider
from g4f import debug from g4f import debug
@@ -177,6 +177,8 @@ class Api:
images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
images = ImageResponse(images, chunk.alt) images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images)) yield self._format_json("content", str(images))
elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.to_json())
elif not isinstance(chunk, FinishReason): elif not isinstance(chunk, FinishReason):
yield self._format_json("content", str(chunk)) yield self._format_json("content", str(chunk))
if debug.logs: if debug.logs:

View File

@@ -1,8 +1,36 @@
import json import json
import asyncio
import flask
from flask import request, Flask from flask import request, Flask
from typing import AsyncGenerator, Generator
from g4f.image import is_allowed_extension, to_image from g4f.image import is_allowed_extension, to_image
from g4f.client.service import convert_to_provider
from g4f.errors import ProviderNotFoundError
from .api import Api from .api import Api
def safe_iter_generator(generator: Generator) -> Generator:
start = next(generator)
def iter_generator():
yield start
yield from generator
return iter_generator()
def to_sync_generator(gen: AsyncGenerator) -> Generator:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
gen = gen.__aiter__()
async def get_next():
try:
obj = await gen.__anext__()
return False, obj
except StopAsyncIteration: return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
class Backend_Api(Api): class Backend_Api(Api):
""" """
Handles various endpoints in a Flask application for backend operations. Handles various endpoints in a Flask application for backend operations.
@@ -47,6 +75,10 @@ class Backend_Api(Api):
'function': self.handle_conversation, 'function': self.handle_conversation,
'methods': ['POST'] 'methods': ['POST']
}, },
'/backend-api/v2/synthesize/<provider>': {
'function': self.handle_synthesize,
'methods': ['GET']
},
'/backend-api/v2/error': { '/backend-api/v2/error': {
'function': self.handle_error, 'function': self.handle_error,
'methods': ['POST'] 'methods': ['POST']
@@ -98,11 +130,28 @@ class Backend_Api(Api):
mimetype='text/event-stream' mimetype='text/event-stream'
) )
def handle_synthesize(self, provider: str):
try:
provider_handler = convert_to_provider(provider)
except ProviderNotFoundError:
return "Provider not found", 404
if not hasattr(provider_handler, "synthesize"):
return "Provider doesn't support synthesize", 500
try:
response_generator = provider_handler.synthesize({**request.args})
if hasattr(response_generator, "__aiter__"):
response_generator = to_sync_generator(response_generator)
response = flask.Response(safe_iter_generator(response_generator), content_type="audio/mpeg")
response.headers['Cache-Control'] = "max-age=604800"
return response
except Exception as e:
return f"{e.__class__.__name__}: {e}", 500
def get_provider_models(self, provider: str): def get_provider_models(self, provider: str):
api_key = None if request.authorization is None else request.authorization.token api_key = None if request.authorization is None else request.authorization.token
models = super().get_provider_models(provider, api_key) models = super().get_provider_models(provider, api_key)
if models is None: if models is None:
return 404, "Provider not found" return "Provider not found", 404
return models return models
def _format_json(self, response_type: str, content) -> str: def _format_json(self, response_type: str, content) -> str:

View File

@@ -11,7 +11,7 @@ from typing import Callable, Union
from ..typing import CreateResult, AsyncResult, Messages from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider from .types import BaseProvider
from .response import FinishReason, BaseConversation from .response import FinishReason, BaseConversation, SynthesizeData
from ..errors import NestAsyncioError, ModelNotSupportedError from ..errors import NestAsyncioError, ModelNotSupportedError
from .. import debug from .. import debug
@@ -259,7 +259,7 @@ class AsyncGeneratorProvider(AsyncProvider):
""" """
return "".join([ return "".join([
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, (Exception, FinishReason, BaseConversation)) if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
]) ])
@staticmethod @staticmethod

View File

@@ -22,5 +22,18 @@ class Sources(ResponseType):
return "\n\n" + ("\n".join([f"{idx+1}. [{link['title']}]({link['url']})" for idx, link in enumerate(self.list)])) return "\n\n" + ("\n".join([f"{idx+1}. [{link['title']}]({link['url']})" for idx, link in enumerate(self.list)]))
class BaseConversation(ResponseType): class BaseConversation(ResponseType):
def __str__(self) -> str:
return ""
class SynthesizeData(ResponseType):
def __init__(self, provider: str, data: dict):
self.provider = provider
self.data = data
def to_json(self) -> dict:
return {
**self.__dict__
}
def __str__(self) -> str: def __str__(self) -> str:
return "" return ""