Add chat share function

This commit is contained in:
hlohaus
2025-03-25 01:46:57 +01:00
parent 5ae71adbae
commit ae1fae7ef0
14 changed files with 239 additions and 88 deletions

View File

@@ -176,6 +176,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
# Step 3: Check Generation Status # Step 3: Check Generation Status
status_url = cls.status_check_url.format(record_id=record_id) status_url = cls.status_check_url.format(record_id=record_id)
counter = 0 counter = 0
start_time = time.time()
while True: while True:
async with session.get(status_url, headers=headers, proxy=proxy) as status_response: async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
status_data = await status_response.json() status_data = await status_response.json()
@@ -183,14 +184,15 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
if status == "DONE": if status == "DONE":
image_urls = [image["url"] for image in status_data.get("response", [])] image_urls = [image["url"] for image in status_data.get("response", [])]
yield Reasoning(status="Finished") duration = time.time() - start_time
yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s")
yield ImageResponse(images=image_urls, alt=prompt) yield ImageResponse(images=image_urls, alt=prompt)
return return
elif status in ("IN_QUEUE", "IN_PROGRESS"): elif status in ("IN_QUEUE", "IN_PROGRESS"):
yield Reasoning(status=("Waiting" if status == "IN_QUEUE" else "Generating") + "." * counter) yield Reasoning(label=("Waiting" if status == "IN_QUEUE" else "Generating"), status="." * counter)
await asyncio.sleep(2) # Poll every 5 seconds await asyncio.sleep(2) # Poll every 2 seconds
counter += 1 counter += 1
if counter > 3: if counter > 3:
counter = 0 counter = 1
else: else:
raise ResponseError(f"Image generation failed with status: {status}") raise ResponseError(f"Image generation failed with status: {status}")

View File

@@ -7,7 +7,7 @@ import uuid
from ..typing import AsyncResult, Messages, ImageType, Cookies from ..typing import AsyncResult, Messages, ImageType, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt from .helper import format_prompt
from ..image import EXTENSIONS_MAP, to_bytes, is_accepted_format from ..image import MEDIA_TYPE_MAP, to_bytes, is_accepted_format
from ..requests import StreamSession, FormData, raise_for_status, get_nodriver from ..requests import StreamSession, FormData, raise_for_status, get_nodriver
from ..providers.response import ImagePreview, ImageResponse from ..providers.response import ImagePreview, ImageResponse
from ..cookies import get_cookies from ..cookies import get_cookies
@@ -159,7 +159,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
upload_nonce = await response.text() upload_nonce = await response.text()
data = FormData() data = FormData()
content_type = is_accepted_format(file) content_type = is_accepted_format(file)
filename = f"image.{EXTENSIONS_MAP[content_type]}" if filename is None else filename filename = f"image.{MEDIA_TYPE_MAP[content_type]}" if filename is None else filename
data.add_field('file', file, content_type=content_type, filename=filename) data.add_field('file', file, content_type=content_type, filename=filename)
async with client.post( async with client.post(
f"{cls.url}/api/upload", f"{cls.url}/api/upload",

View File

@@ -202,7 +202,9 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
background_tasks.add(task) background_tasks.add(task)
task.add_done_callback(background_tasks.discard) task.add_done_callback(background_tasks.discard)
while background_tasks: while background_tasks:
yield Reasoning(label="Generating", status=f"{time.time() - started:.2f}s") diff = time.time() - started
if diff > 1:
yield Reasoning(label="Generating", status=f"{diff:.2f}s")
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
provider_info, media_response = await task provider_info, media_response = await task
yield Reasoning(label="Finished", status=f"{time.time() - started:.2f}s") yield Reasoning(label="Finished", status=f"{time.time() - started:.2f}s")

View File

@@ -72,12 +72,12 @@ class DeepSeekAPI(AsyncAuthedProvider, ProviderModelMixin):
): ):
if chunk['type'] == 'thinking': if chunk['type'] == 'thinking':
if not is_thinking: if not is_thinking:
yield Reasoning(None, "Is thinking...") yield Reasoning(status="Is thinking...")
is_thinking = time.time() is_thinking = time.time()
yield Reasoning(chunk['content']) yield Reasoning(chunk['content'])
elif chunk['type'] == 'text': elif chunk['type'] == 'text':
if is_thinking: if is_thinking:
yield Reasoning(None, f"Thought for {time.time() - is_thinking:.2f}s") yield Reasoning(status=f"Thought for {time.time() - is_thinking:.2f}s")
is_thinking = 0 is_thinking = 0
if chunk['content']: if chunk['content']:
yield chunk['content'] yield chunk['content']

View File

