Support synthesize in Openai generator (#2394)

* Improve download of generated images, serve images in the api

* Add support for conversation handling in the api

* Add orginal prompt to image response

* Add download images option in gui, fix loading model list in Airforce

* Support speech synthesize in Openai generator
This commit is contained in:
H Lohaus
2024-11-21 05:00:08 +01:00
committed by GitHub
parent ffb4b0d162
commit eae317a166
8 changed files with 260 additions and 77 deletions

View File

@@ -7,7 +7,6 @@ import json
import base64
import time
import requests
from aiohttp import ClientWebSocketResponse
from copy import copy
try:
@@ -28,7 +27,7 @@ from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import StreamSession
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, ResponseError
from ...providers.response import BaseConversation
from ...providers.response import BaseConversation, FinishReason, SynthesizeData
from ..helper import format_cookies
from ..openai.har_file import get_request_config, NoValidHarFileError
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
@@ -367,19 +366,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
await cls.login(proxy)
async with StreamSession(
proxy=proxy,
impersonate="chrome",
timeout=timeout
) as session:
if cls._expires is not None and cls._expires < time.time():
cls._headers = cls._api_key = None
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
cls._set_api_key(RequestConfig.access_token)
except NoValidHarFileError as e:
await cls.nodriver_auth(proxy)
try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
@@ -469,18 +462,25 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await asyncio.sleep(5)
continue
await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation):
if return_conversation:
history_disabled = False
return_conversation = False
yield conversation
async for line in response.iter_lines():
async for chunk in cls.iter_messages_line(session, line, conversation):
yield chunk
if not history_disabled:
yield SynthesizeData(cls.__name__, {
"conversation_id": conversation.conversation_id,
"message_id": conversation.message_id,
"voice": "maple",
})
if auto_continue and conversation.finish_reason == "max_tokens":
conversation.finish_reason = None
action = "continue"
await asyncio.sleep(5)
else:
break
yield FinishReason(conversation.finish_reason)
if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, conversation.conversation_id)
@@ -541,10 +541,38 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if "error" in line and line.get("error"):
raise RuntimeError(line.get("error"))
@classmethod
async def synthesize(cls, params: dict) -> AsyncIterator[bytes]:
await cls.login()
async with StreamSession(
impersonate="chrome",
timeout=900
) as session:
async with session.get(
f"{cls.url}/backend-api/synthesize",
params=params,
headers=cls._headers
) as response:
await raise_for_status(response)
async for chunk in response.iter_content():
yield chunk
@classmethod
async def login(cls, proxy: str = None):
if cls._expires is not None and cls._expires < time.time():
cls._headers = cls._api_key = None
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
cls._set_api_key(RequestConfig.access_token)
except NoValidHarFileError:
if has_nodriver:
await cls.nodriver_auth(proxy)
else:
raise
@classmethod
async def nodriver_auth(cls, proxy: str = None):
if not has_nodriver:
return
if has_platformdirs:
user_data_dir = user_config_dir("g4f-nodriver")
else:

View File