@@ -38,7 +38,7 @@ import g4f.debug
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
from g4f.providers.response import BaseConversation, JsonConversation from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none from g4f.client.helper import filter_none
from g4f.image import is_data_an_media from g4f.image import is_data_an_media, EXTENSIONS_MAP
from g4f.image.copy_images import images_dir, copy_media, get_source_url from g4f.image.copy_images import images_dir, copy_media, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.cookies import read_cookie_files, get_cookies_dir
@@ -179,7 +179,7 @@ class Api:
return ErrorResponse.from_message("G4F API key required", HTTP_401_UNAUTHORIZED) return ErrorResponse.from_message("G4F API key required", HTTP_401_UNAUTHORIZED)
if AppConfig.g4f_api_key is None or not secrets.compare_digest(AppConfig.g4f_api_key, user_g4f_api_key): if AppConfig.g4f_api_key is None or not secrets.compare_digest(AppConfig.g4f_api_key, user_g4f_api_key):
return ErrorResponse.from_message("Invalid G4F API key", HTTP_403_FORBIDDEN) return ErrorResponse.from_message("Invalid G4F API key", HTTP_403_FORBIDDEN)
elif not AppConfig.demo and not path.startswith("/images/"): elif not AppConfig.demo and not path.startswith("/images/") and not path.startswith("/media/"):
if user_g4f_api_key is not None: if user_g4f_api_key is not None:
if not secrets.compare_digest(AppConfig.g4f_api_key, user_g4f_api_key): if not secrets.compare_digest(AppConfig.g4f_api_key, user_g4f_api_key):
return ErrorResponse.from_message("Invalid G4F API key", HTTP_403_FORBIDDEN) return ErrorResponse.from_message("Invalid G4F API key", HTTP_403_FORBIDDEN)
@@ -189,7 +189,7 @@ class Api:
except HTTPException as e: except HTTPException as e:
return ErrorResponse.from_message(e.detail, e.status_code, e.headers) return ErrorResponse.from_message(e.detail, e.status_code, e.headers)
response = await call_next(request) response = await call_next(request)
response.headers["X-Username"] = username response.headers["x-user"] = username
return response return response
return await call_next(request) return await call_next(request)
@@ -220,7 +220,7 @@ class Api:
return HTMLResponse('g4f API: Go to ' return HTMLResponse('g4f API: Go to '
'<a href="/v1/models">models</a>, ' '<a href="/v1/models">models</a>, '
'<a href="/v1/chat/completions">chat/completions</a>, or ' '<a href="/v1/chat/completions">chat/completions</a>, or '
'<a href="/v1/images/generate">images/generate</a> <br><br>' '<a href="/v1/media/generate">media/generate</a> <br><br>'
'Open Swagger UI at: ' 'Open Swagger UI at: '
'<a href="/docs">/docs</a>') '<a href="/docs">/docs</a>')
@@ -259,7 +259,7 @@ class Api:
provider: ProviderType = ProviderUtils.convert[provider] provider: ProviderType = ProviderUtils.convert[provider]
if not hasattr(provider, "get_models"): if not hasattr(provider, "get_models"):
models = [] models = []
elif credentials is not None: elif credentials is not None and credentials.credentials != "secret":
models = provider.get_models(api_key=credentials.credentials) models = provider.get_models(api_key=credentials.credentials)
else: else:
models = provider.get_models() models = provider.get_models()
@@ -404,6 +404,7 @@ class Api:
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel}, HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
HTTP_500_INTERNAL_SERVER_ERROR: {"model": ErrorResponseModel}, HTTP_500_INTERNAL_SERVER_ERROR: {"model": ErrorResponseModel},
} }
@self.app.post("/v1/media/generate", responses=responses)
@self.app.post("/v1/images/generate", responses=responses) @self.app.post("/v1/images/generate", responses=responses)
@self.app.post("/v1/images/generations", responses=responses) @self.app.post("/v1/images/generations", responses=responses)
async def generate_image( async def generate_image(
@@ -555,9 +556,14 @@ class Api:
HTTP_200_OK: {"content": {"image/*": {}}}, HTTP_200_OK: {"content": {"image/*": {}}},
HTTP_404_NOT_FOUND: {} HTTP_404_NOT_FOUND: {}
}) })
async def get_image(filename, request: Request): @self.app.get("/media/{filename}", responses={
HTTP_200_OK: {"content": {"image/*": {}, "audio/*": {}}, "video/*": {}},
HTTP_404_NOT_FOUND: {}
})
async def get_media(filename, request: Request):
target = os.path.join(images_dir, os.path.basename(filename)) target = os.path.join(images_dir, os.path.basename(filename))
ext = os.path.splitext(filename)[1][1:] ext = os.path.splitext(filename)[1][1:]
mime_type = EXTENSIONS_MAP.get(ext)
stat_result = SimpleNamespace() stat_result = SimpleNamespace()
stat_result.st_size = 0 stat_result.st_size = 0
if os.path.isfile(target): if os.path.isfile(target):
@@ -565,12 +571,14 @@ class Api:
stat_result.st_mtime = int(f"{filename.split('_')[0]}") if filename.startswith("1") else 0 stat_result.st_mtime = int(f"{filename.split('_')[0]}") if filename.startswith("1") else 0
headers = { headers = {
"cache-control": "public, max-age=31536000", "cache-control": "public, max-age=31536000",
"content-type": f"image/{ext.replace('jpg', 'jpeg') or 'jpeg'}",
"last-modified": formatdate(stat_result.st_mtime, usegmt=True), "last-modified": formatdate(stat_result.st_mtime, usegmt=True),
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"', "etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"',
**({ **({
"content-length": str(stat_result.st_size), "content-length": str(stat_result.st_size),
} if stat_result.st_size else {}) } if stat_result.st_size else {}),
**({} if mime_type is None else {
"content-type": mime_type,
})
} }
response = FileResponse( response = FileResponse(
target, target,
@@ -584,13 +592,13 @@ class Api:
return NotModifiedResponse(response.headers) return NotModifiedResponse(response.headers)
except KeyError: except KeyError:
pass pass
if not os.path.isfile(target): if not os.path.isfile(target) and mime_type is not None:
source_url = get_source_url(str(request.query_params)) source_url = get_source_url(str(request.query_params))
ssl = None ssl = None
if source_url is None: if source_url is None:
backend_url = os.environ.get("G4F_BACKEND_URL") backend_url = os.environ.get("G4F_BACKEND_URL")
if backend_url: if backend_url:
source_url = f"{backend_url}/images/{filename}" source_url = f"{backend_url}/media/{filename}"
ssl = False ssl = False
if source_url is not None: if source_url is not None:
try: try:

View File

@@ -13,7 +13,7 @@ from ..image.copy_images import copy_media
from ..typing import Messages, ImageType from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import * from ..providers.response import *
from ..errors import NoImageResponseError from ..errors import NoMediaResponseError
from ..providers.retry_provider import IterListProvider from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
@@ -268,6 +268,7 @@ class Client(BaseClient):
super().__init__(**kwargs) super().__init__(**kwargs)
self.chat: Chat = Chat(self, provider) self.chat: Chat = Chat(self, provider)
self.images: Images = Images(self, image_provider) self.images: Images = Images(self, image_provider)
self.media: Images = Images(self, image_provider)
class Completions: class Completions:
def __init__(self, client: Client, provider: Optional[ProviderType] = None): def __init__(self, client: Client, provider: Optional[ProviderType] = None):
@@ -406,7 +407,7 @@ class Images:
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
if isinstance(response, ImageResponse): if isinstance(response, MediaResponse):
return await self._process_image_response( return await self._process_image_response(
response, response,
model, model,
@@ -417,8 +418,8 @@ class Images:
if response is None: if response is None:
if error is not None: if error is not None:
raise error raise error
raise NoImageResponseError(f"No image response from {provider_name}") raise NoMediaResponseError(f"No image response from {provider_name}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}") raise NoMediaResponseError(f"Unexpected response type: {type(response)}")
async def _generate_image_response( async def _generate_image_response(
self, self,
@@ -428,7 +429,7 @@ class Images:
prompt: str, prompt: str,
prompt_prefix: str = "Generate a image: ", prompt_prefix: str = "Generate a image: ",
**kwargs **kwargs
) -> ImageResponse: ) -> MediaResponse:
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}] messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
response = None response = None
if hasattr(provider_handler, "create_async_generator"): if hasattr(provider_handler, "create_async_generator"):
@@ -439,7 +440,7 @@ class Images:
prompt=prompt, prompt=prompt,
**kwargs **kwargs
): ):
if isinstance(item, ImageResponse): if isinstance(item, MediaResponse):
response = item response = item
break break
elif hasattr(provider_handler, "create_completion"): elif hasattr(provider_handler, "create_completion"):
@@ -450,7 +451,7 @@ class Images:
prompt=prompt, prompt=prompt,
**kwargs **kwargs
): ):
if isinstance(item, ImageResponse): if isinstance(item, MediaResponse):
response = item response = item
break break
else: else:
@@ -501,17 +502,17 @@ class Images:
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
if isinstance(response, ImageResponse): if isinstance(response, MediaResponse):
return await self._process_image_response(response, model, provider_name, response_format, proxy) return await self._process_image_response(response, model, provider_name, response_format, proxy)
if response is None: if response is None:
if error is not None: if error is not None:
raise error raise error
raise NoImageResponseError(f"No image response from {provider_name}") raise NoMediaResponseError(f"No image response from {provider_name}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}") raise NoMediaResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response( async def _process_image_response(
self, self,
response: ImageResponse, response: MediaResponse,
model: str, model: str,
provider: str, provider: str,
response_format: Optional[str] = None, response_format: Optional[str] = None,
@@ -533,7 +534,7 @@ class Images:
else: else:
# Save locally for None (default) case # Save locally for None (default) case
images = await copy_media(response.get_list(), response.get("cookies"), proxy) images = await copy_media(response.get_list(), response.get("cookies"), proxy)
images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images] images = [Image.model_construct(url=f"/media/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
return ImagesResponse.model_construct( return ImagesResponse.model_construct(
created=int(time.time()), created=int(time.time()),
@@ -552,6 +553,7 @@ class AsyncClient(BaseClient):
super().__init__(**kwargs) super().__init__(**kwargs)
self.chat: AsyncChat = AsyncChat(self, provider) self.chat: AsyncChat = AsyncChat(self, provider)
self.images: AsyncImages = AsyncImages(self, image_provider) self.images: AsyncImages = AsyncImages(self, image_provider)
self.media: AsyncImages = self.images
class AsyncChat: class AsyncChat:
completions: AsyncCompletions completions: AsyncCompletions

View File

@@ -34,7 +34,7 @@ class NestAsyncioError(MissingRequirementsError):
class MissingAuthError(Exception): class MissingAuthError(Exception):
... ...
class NoImageResponseError(Exception): class NoMediaResponseError(Exception):
... ...
class ResponseError(Exception): class ResponseError(Exception):

View File

@@ -49,7 +49,7 @@
<script src="https://cdn.jsdelivr.net/npm/gpt-tokenizer/dist/cl100k_base.js" async></script> <script src="https://cdn.jsdelivr.net/npm/gpt-tokenizer/dist/cl100k_base.js" async></script>
<script src="https://cdn.jsdelivr.net/npm/gpt-tokenizer/dist/o200k_base.js" async></script> <script src="https://cdn.jsdelivr.net/npm/gpt-tokenizer/dist/o200k_base.js" async></script>
</template> </template>
<script> <script async>
if (localStorage.getItem("countTokens") != "false") { if (localStorage.getItem("countTokens") != "false") {
const template = document.head.querySelector('template'); const template = document.head.querySelector('template');
document.head.appendChild(template.content); document.head.appendChild(template.content);
@@ -61,10 +61,14 @@
const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">'; const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">';
</script> </script>
<script src="/static/js/highlight.min.js" async></script> <script src="/static/js/highlight.min.js" async></script>
<script>window.conversation_id = "{{chat_id}}"</script> <script>window.conversation_id = "{{conversation_id}}"</script>
<script>window.chat_id = "{{chat_id}}"</script>
<title>G4F Chat</title> <title>G4F Chat</title>
</head> </head>
<body> <body>
<script async>
localStorage.getItem("darkMode") == "false" ? document.body.classList.add("white") : null;
</script>
<div class="gradient"></div> <div class="gradient"></div>
<div class="sidebar shown"> <div class="sidebar shown">
<div class="top"> <div class="top">
@@ -118,7 +122,7 @@
<label for="hide-systemPrompt" class="toogle" title="For more space on phones"></label> <label for="hide-systemPrompt" class="toogle" title="For more space on phones"></label>
</div> </div>
<div class="field"> <div class="field">
<span class="label">Download generated images</span> <span class="label">Download generated images, audios and videos</span>
<input type="checkbox" id="download_media" checked/> <input type="checkbox" id="download_media" checked/>
<label for="download_media" class="toogle" title="Download and save generated images, audios and videos"></label> <label for="download_media" class="toogle" title="Download and save generated images, audios and videos"></label>
</div> </div>
@@ -212,7 +216,7 @@
G4F Chat G4F Chat
</div> </div>
<textarea id="chatPrompt" class="box" placeholder="System prompt"></textarea> <textarea id="chatPrompt" class="box" placeholder="System prompt"></textarea>
<button class="slide-systemPrompt"> <button class="slide-header">
<i class="fa-solid fa-angles-up"></i> <i class="fa-solid fa-angles-up"></i>
</button> </button>
<div class="chat-body" id="chatBody"></div> <div class="chat-body" id="chatBody"></div>

View File

@@ -321,8 +321,9 @@ body:not(.white) a:visited{
white-space: pre-wrap; white-space: pre-wrap;
} }
.message .content img{ .message .content img, .message .content video{
max-width: 400px; max-width: 400px;
max-height: 400px;
} }
.message .content .audio{ .message .content .audio{
@@ -667,8 +668,9 @@ input-count .text {
.micro-label { .micro-label {
cursor: pointer; cursor: pointer;
position: absolute; position: absolute;
top: 10px; top: 8px;
left: 10px; left: 8px;
padding: 2px;
} }
.file-label:has(> input:valid), .file-label:has(> input:valid),
@@ -683,11 +685,12 @@ input-count .text {
} }
label.image-label { label.image-label {
top: 32px; top: 30px;
} }
label[for="micro"] { label[for="micro"] {
top: 54px; top: 50px;
padding: 4px;
} }
@media (pointer:none), (pointer:coarse) { @media (pointer:none), (pointer:coarse) {
@@ -872,7 +875,7 @@ input.model:hover
min-height: 59px; min-height: 59px;
height: 59px; height: 59px;
resize: vertical; resize: vertical;
padding: var(--inner-gap) var(--section-gap); padding: var(--inner-gap) 28px;
} }
#systemPrompt, #chatPrompt, .settings textarea, form textarea { #systemPrompt, #chatPrompt, .settings textarea, form textarea {
@@ -935,15 +938,15 @@ input.model:hover
top: auto !important; top: auto !important;
} }
.slide-systemPrompt { .slide-header {
position: absolute; position: absolute;
top: 42px; top: 0;
z-index: 1; z-index: 1;
padding: var(--inner-gap) 10px; padding: 10px;
border: none; border: none;
background: transparent; background: transparent;
cursor: pointer; cursor: pointer;
height: 49px; height: 40px;
color: var(--colour-3); color: var(--colour-3);
} }
@@ -1159,7 +1162,7 @@ ul {
} }
.sidebar.shown { .sidebar.shown {
width: 300px; width: 400px;
padding: 15px; padding: 15px;
margin-right: 10px; margin-right: 10px;
} }
@@ -1475,7 +1478,7 @@ form .field.saved .fa-xmark {
.conversation .user-input, .conversation .user-input,
.conversation .chat-buttons, .conversation .chat-buttons,
.conversation .chat-toolbar, .conversation .chat-toolbar,
.conversation .slide-systemPrompt, .conversation .slide-header,
.message .count i, .message .count i,
.message .assistant, .message .assistant,
.message .user { .message .user {
@@ -1509,7 +1512,7 @@ form .field.saved .fa-xmark {
overflow: hidden; overflow: hidden;
} }
.chat-header { .chat-header {
padding: 10px; padding: 10px 28px;
font-weight: 500; font-weight: 500;
white-space: nowrap; white-space: nowrap;
text-overflow: ellipsis; text-overflow: ellipsis;
@@ -1549,10 +1552,12 @@ form .field.saved .fa-xmark {
.chat-footer .send-buttons button { .chat-footer .send-buttons button {
background: var(--blur-bg); background: var(--blur-bg);
color: white; color: white;
border: none;
padding: 12px 15px; padding: 12px 15px;
margin: 0 10px; margin: 0 10px;
border-radius: 5px; border-radius: 5px;
cursor: pointer; cursor: pointer;
border: 1px dashed #e4d4ffa6; border: 1px dashed #e4d4ffa6;
}
.chat-footer .send-buttons button:hover {
border-style: solid;
} }

View File

@@ -63,8 +63,6 @@ appStorage = window.localStorage || {
length: 0 length: 0
} }
appStorage.getItem("darkMode") == "false" ? document.body.classList.add("white") : null;
let markdown_render = (content) => escapeHtml(content); let markdown_render = (content) => escapeHtml(content);
if (window.markdownit) { if (window.markdownit) {
const markdown = window.markdownit(); const markdown = window.markdownit();
@@ -81,7 +79,7 @@ if (window.markdownit) {
.replaceAll('<code>', '<code class="language-plaintext">') .replaceAll('<code>', '<code class="language-plaintext">')
.replaceAll('&lt;i class=&quot;', '<i class="') .replaceAll('&lt;i class=&quot;', '<i class="')
.replaceAll('&quot;&gt;&lt;/i&gt;', '"></i>') .replaceAll('&quot;&gt;&lt;/i&gt;', '"></i>')
.replaceAll('&lt;video controls src=&quot;', '<video controls width="400" src="') .replaceAll('&lt;video controls src=&quot;', '<video controls loop src="')
.replaceAll('&quot;&gt;&lt;/video&gt;', '"></video>') .replaceAll('&quot;&gt;&lt;/video&gt;', '"></video>')
.replaceAll('&lt;audio controls src=&quot;', '<audio controls src="') .replaceAll('&lt;audio controls src=&quot;', '<audio controls src="')
.replaceAll('&quot;&gt;&lt;/audio&gt;', '"></audio>') .replaceAll('&quot;&gt;&lt;/audio&gt;', '"></audio>')
@@ -581,7 +579,7 @@ stop_generating.addEventListener("click", async () => {
} }
} }
} }
await load_conversation(window.conversation_id, false); await safe_load_conversation(window.conversation_id, false);
}); });
document.querySelector(".media-player .fa-x").addEventListener("click", ()=>{ document.querySelector(".media-player .fa-x").addEventListener("click", ()=>{
@@ -831,7 +829,7 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
for (const [key, value] of Object.entries(message.conversation)) { for (const [key, value] of Object.entries(message.conversation)) {
conversation.data[key] = value; conversation.data[key] = value;
} }
await save_conversation(conversation_id, conversation); await save_conversation(conversation_id, conversation, false);
} else if (message.type == "provider") { } else if (message.type == "provider") {
provider_storage[message_id] = message.provider; provider_storage[message_id] = message.provider;
let provider_el = content_map.content.querySelector('.provider'); let provider_el = content_map.content.querySelector('.provider');
@@ -1289,6 +1287,7 @@ const delete_conversation = async (conversation_id) => {
}; };
const set_conversation = async (conversation_id) => { const set_conversation = async (conversation_id) => {
window.chat_id = null;
if (title_ids_storage[conversation_id]) { if (title_ids_storage[conversation_id]) {
conversation_id = title_ids_storage[conversation_id]; conversation_id = title_ids_storage[conversation_id];
} }
@@ -1300,7 +1299,7 @@ const set_conversation = async (conversation_id) => {
window.conversation_id = conversation_id; window.conversation_id = conversation_id;
await clear_conversation(); await clear_conversation();
await load_conversation(conversation_id); await load_conversation(await get_conversation(conversation_id));
load_conversations(); load_conversations();
hide_sidebar(true); hide_sidebar(true);
}; };
@@ -1364,15 +1363,14 @@ function merge_messages(message1, message2) {
// console.log(merge_messages("1 != 2", "```python\n1 != 2;")); // console.log(merge_messages("1 != 2", "```python\n1 != 2;"));
// console.log(merge_messages("1 != 2;\n1 != 3;\n", "1 != 2;\n1 != 3;\n")); // console.log(merge_messages("1 != 2;\n1 != 3;\n", "1 != 2;\n1 != 3;\n"));
const load_conversation = async (conversation_id, scroll=true) => { const load_conversation = async (conversation, scroll=true) => {
let conversation = await get_conversation(conversation_id);
let messages = conversation?.items || [];
console.debug("Conversation:", conversation)
if (!conversation) { if (!conversation) {
return; return;
} }
let title = conversation.title || conversation.new_title; let messages = conversation?.items || [];
console.debug("Conversation:", conversation.id)
let title = conversation.new_title || conversation.title;
title = title ? `${title} - G4F` : window.title; title = title ? `${title} - G4F` : window.title;
if (title) { if (title) {
document.title = title; document.title = title;
@@ -1550,7 +1548,8 @@ async function safe_load_conversation(conversation_id, scroll=true) {
} }
} }
if (!is_running) { if (!is_running) {
return await load_conversation(conversation_id, scroll); let conversation = await get_conversation(conversation_id);
return await load_conversation(conversation, scroll);
} }
} }
@@ -1563,9 +1562,10 @@ async function get_conversation(conversation_id) {
async function save_conversation(conversation_id, conversation) { async function save_conversation(conversation_id, conversation) {
conversation.updated = Date.now(); conversation.updated = Date.now();
const data = JSON.stringify(conversation)
appStorage.setItem( appStorage.setItem(
`conversation:${conversation_id}`, `conversation:${conversation_id}`,
JSON.stringify(conversation) data
); );
} }
@@ -1617,6 +1617,14 @@ const remove_message = async (conversation_id, index) => {
} }
conversation.items = new_items; conversation.items = new_items;
await save_conversation(conversation_id, conversation); await save_conversation(conversation_id, conversation);
if (window.chat_id) {
const url = `/backend-api/v2/chat/${window.chat_id}`;
response = await fetch(url, {
method: 'POST',
headers: {'content-type': 'application/json'},
body: data,
});
}
}; };
const get_message = async (conversation_id, index) => { const get_message = async (conversation_id, index) => {
@@ -1685,6 +1693,14 @@ const add_message = async (
conversation.items = new_messages; conversation.items = new_messages;
} }
await save_conversation(conversation_id, conversation); await save_conversation(conversation_id, conversation);
if (window.chat_id) {
const url = `/backend-api/v2/chat/${window.chat_id}`;
fetch(url, {
method: 'POST',
headers: {'content-type': 'application/json'},
body: JSON.stringify(conversation),
});
}
return conversation.items.length - 1; return conversation.items.length - 1;
}; };
@@ -2021,12 +2037,56 @@ chatPrompt.addEventListener("input", function() {
}); });
window.addEventListener('load', async function() { window.addEventListener('load', async function() {
if (!window.conversation_id) {
window.conversation_id = window.chat_id;
}
const response = await fetch(`/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, {
headers: {'accept': 'application/json'},
});
if (response.ok) {
let conversation = await response.json();
if (window.chat_id && (!window.conversation_id || conversation.id == window.conversation_id)) {
window.conversation_id = conversation.id;
await load_conversation(conversation);
appStorage.setItem(
`conversation:${conversation.id}`,
JSON.stringify(conversation)
);
let refreshOnHide = true;
document.addEventListener("visibilitychange", () => {
if (document.hidden) {
refreshOnHide = false;
} else {
refreshOnHide = true;
}
});
return setInterval(async () => {
if (!refreshOnHide || !window.chat_id) {
return;
}
const response = await fetch(`/backend-api/v2/chat/${window.chat_id}`, {
headers: {'accept': 'application/json', 'if-none-match': conversation.updated},
});
if (response.status == 200) {
const new_conversation = await response.json();
if (conversation.id == window.conversation_id && new_conversation.updated != conversation.updated) {
conversation = new_conversation;
appStorage.setItem(
`conversation:${conversation.id}`,
JSON.stringify(conversation)
);
await load_conversation(conversation);
}
}
}, 5000);
}
}
await safe_load_conversation(window.conversation_id, false); await safe_load_conversation(window.conversation_id, false);
}); });
window.addEventListener('DOMContentLoaded', async function() { window.addEventListener('DOMContentLoaded', async function() {
await on_load(); await on_load();
if (window.conversation_id == "{{chat_id}}") { if (!window.conversation_id == "{{chat_id}}") {
window.conversation_id = uuid(); window.conversation_id = uuid();
} else { } else {
await on_api(); await on_api();
@@ -2289,11 +2349,9 @@ async function on_api() {
); );
const hide_systemPrompt = document.getElementById("hide-systemPrompt") const hide_systemPrompt = document.getElementById("hide-systemPrompt")
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i"); const slide_systemPrompt_icon = document.querySelector(".slide-header i");
if (hide_systemPrompt.checked) { if (hide_systemPrompt.checked) {
chatPrompt.classList.add("hidden"); chatPrompt.classList.add("hidden");
slide_systemPrompt_icon.classList.remove("fa-angles-up");
slide_systemPrompt_icon.classList.add("fa-angles-down");
} }
hide_systemPrompt.addEventListener('change', async (event) => { hide_systemPrompt.addEventListener('change', async (event) => {
if (event.target.checked) { if (event.target.checked) {
@@ -2302,10 +2360,10 @@ async function on_api() {
chatPrompt.classList.remove("hidden"); chatPrompt.classList.remove("hidden");
} }
}); });
document.querySelector(".slide-systemPrompt")?.addEventListener("click", () => { document.querySelector(".slide-header")?.addEventListener("click", () => {
hide_systemPrompt.click(); const checked = slide_systemPrompt_icon.classList.contains("fa-angles-up");
const checked = hide_systemPrompt.checked; document.querySelector(".chat-header").classList[checked ? "add": "remove"]("hidden");
chatPrompt.classList[checked ? "add": "remove"]("hidden"); chatPrompt.classList[checked || hide_systemPrompt.checked ? "add": "remove"]("hidden");
slide_systemPrompt_icon.classList[checked ? "remove": "add"]("fa-angles-up"); slide_systemPrompt_icon.classList[checked ? "remove": "add"]("fa-angles-up");
slide_systemPrompt_icon.classList[checked ? "add": "remove"]("fa-angles-down"); slide_systemPrompt_icon.classList[checked ? "add": "remove"]("fa-angles-down");
}); });
@@ -2361,7 +2419,6 @@ async function load_version() {
} }
function renderMediaSelect() { function renderMediaSelect() {
mediaSelect.classList.remove("hidden");
const oldImages = mediaSelect.querySelectorAll("a:has(img)"); const oldImages = mediaSelect.querySelectorAll("a:has(img)");
oldImages.forEach((el)=>el.remove()); oldImages.forEach((el)=>el.remove());
Object.entries(image_storage).forEach(([object_url, file]) => { Object.entries(image_storage).forEach(([object_url, file]) => {
@@ -2472,10 +2529,15 @@ function connectToSSE(url, do_refine, bucket_id) {
inputCount.innerText = `Download: ${data.count} files`; inputCount.innerText = `Download: ${data.count} files`;
} else if (data.action == "done") { } else if (data.action == "done") {
if (do_refine) { if (do_refine) {
do_refine = false; connectToSSE(`/backend-api/v2/files/${bucket_id}?refine_chunks_with_spacy=true`, false, bucket_id);
connectToSSE(`/backend-api/v2/files/${bucket_id}?refine_chunks_with_spacy=true`, do_refine, bucket_id);
return; return;
} }
fileInput.value = "";
paperclip.classList.remove("blink");
if (!data.size) {
inputCount.innerText = "No content found";
return
}
appStorage.setItem(`bucket:${bucket_id}`, data.size); appStorage.setItem(`bucket:${bucket_id}`, data.size);
inputCount.innerText = "Files are loaded successfully"; inputCount.innerText = "Files are loaded successfully";
if (!userInput.value) { if (!userInput.value) {
@@ -2483,8 +2545,6 @@ function connectToSSE(url, do_refine, bucket_id) {
handle_ask(false); handle_ask(false);
} else { } else {
userInput.value += (userInput.value ? "\n" : "") + JSON.stringify({bucket_id: bucket_id}) + "\n"; userInput.value += (userInput.value ? "\n" : "") + JSON.stringify({bucket_id: bucket_id}) + "\n";
paperclip.classList.remove("blink");
fileInput.value = "";
} }
} }
}; };
@@ -2518,9 +2578,10 @@ async function upload_files(fileInput) {
} }
if (result.media) { if (result.media) {
result.media.forEach((filename)=> { result.media.forEach((filename)=> {
const url = `/backend-api/v2/files/${bucket_id}/media/${filename}`; const url = `/files/${bucket_id}/media/${filename}`;
image_storage[url] = {bucket_id: bucket_id, name: filename}; image_storage[url] = {bucket_id: bucket_id, name: filename};
}); });
mediaSelect.classList.remove("hidden");
renderMediaSelect(); renderMediaSelect();
} }
} }

View File

@@ -58,6 +58,7 @@ class Backend_Api(Api):
app (Flask): Flask application instance to attach routes to. app (Flask): Flask application instance to attach routes to.
""" """
self.app: Flask = app self.app: Flask = app
self.chat_cache = {}
if app.demo: if app.demo:
@app.route('/', methods=['GET']) @app.route('/', methods=['GET'])
@@ -210,6 +211,10 @@ class Backend_Api(Api):
'/images/<path:name>': { '/images/<path:name>': {
'function': self.serve_images, 'function': self.serve_images,
'methods': ['GET'] 'methods': ['GET']
},
'/media/<path:name>': {
'function': self.serve_images,
'methods': ['GET']
} }
} }
@@ -359,6 +364,33 @@ class Backend_Api(Api):
return "File saved", 200 return "File saved", 200
return 'Not supported file', 400 return 'Not supported file', 400
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET'])
def get_chat(chat_id: str) -> str:
chat_id = secure_filename(chat_id)
if int(self.chat_cache.get(chat_id, -1)) == int(request.headers.get("if-none-match", 0)):
return jsonify({"error": {"message": "Not modified"}}), 304
bucket_dir = get_bucket_dir(chat_id)
file = os.path.join(bucket_dir, "chat.json")
if not os.path.isfile(file):
return jsonify({"error": {"message": "Not found"}}), 404
with open(file, 'r') as f:
chat_data = json.load(f)
if int(chat_data.get("updated", 0)) == int(request.headers.get("if-none-match", 0)):
return jsonify({"error": {"message": "Not modified"}}), 304
self.chat_cache[chat_id] = chat_data.get("updated", 0)
return jsonify(chat_data), 200
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST'])
def upload_chat(chat_id: str) -> dict:
chat_data = {**request.json}
chat_id = secure_filename(chat_id)
bucket_dir = get_bucket_dir(chat_id)
os.makedirs(bucket_dir, exist_ok=True)
with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
json.dump(chat_data, f)
self.chat_cache[chat_id] = chat_data.get("updated", 0)
return {"chat_id": chat_id}
def handle_synthesize(self, provider: str): def handle_synthesize(self, provider: str):
try: try:
provider_handler = convert_to_provider(provider) provider_handler = convert_to_provider(provider)

View File

@@ -16,6 +16,14 @@ class Website:
'function': self._chat, 'function': self._chat,
'methods': ['GET', 'POST'] 'methods': ['GET', 'POST']
}, },
'/chat/<chat_id>/': {
'function': self._chat_id,
'methods': ['GET', 'POST']
},
'/chat/<chat_id>/<conversation_id>': {
'function': self._chat_id,
'methods': ['GET', 'POST']
},
'/chat/menu/': { '/chat/menu/': {
'function': redirect_home, 'function': redirect_home,
'methods': ['GET', 'POST'] 'methods': ['GET', 'POST']
@@ -32,11 +40,14 @@ class Website:
def _chat(self, conversation_id): def _chat(self, conversation_id):
if conversation_id == "share": if conversation_id == "share":
return render_template('index.html', chat_id=str(uuid.uuid4())) return render_template('index.html', conversation_id=str(uuid.uuid4()))
return render_template('index.html', chat_id=conversation_id) return render_template('index.html', conversation_id=conversation_id)
def _chat_id(self, chat_id, conversation_id: str = ""):
return render_template('index.html', chat_id=chat_id, conversation_id=conversation_id)
def _index(self): def _index(self):
return render_template('index.html', chat_id=str(uuid.uuid4())) return render_template('index.html', conversation_id=str(uuid.uuid4()))
def _settings(self): def _settings(self):
return render_template('index.html', chat_id=str(uuid.uuid4())) return render_template('index.html', conversation_id=str(uuid.uuid4()))

View File

@@ -26,13 +26,32 @@ ALLOWED_EXTENSIONS = {
'mkv', 'webm', 'mp4' 'mkv', 'webm', 'mp4'
} }
EXTENSIONS_MAP: dict[str, str] = { MEDIA_TYPE_MAP: dict[str, str] = {
"image/png": "png", "image/png": "png",
"image/jpeg": "jpg", "image/jpeg": "jpg",
"image/gif": "gif", "image/gif": "gif",
"image/webp": "webp", "image/webp": "webp",
} }
EXTENSIONS_MAP: dict[str, str] = {
# Image
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
# Audio
"wav": "audio/wav",
"mp3": "audio/mpeg",
"flac": "audio/flac",
"opus": "audio/opus",
"ogg": "audio/ogg",
# Video
"mkv": "video/x-matroska",
"webm": "video/webm",
"mp4": "video/mp4",
}
def to_image(image: ImageType, is_svg: bool = False) -> Image: def to_image(image: ImageType, is_svg: bool = False) -> Image:
""" """
Converts the input image to a PIL Image object. Converts the input image to a PIL Image object.

View File

@@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector, StreamResponse from ..requests.aiohttp import get_connector, StreamResponse
from ..image import EXTENSIONS_MAP, ALLOWED_EXTENSIONS from ..image import MEDIA_TYPE_MAP, ALLOWED_EXTENSIONS
from ..tools.files import get_bucket_dir from ..tools.files import get_bucket_dir
from ..providers.response import ImageResponse, AudioResponse, VideoResponse from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi from ..Provider.template import BackendApi
@@ -51,10 +51,13 @@ def secure_filename(filename: str) -> str:
filename = filename[:100].strip(".,_-") filename = filename[:100].strip(".,_-")
return filename return filename
def is_valid_media_type(content_type: str) -> bool:
return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/")
async def save_response_media(response: StreamResponse, prompt: str): async def save_response_media(response: StreamResponse, prompt: str):
content_type = response.headers["content-type"] content_type = response.headers["content-type"]
if content_type in EXTENSIONS_MAP or content_type.startswith("audio/") or content_type.startswith("video/"): if is_valid_media_type(content_type):
extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3") extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
if extension not in ALLOWED_EXTENSIONS: if extension not in ALLOWED_EXTENSIONS:
raise ValueError(f"Unsupported media type: {content_type}") raise ValueError(f"Unsupported media type: {content_type}")
bucket_id = str(uuid.uuid4()) bucket_id = str(uuid.uuid4())
@@ -131,6 +134,8 @@ async def copy_media(
async with session.get(image, ssl=request_ssl, headers=request_headers) as response: async with session.get(image, ssl=request_ssl, headers=request_headers) as response:
response.raise_for_status() response.raise_for_status()
if not is_valid_media_type(response.headers.get("content-type")):
raise ValueError(f"Unsupported media type: {response.headers.get('content-type')}")
with open(target_path, "wb") as f: with open(target_path, "wb") as f:
async for chunk in response.content.iter_chunked(4096): async for chunk in response.content.iter_chunked(4096):
f.write(chunk) f.write(chunk)
@@ -150,9 +155,9 @@ async def copy_media(
# Build URL with safe encoding # Build URL with safe encoding
url_filename = quote(os.path.basename(target_path)) url_filename = quote(os.path.basename(target_path))
return f"/images/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '') return f"/media/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
except (ClientError, IOError, OSError) as e: except (ClientError, IOError, OSError, ValueError) as e:
debug.error(f"Image copying failed: {type(e).__name__}: {e}") debug.error(f"Image copying failed: {type(e).__name__}: {e}")
if target_path and os.path.exists(target_path): if target_path and os.path.exists(target_path):
os.unlink(target_path) os.unlink(target_path)