@@ -13,7 +13,7 @@ from ..providers.base_provider import AsyncGeneratorProvider
from ..image import ImageResponse, copy_images, images_dir
from ..typing import Messages, Image, ImageType
from ..providers.types import ProviderType
from ..providers.response import ResponseType, FinishReason, BaseConversation
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
from ..errors import NoImageResponseError, ModelNotFoundError
from ..providers.retry_provider import IterListProvider
from ..providers.base_provider import get_running_loop
@@ -60,6 +60,8 @@ def iter_response(
elif isinstance(chunk, BaseConversation):
yield chunk
continue
elif isinstance(chunk, SynthesizeData):
continue
chunk = str(chunk)
content += chunk
@@ -121,6 +123,8 @@ async def async_iter_response(
elif isinstance(chunk, BaseConversation):
yield chunk
continue
elif isinstance(chunk, SynthesizeData):
continue
chunk = str(chunk)
content += chunk

View File

@@ -259,7 +259,6 @@ body {
flex-direction: column;
gap: var(--section-gap);
padding: var(--inner-gap) var(--section-gap);
padding-bottom: 0;
}
.message.print {
@@ -271,7 +270,11 @@ body {
}
.message.regenerate {
opacity: 0.75;
background-color: var(--colour-6);
}
.white .message.regenerate {
background-color: var(--colour-4);
}
.message:last-child {
@@ -407,6 +410,7 @@ body {
.message .count .fa-clipboard.clicked,
.message .count .fa-print.clicked,
.message .count .fa-rotate.clicked,
.message .count .fa-volume-high.active {
color: var(--accent);
}
@@ -430,6 +434,15 @@ body {
font-size: 12px;
}
.message audio {
display: none;
max-width: 400px;
}
.message audio.show {
display: block;
}
.count_total {
font-size: 12px;
padding-left: 25px;
@@ -1159,7 +1172,10 @@ a:-webkit-any-link {
.message .user {
display: none;
}
.message.regenerate {
opacity: 1;
body {
height: auto;
}
.box {
backdrop-filter: none;
}
}

View File

@@ -28,6 +28,7 @@ let message_storage = {};
let controller_storage = {};
let content_storage = {};
let error_storage = {};
let synthesize_storage = {};
messageInput.addEventListener("blur", () => {
window.scrollTo(0, 0);
@@ -134,6 +135,13 @@ const register_message_buttons = async () => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
const content_el = el.parentElement.parentElement;
const audio = content_el.querySelector("audio");
if (audio) {
audio.classList.add("show");
audio.play();
return;
}
let playlist = [];
function play_next() {
const next = playlist.shift();
@@ -155,7 +163,6 @@ const register_message_buttons = async () => {
el.dataset.running = true;
el.classList.add("blink")
el.classList.add("active")
const content_el = el.parentElement.parentElement;
const message_el = content_el.parentElement;
let speechText = await get_message(window.conversation_id, message_el.dataset.index);
@@ -215,8 +222,8 @@ const register_message_buttons = async () => {
const message_el = el.parentElement.parentElement.parentElement;
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
await ask_gpt(message_el.dataset.index, get_message_id());
})
await ask_gpt(get_message_id(), message_el.dataset.index);
});
}
});
document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
@@ -301,25 +308,29 @@ const handle_ask = async () => {
<i class="fa-regular fa-clipboard"></i>
<a><i class="fa-brands fa-whatsapp"></i></a>
<i class="fa-solid fa-print"></i>
<i class="fa-solid fa-rotate"></i>
</div>
</div>
</div>
`;
highlight(message_box);
stop_generating.classList.remove("stop_generating-hidden");
await ask_gpt(-1, message_id);
await ask_gpt(message_id);
};
async function remove_cancel_button() {
async function safe_remove_cancel_button() {
for (let key in controller_storage) {
if (!controller_storage[key].signal.aborted) {
return;
}
}
stop_generating.classList.add("stop_generating-hidden");
}
regenerate.addEventListener("click", async () => {
regenerate.classList.add("regenerate-hidden");
setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
stop_generating.classList.remove("stop_generating-hidden");
await hide_message(window.conversation_id);
await ask_gpt(-1, get_message_id());
await ask_gpt(get_message_id());
});
stop_generating.addEventListener("click", async () => {
@@ -337,12 +348,15 @@ stop_generating.addEventListener("click", async () => {
}
}
}
await load_conversation(window.conversation_id);
await load_conversation(window.conversation_id, false);
});
const prepare_messages = (messages, message_index = -1) => {
if (message_index >= 0) {
messages = messages.filter((_, index) => message_index >= index);
}
// Removes none user messages at end
if (message_index == -1) {
let last_message;
while (last_message = messages.pop()) {
if (last_message["role"] == "user") {
@@ -350,9 +364,6 @@ const prepare_messages = (messages, message_index = -1) => {
break;
}
}
} else if (message_index >= 0) {
messages = messages.filter((_, index) => message_index >= index);
}
let new_messages = [];
if (systemPrompt?.value) {
@@ -377,9 +388,11 @@ const prepare_messages = (messages, message_index = -1) => {
// Remove generated images from history
new_message.content = filter_message(new_message.content);
delete new_message.provider;
delete new_message.synthesize;
new_messages.push(new_message)
}
});
return new_messages;
}
@@ -427,6 +440,8 @@ async function add_message_chunk(message, message_id) {
let p = document.createElement("p");
p.innerText = message.log;
log_storage.appendChild(p);
} else if (message.type == "synthesize") {
synthesize_storage[message_id] = message.synthesize;
}
let scroll_down = ()=>{
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
@@ -434,8 +449,10 @@ async function add_message_chunk(message, message_id) {
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
}
}
if (!content_map.container.classList.contains("regenerate")) {
scroll_down();
setTimeout(scroll_down, 200);
setTimeout(scroll_down, 1000);
}
}
cameraInput?.addEventListener("click", (e) => {
@@ -452,22 +469,25 @@ imageInput?.addEventListener("click", (e) => {
}
});
const ask_gpt = async (message_index = -1, message_id) => {
const ask_gpt = async (message_id, message_index = -1) => {
let messages = await get_messages(window.conversation_id);
let total_messages = messages.length;
messages = prepare_messages(messages, message_index);
message_index = total_messages
message_storage[message_id] = "";
stop_generating.classList.remove(".stop_generating-hidden");
stop_generating.classList.remove("stop_generating-hidden");
message_box.scrollTop = message_box.scrollHeight;
window.scrollTo(0, 0);
if (message_index == -1) {
await scroll_to_bottom();
}
let count_total = message_box.querySelector('.count_total');
count_total ? count_total.parentElement.removeChild(count_total) : null;
message_box.innerHTML += `
<div class="message" data-index="${message_index}">
const message_el = document.createElement("div");
message_el.classList.add("message");
if (message_index != -1) {
message_el.classList.add("regenerate");
}
message_el.innerHTML += `
<div class="assistant">
${gpt_image}
<i class="fa-solid fa-xmark"></i>
@@ -478,19 +498,29 @@ const ask_gpt = async (message_index = -1, message_id) => {
<div class="content_inner"><span class="cursor"></span></div>
<div class="count"></div>
</div>
</div>
`;
if (message_index == -1) {
message_box.appendChild(message_el);
} else {
parent_message = message_box.querySelector(`.message[data-index="${message_index}"]`);
if (!parent_message) {
return;
}
parent_message.after(message_el);
}
controller_storage[message_id] = new AbortController();
let content_el = document.getElementById(`gpt_${message_id}`)
let content_map = content_storage[message_id] = {
container: message_el,
content: content_el,
inner: content_el.querySelector('.content_inner'),
count: content_el.querySelector('.count'),
}
if (message_index == -1) {
await scroll_to_bottom();
}
try {
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
const file = input && input.files.length > 0 ? input.files[0] : null;
@@ -527,14 +557,23 @@ const ask_gpt = async (message_index = -1, message_id) => {
delete controller_storage[message_id];
if (!error_storage[message_id] && message_storage[message_id]) {
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider);
await safe_load_conversation(window.conversation_id);
await add_message(
window.conversation_id,
"assistant",
message_storage[message_id],
message_provider,
message_index,
synthesize_storage[message_id]
);
await safe_load_conversation(window.conversation_id, message_index == -1);
} else {
let cursorDiv = message_box.querySelector(".cursor");
let cursorDiv = message_el.querySelector(".cursor");
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
}
if (message_index == -1) {
await scroll_to_bottom();
await remove_cancel_button();
}
await safe_remove_cancel_button();
await register_message_buttons();
await load_conversations();
regenerate.classList.remove("regenerate-hidden");
@@ -687,6 +726,15 @@ const load_conversation = async (conversation_id, scroll=true) => {
${item.provider.model ? ' with ' + item.provider.model : ''}
</div>
` : "";
let audio = "";
if (item.synthesize) {
const synthesize_params = (new URLSearchParams(item.synthesize.data)).toString();
audio = `
<audio controls preload="none">
<source src="/backend-api/v2/synthesize/${item.synthesize.provider}?${synthesize_params}" type="audio/mpeg">
</audio>
`;
}
elements += `
<div class="message${item.regenerate ? " regenerate": ""}" data-index="${i}">
<div class="${item.role}">
@@ -700,12 +748,14 @@ const load_conversation = async (conversation_id, scroll=true) => {
<div class="content">
${provider}
<div class="content_inner">${markdown_render(item.content)}</div>
${audio}
<div class="count">
${count_words_and_tokens(item.content, next_provider?.model)}
<i class="fa-solid fa-volume-high"></i>
<i class="fa-regular fa-clipboard"></i>
<a><i class="fa-brands fa-whatsapp"></i></a>
<i class="fa-solid fa-print"></i>
<i class="fa-solid fa-rotate"></i>
</div>
</div>
</div>
@@ -830,14 +880,35 @@ const get_message = async (conversation_id, index) => {
return messages[index]["content"];
};
const add_message = async (conversation_id, role, content, provider) => {
const add_message = async (
conversation_id, role, content,
provider = null,
message_index = -1,
synthesize_data = null
) => {
const conversation = await get_conversation(conversation_id);
if (!conversation) return;
conversation.items.push({
const new_message = {
role: role,
content: content,
provider: provider
provider: provider,
};
if (synthesize_data) {
new_message.synthesize = synthesize_data;
}
if (message_index == -1) {
conversation.items.push(new_message);
} else {
const new_messages = [];
conversation.items.forEach((item, index)=>{
new_messages.push(item);
if (index == message_index) {
new_message.regenerate = true;
new_messages.push(new_message);
}
});
conversation.items = new_messages;
}
await save_conversation(conversation_id, conversation);
return conversation.items.length - 1;
};

View File

@@ -13,7 +13,7 @@ from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.response import BaseConversation, FinishReason
from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
from g4f.client.service import convert_to_provider
from g4f import debug
@@ -177,6 +177,8 @@ class Api:
images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images))
elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.to_json())
elif not isinstance(chunk, FinishReason):
yield self._format_json("content", str(chunk))
if debug.logs:

View File

@@ -1,8 +1,36 @@
import json
import asyncio
import flask
from flask import request, Flask
from typing import AsyncGenerator, Generator
from g4f.image import is_allowed_extension, to_image
from g4f.client.service import convert_to_provider
from g4f.errors import ProviderNotFoundError
from .api import Api
def safe_iter_generator(generator: Generator) -> Generator:
start = next(generator)
def iter_generator():
yield start
yield from generator
return iter_generator()
def to_sync_generator(gen: AsyncGenerator) -> Generator:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
gen = gen.__aiter__()
async def get_next():
try:
obj = await gen.__anext__()
return False, obj
except StopAsyncIteration: return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
class Backend_Api(Api):
"""
Handles various endpoints in a Flask application for backend operations.
@@ -47,6 +75,10 @@ class Backend_Api(Api):
'function': self.handle_conversation,
'methods': ['POST']
},
'/backend-api/v2/synthesize/<provider>': {
'function': self.handle_synthesize,
'methods': ['GET']
},
'/backend-api/v2/error': {
'function': self.handle_error,
'methods': ['POST']
@@ -98,11 +130,28 @@ class Backend_Api(Api):
mimetype='text/event-stream'
)
def handle_synthesize(self, provider: str):
try:
provider_handler = convert_to_provider(provider)
except ProviderNotFoundError:
return "Provider not found", 404
if not hasattr(provider_handler, "synthesize"):
return "Provider doesn't support synthesize", 500
try:
response_generator = provider_handler.synthesize({**request.args})
if hasattr(response_generator, "__aiter__"):
response_generator = to_sync_generator(response_generator)
response = flask.Response(safe_iter_generator(response_generator), content_type="audio/mpeg")
response.headers['Cache-Control'] = "max-age=604800"
return response
except Exception as e:
return f"{e.__class__.__name__}: {e}", 500
def get_provider_models(self, provider: str):
api_key = None if request.authorization is None else request.authorization.token
models = super().get_provider_models(provider, api_key)
if models is None:
return 404, "Provider not found"
return "Provider not found", 404
return models
def _format_json(self, response_type: str, content) -> str:

View File

@@ -11,7 +11,7 @@ from typing import Callable, Union
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider
from .response import FinishReason, BaseConversation
from .response import FinishReason, BaseConversation, SynthesizeData
from ..errors import NestAsyncioError, ModelNotSupportedError
from .. import debug
@@ -259,7 +259,7 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
return "".join([
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, (Exception, FinishReason, BaseConversation))
if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
])
@staticmethod

View File

@@ -24,3 +24,16 @@ class Sources(ResponseType):
class BaseConversation(ResponseType):
def __str__(self) -> str:
return ""
class SynthesizeData(ResponseType):
def __init__(self, provider: str, data: dict):
self.provider = provider
self.data = data
def to_json(self) -> dict:
return {
**self.__dict__
}
def __str__(self) -> str:
return ""