Add example for video generation

Add support for images in messages
This commit is contained in:
hlohaus
2025-03-27 09:38:31 +01:00
parent db1cfc48bc
commit 46d0b87008
26 changed files with 410 additions and 230 deletions

View File

@@ -21,6 +21,7 @@ The G4F AsyncClient API is designed to be compatible with the OpenAI API, making
- [Using a Vision Model](#using-a-vision-model)
- **[Transcribing Audio with Chat Completions](#transcribing-audio-with-chat-completions)** *(New Section)*
- [Image Generation](#image-generation)
- **[Video Generation](#video-generation)** *(New Section)*
- [Advanced Usage](#advanced-usage)
- [Conversation Memory](#conversation-memory)
- [Search Tool Support](#search-tool-support)
@@ -327,6 +328,46 @@ asyncio.run(main())
---
### Video Generation
The G4F `AsyncClient` also supports **video generation** through supported providers like `HuggingFaceMedia`. You can retrieve the list of available video models and generate videos from prompts.
**Example: Generate a video using a prompt**
```python
import asyncio
from g4f.client import AsyncClient
from g4f.Provider import HuggingFaceMedia
async def main():
client = AsyncClient(
provider=HuggingFaceMedia,
api_key="hf_***" # Your API key here
)
# Get available video models
video_models = client.models.get_video()
print("Available Video Models:", video_models)
# Generate video
result = await client.media.generate(
model=video_models[0],
prompt="G4F AI technology is the best in the world.",
response_format="url"
)
print("Generated Video URL:", result.data[0].url)
asyncio.run(main())
```
#### Explanation
- **Client Initialization**: An `AsyncClient` is initialized using the `HuggingFaceMedia` provider with an API key.
- **Model Discovery**: `client.models.get_video()` fetches a list of supported video models.
- **Video Generation**: A prompt is submitted to generate a video using `await client.media.generate(...)`.
- **Output**: The result includes a URL to the generated video, accessed via `result.data[0].url`.
> Make sure your selected provider supports media generation and your API key has appropriate permissions.
## Advanced Usage

View File

@@ -16,9 +16,9 @@ from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import to_data_uri
from ..cookies import get_cookies_dir
from .helper import format_prompt, format_image_prompt
from .helper import format_image_prompt
from ..providers.response import JsonConversation, ImageResponse
from ..errors import ModelNotSupportedError
from ..tools.media import merge_media
from .. import debug
class Conversation(JsonConversation):
@@ -488,7 +488,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
"filePath": f"/{image_name}",
"contents": to_data_uri(image)
}
for image, image_name in media
for image, image_name in merge_media(media, messages)
],
"fileText": "",
"title": ""

View File

@@ -24,8 +24,9 @@ from .openai.har_file import get_headers, get_har_files
from ..typing import CreateResult, Messages, MediaListType
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
from ..requests.raise_for_status import raise_for_status
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, ImageResponse
from ..providers.asyncio import get_running_loop
from ..tools.media import merge_media
from ..requests import get_nodriver
from ..image import to_bytes, is_accepted_format
from .helper import get_last_user_message
@@ -142,17 +143,18 @@ class Copilot(AbstractProvider, ProviderModelMixin):
debug.log(f"Copilot: Use conversation: {conversation_id}")
uploaded_images = []
if media is not None:
for image, _ in media:
data = to_bytes(image)
media, _ = [(None, None), *merge_media(media, messages)].pop()
if media:
if not isinstance(media, str):
data = to_bytes(media)
response = session.post(
"https://copilot.microsoft.com/c/api/attachments",
headers={"content-type": is_accepted_format(data)},
data=data
)
raise_for_status(response)
uploaded_images.append({"type":"image", "url": response.json().get("url")})
break
media = response.json().get("url")
uploaded_images.append({"type":"image", "url": media})
wss = session.ws_connect(cls.websocket_url)
# if clarity_token is not None:

View File

@@ -11,13 +11,14 @@ from aiohttp import ClientSession
from .helper import filter_none, format_image_prompt
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, MediaListType
from ..image import to_data_uri, is_data_an_audio, to_input_audio
from ..image import is_data_an_audio
from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from ..image.copy_images import save_response_media
from ..image import use_aspect_ratio
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse
from ..tools.media import render_messages
from .. import debug
DEFAULT_HEADERS = {
@@ -285,32 +286,15 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if response_format and response_format.get("type") == "json_object":
json_mode = True
if media and messages:
last_message = messages[-1].copy()
image_content = [
{
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
]
last_message["content"] = image_content + ([{"type": "text", "text": last_message["content"]}] if isinstance(last_message["content"], str) else image_content)
messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
if model in cls.audio_models:
#data["voice"] = random.choice(cls.audio_models[model])
url = cls.text_api_endpoint
stream = False
else:
url = cls.openai_endpoint
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
data = filter_none(**{
"messages": messages,
"messages": list(render_messages(messages, media)),
"model": model,
"temperature": temperature,
"presence_penalty": presence_penalty,
@@ -324,7 +308,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
})
async with session.post(url, json=data) as response:
await raise_for_status(response)
async for chunk in save_response_media(response, messages[-1]["content"], [model]):
async for chunk in save_response_media(response, format_image_prompt(messages), [model]):
yield chunk
return
if response.headers["content-type"].startswith("text/plain"):

View File

@@ -24,6 +24,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
from ...cookies import get_cookies
from ...tools.media import merge_media
from .models import default_model, default_vision_model, fallback_models, image_models, model_aliases
from ... import debug
@@ -146,8 +147,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
}
data = CurlMime()
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
if media is not None:
for image, filename in media:
for image, filename in merge_media(media, messages):
data.addpart(
"files",
filename=f"base64;{filename}",

View File

@@ -142,20 +142,30 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
}
else:
extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio)
if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}"
data = {
"prompt": prompt,
"image_size": "square_hd",
**extra_data
}
if provider_key == "fal-ai" and task == "text-to-image":
if aspect_ratio is None or aspect_ratio == "1:1":
image_size = "square_hd",
elif aspect_ratio == "16:9":
image_size = "landscape_hd",
elif aspect_ratio == "9:16":
image_size = "portrait_16_9"
else:
image_size = extra_data # width, height
data = {
"image_size": image_size,
**data
}
elif provider_key == "novita":
url = f"{api_base}/v3/hf/{provider_id}"
elif provider_key == "replicate":
url = f"{api_base}/v1/models/{provider_id}/predictions"
data = {
"input": {
"prompt": prompt,
**extra_data
}
"input": data
}
elif provider_key in ("hf-inference", "hf-free"):
api_base = "https://api-inference.huggingface.co"
@@ -171,9 +181,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
url = f"{api_base}/v1/images/generations"
data = {
"response_format": "url",
"prompt": prompt,
"model": provider_id,
**extra_data
**data
}
async with StreamSession(
@@ -193,7 +202,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
return provider_info, chunk
result = await response.json()
if "video" in result:
return provider_info, VideoResponse(result["video"]["url"], prompt)
return provider_info, VideoResponse(result.get("video").get("url", result.get("video").get("url")), prompt)#video_url
elif task == "text-to-image":
return provider_info, ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
elif task == "text-to-video":

View File

@@ -20,7 +20,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
supports_message_history = True
@classmethod
def get_models(cls) -> list[str]:
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
cls.models = HuggingFaceInference.get_models()
cls.image_models = HuggingFaceInference.image_models

View File

@@ -27,6 +27,7 @@ from ...requests import get_nodriver
from ...errors import MissingAuthError
from ...image import to_bytes
from ...cookies import get_cookies_dir
from ...tools.media import merge_media
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies, get_last_user_message
from ... import debug
@@ -186,7 +187,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
cls.start_auto_refresh()
)
uploads = None if media is None else await cls.upload_images(base_connector, media)
uploads = await cls.upload_images(base_connector, merge_media(media, messages))
async with ClientSession(
cookies=cls._cookies,
headers=REQUEST_HEADERS,

View File

@@ -25,7 +25,8 @@ from ...requests import get_nodriver
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse
from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters, Reasoning
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
from ...tools.media import merge_media
from ..helper import format_cookies, get_last_user_message
from ..openai.models import default_model, default_image_model, models, image_models, text_models
from ..openai.har_file import get_request_config
@@ -187,8 +188,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await raise_for_status(response, "Get download url failed")
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)
if not media:
return
return [await upload_image(image, image_name) for image, image_name in media]
@classmethod
@@ -330,7 +329,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls._update_request_args(auth_result, session)
await raise_for_status(response)
try:
image_requests = None if media is None else await cls.upload_images(session, auth_result, media)
image_requests = await cls.upload_images(session, auth_result, merge_media(media, messages))
except Exception as e:
debug.error("OpenaiChat: Upload image failed")
debug.error(e)

View File

@@ -7,8 +7,8 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErr
from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
from ...tools.media import render_messages
from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri, is_data_an_audio, to_input_audio
from ... import debug
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
@@ -97,27 +97,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
yield ImageResponse([image["url"] for image in data["data"]], prompt)
return
if media is not None and messages:
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
last_message = messages[-1].copy()
image_content = [
{
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
]
last_message["content"] = image_content + ([{"type": "text", "text": last_message["content"]}] if isinstance(last_message["content"], str) else image_content)
messages[-1] = last_message
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
data = filter_none(
messages=messages,
messages=list(render_messages(messages, media)),
model=model,
temperature=temperature,
max_tokens=max_tokens,

View File

@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
from .image_models import MediaModels
from .models import ClientModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose
@@ -269,7 +269,7 @@ class Client(BaseClient):
self.chat: Chat = Chat(self, provider)
if image_provider is None:
image_provider = provider
self.models: MediaModels = MediaModels(self, image_provider)
self.models: ClientModels = ClientModels(self, provider, image_provider)
self.images: Images = Images(self, image_provider)
self.media: Images = self.images
@@ -558,7 +558,7 @@ class AsyncClient(BaseClient):
self.chat: AsyncChat = AsyncChat(self, provider)
if image_provider is None:
image_provider = provider
self.models: MediaModels = MediaModels(self, image_provider)
self.models: ClientModels = ClientModels(self, provider, image_provider)
self.images: AsyncImages = AsyncImages(self, image_provider)
self.media: AsyncImages = self.images

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
from ..models import ModelUtils, ImageModel
from ..Provider import ProviderUtils
from ..providers.types import ProviderType
class MediaModels():
def __init__(self, client, provider: ProviderType = None):
self.client = client
self.provider = provider
def get(self, name, default=None) -> ProviderType:
if name in ModelUtils.convert:
return ModelUtils.convert[name].best_provider
if name in ProviderUtils.convert:
return ProviderUtils.convert[name]
return default
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
if self.provider is None:
return []
if api_key is None:
api_key = self.client.api_key
return self.provider.get_models(
**kwargs,
**{} if api_key is None else {"api_key": api_key}
)
def get_image(self, **kwargs) -> list[str]:
if self.provider is None:
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
self.get_all(**kwargs)
if hasattr(self.provider, "image_models"):
return self.provider.image_models
return []
def get_video(self, **kwargs) -> list[str]:
if self.provider is None:
return []
self.get_all(**kwargs)
if hasattr(self.provider, "video_models"):
return self.provider.video_models
return []

62
g4f/client/models.py Normal file
View File

@@ -0,0 +1,62 @@
from __future__ import annotations
from ..models import ModelUtils, ImageModel, VisionModel
from ..Provider import ProviderUtils
from ..providers.types import ProviderType
class ClientModels():
def __init__(self, client, provider: ProviderType = None, media_provider: ProviderType = None):
self.client = client
self.provider = provider
self.media_provider = media_provider
def get(self, name, default=None) -> ProviderType:
if name in ModelUtils.convert:
return ModelUtils.convert[name].best_provider
if name in ProviderUtils.convert:
return ProviderUtils.convert[name]
return default
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
if self.provider is None:
return []
if api_key is None:
api_key = self.client.api_key
return self.provider.get_models(
**kwargs,
**{} if api_key is None else {"api_key": api_key}
)
def get_vision(self, **kwargs) -> list[str]:
if self.provider is None:
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, VisionModel)]
self.get_all(**kwargs)
if hasattr(self.provider, "vision_models"):
return self.provider.vision_models
return []
def get_media(self, api_key: str = None, **kwargs) -> list[str]:
if self.media_provider is None:
return []
if api_key is None:
api_key = self.client.api_key
return self.media_provider.get_models(
**kwargs,
**{} if api_key is None else {"api_key": api_key}
)
def get_image(self, **kwargs) -> list[str]:
if self.media_provider is None:
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
self.get_media(**kwargs)
if hasattr(self.media_provider, "image_models"):
return self.media_provider.image_models
return []
def get_video(self, **kwargs) -> list[str]:
if self.media_provider is None:
return []
self.get_media(**kwargs)
if hasattr(self.media_provider, "video_models"):
return self.media_provider.video_models
return []

View File

@@ -89,7 +89,7 @@
</head>
<body>
<img id="image-feed" class="hidden" alt="Image Feed">
<video id="video-feed" class="hidden" alt="Video Feed" src="/search/video" autoplay></video>
<video id="video-feed" class="hidden" alt="Video Feed" src="/search/video+g4f" autoplay></video>
<!-- Gradient Background Circle -->
<div class="gradient"></div>
@@ -105,6 +105,7 @@
let skipImage = 0;
let errorVideo = 0;
let errorImage = 0;
let skipRefresh = 0;
videoFeed.onloadeddata = () => {
videoFeed.classList.remove("hidden");
gradient.classList.add("hidden");
@@ -116,15 +117,15 @@
gradient.classList.remove("hidden");
return;
}
videoFeed.src = "/search/video?skip=" + skipVideo;
videoFeed.src = "/search/video+g4f?skip=" + skipVideo;
skipVideo++;
};
videoFeed.onended = () => {
videoFeed.src = "/search/video?skip=" + skipVideo;
videoFeed.src = "/search/video+g4f?skip=" + skipVideo;
skipVideo++;
};
videoFeed.onclick = () => {
videoFeed.src = "/search/video?skip=" + skipVideo;
videoFeed.src = "/search/video+g4f?skip=" + skipVideo;
skipVideo++;
};
function initES() {
@@ -173,11 +174,15 @@
skipImage++;
return;
}
if (skipRefresh) {
skipRefresh = 0;
return;
}
if (images.length > 0) {
imageFeed.classList.remove("hidden");
imageFeed.src = images.shift();
gradient.classList.add("hidden");
} else if(imageFeed) {
} else {
initES();
}
}, 7000);
@@ -192,6 +197,7 @@
};
imageFeed.onclick = () => {
imageFeed.src = "/search/image?random=" + Math.random();
skipRefresh = 1;
};
})();
</script>

View File

@@ -81,16 +81,19 @@
border: none;
}
#background, #image-feed {
#background {
height: 100%;
position: absolute;
z-index: -1;
object-fit: cover;
object-position: center;
width: 100%;
background: black;
}
.container * {
z-index: 2;
}
.description, form p a {
font-size: 1.2rem;
margin-bottom: 30px;
@@ -176,9 +179,6 @@
<body>
<iframe id="background" src="/background"></iframe>
<!-- Gradient Background Circle -->
<div class="gradient"></div>
<button class="slide-button">
<i class="fa-solid fa-arrow-left"></i>
</button>

View File

@@ -48,7 +48,6 @@
align-items: center;
height: 100%;
text-align: center;
z-index: 1;
}
header {
@@ -67,7 +66,11 @@
#background {
height: 100%;
position: absolute;
z-index: -1;
top: 0;
}
.container * {
z-index: 2;
}
.stream-widget {

View File

@@ -270,7 +270,7 @@
<i class="fa-regular fa-image"></i>
</label>
<label class="file-label" for="file">
<input type="file" id="file" name="file" accept=".txt, .html, .xml, .json, .js, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md, .pdf, .docx, .odt, .epub, .xlsx, .zip" required multiple/>
<input type="file" id="file" name="file" accept="*/*" required multiple/>
<i class="fa-solid fa-paperclip"></i>
</label>
<label class="micro-label" for="micro">

View File

@@ -72,9 +72,15 @@
}
}
});
let share_id = null;
document.getElementById('generateQRCode').addEventListener('click', async () => {
const share_id = generate_uuid();
if (share_id) {
const delete_url = `${share_url}/backend-api/v2/files/${encodeURI(share_id)}`;
await fetch(delete_url, {
method: 'DELETE'
});
}
share_id = generate_uuid();
const url = `${share_url}/backend-api/v2/chat/${encodeURI(share_id)}`;
const response = await fetch(url, {

View File

@@ -67,6 +67,17 @@ let markdown_render = (content) => escapeHtml(content);
if (window.markdownit) {
const markdown = window.markdownit();
markdown_render = (content) => {
if (Array.isArray(content)) {
content = content.map((item) => {
if (item.name.endsWith(".wav") || item.name.endsWith(".mp3")) {
return `<audio controls src="${item.url}"></audio>`;
}
if (item.name.endsWith(".mp4") || item.name.endsWith(".webm")) {
return `<video controls src="${item.url}"></video>`;
}
return `[![${item.name}](${item.url})]()`;
}).join("\n");
}
return markdown.render(content
.replaceAll(/<!-- generated images start -->|<!-- generated images end -->/gm, "")
.replaceAll(/<img data-prompt="[^>]+">/gm, "")
@@ -95,7 +106,7 @@ function render_reasoning(reasoning, final = false) {
return `<div class="reasoning_body">
<div class="reasoning_title">
<strong>${reasoning.label ? reasoning.label :'Reasoning <i class="brain">🧠</i>'}: </strong>
${reasoning.status ? escapeHtml(reasoning.status) : '&nbsp;<i class="fas fa-spinner fa-spin"></i>'}
${reasoning.status ? escapeHtml(reasoning.status) : '<i class="fas fa-spinner fa-spin"></i>'}
</div>
${inner_text}
</div>`;
@@ -106,12 +117,18 @@ function render_reasoning_text(reasoning) {
}
function filter_message(text) {
if (Array.isArray(text)) {
return text;
}
return text.replaceAll(
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
).replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
}
function filter_message_content(text) {
if (Array.isArray(text)) {
return text;
}
return text.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "")
}
@@ -269,11 +286,12 @@ const register_message_buttons = async () => {
return
}
el.dataset.click = true;
const provider_forms = document.querySelector(".provider_forms");
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
const provider_link = el.querySelector("a");
provider_link?.addEventListener("click", async (event) => {
event.preventDefault();
await load_provider_parameters(el.dataset.provider);
const provider_forms = document.querySelector(".provider_forms");
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
if (provider_form) {
provider_form.classList.remove("hidden");
provider_forms.classList.remove("hidden");
@@ -281,11 +299,6 @@ const register_message_buttons = async () => {
}
return false;
});
document.getElementById("close_provider_forms").addEventListener("click", async () => {
provider_form.classList.add("hidden");
provider_forms.classList.add("hidden");
chat.classList.remove("hidden");
});
});
chatBody.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
@@ -479,23 +492,24 @@ const delete_conversations = async () => {
await new_conversation();
};
const handle_ask = async (do_ask_gpt = true) => {
const handle_ask = async (do_ask_gpt = true, message = null) => {
userInput.style.height = "82px";
userInput.focus();
await scroll_to_bottom();
let message = userInput.value.trim();
if (message.length <= 0) {
if (!message) {
message = userInput.value.trim();
if (!message) {
return;
}
userInput.value = "";
await count_input()
await add_conversation(window.conversation_id);
}
// Is message a url?
const expression = /^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$/gi;
const regex = new RegExp(expression);
if (message.match(regex)) {
if (!Array.isArray(message) && message.match(regex)) {
paperclip.classList.add("blink");
const blob = new Blob([JSON.stringify([{url: message}])], { type: 'application/json' });
const file = new File([blob], 'downloads.json', { type: 'application/json' }); // Create File object
@@ -509,6 +523,8 @@ const handle_ask = async (do_ask_gpt = true) => {
connectToSSE(`/backend-api/v2/files/${bucket_id}`, false, bucket_id); //Retrieve and refine
return;
}
await add_conversation(window.conversation_id);
let message_index = await add_message(window.conversation_id, "user", message);
let message_id = get_message_id();
@@ -602,6 +618,12 @@ document.querySelector(".media-player .fa-x").addEventListener("click", ()=>{
media_player.removeChild(audio);
});
document.getElementById("close_provider_forms").addEventListener("click", async () => {
const provider_forms = document.querySelector(".provider_forms");
provider_forms.classList.add("hidden");
chat.classList.remove("hidden");
});
const prepare_messages = (messages, message_index = -1, do_continue = false, do_filter = true) => {
messages = [ ...messages ]
if (message_index != null) {
@@ -930,7 +952,6 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
Object.entries(message.parameters).forEach(([key, value]) => {
parameters_storage[provider][key] = value;
});
await load_provider_parameters(provider);
}
}
@@ -1332,6 +1353,9 @@ const new_conversation = async () => {
};
function merge_messages(message1, message2) {
if (Array.isArray(message2)) {
return message2;
}
let newContent = message2;
// Remove start tokens
if (newContent.startsWith("```")) {
@@ -1530,6 +1554,8 @@ const load_conversation = async (conversation, scroll=true) => {
});
if (countTokensEnabled && window.GPTTokenizer_cl100k_base) {
const has_media = messages.filter((item)=>Array.isArray(item.content)).length > 0;
if (!has_media) {
const filtered = prepare_messages(messages, null, true, false);
if (filtered.length > 0) {
last_model = last_model?.startsWith("gpt-3") ? "gpt-3.5-turbo" : "gpt-4"
@@ -1539,11 +1565,9 @@ const load_conversation = async (conversation, scroll=true) => {
}
}
}
}
chatBody.innerHTML = elements.join("");
[...new Set(providers)].forEach(async (provider) => {
await load_provider_parameters(provider);
});
await register_message_buttons();
highlight(chatBody);
regenerate_button.classList.remove("regenerate-hidden");
@@ -1674,7 +1698,7 @@ const add_message = async (
}
if (title) {
conversation.title = title;
} else if (!conversation.title) {
} else if (!conversation.title && !Array.isArray(content)) {
let new_value = content.trim();
let new_lenght = new_value.indexOf("\n");
new_lenght = new_lenght > 200 || new_lenght < 0 ? 200 : new_lenght;
@@ -1728,8 +1752,10 @@ const add_message = async (
return conversation.items.length - 1;
};
const escapeHtml = (unsafe) => {
return unsafe+"".replaceAll('&', '&amp;').replaceAll('<', '&lt;').replaceAll('>', '&gt;').replaceAll('"', '&quot;').replaceAll("'", '&#039;');
function escapeHtml(str) {
const div = document.createElement('div');
div.appendChild(document.createTextNode(str));
return div.innerHTML;
}
const toLocaleDateString = (date) => {
@@ -1746,8 +1772,7 @@ const load_conversations = async () => {
}
}
conversations.sort((a, b) => (b.updated||0)-(a.updated||0));
let html = [];
await clear_conversations();
conversations.forEach((conversation) => {
// const length = conversation.items.map((item) => (
// !item.content.toLowerCase().includes("hello") &&
@@ -1759,8 +1784,10 @@ const load_conversations = async () => {
// return;
// }
const shareIcon = (conversation.id == window.start_id && window.share_id) ? '<i class="fa-solid fa-qrcode"></i>': '';
html.push(`
<div class="convo" id="convo-${conversation.id}">
let convo = document.createElement("div");
convo.classList.add("convo");
convo.id = `convo-${conversation.id}`;
convo.innerHTML = `
<div class="left" onclick="set_conversation('${conversation.id}')">
<i class="fa-regular fa-comments"></i>
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
@@ -1771,11 +1798,9 @@ const load_conversations = async () => {
<i onclick="delete_conversation('${conversation.id}')" class="fa-solid fa-trash"></i>
<i onclick="hide_option('${conversation.id}')" class="fa-regular fa-x"></i>
</div>
</div>
`);
`;
box_conversations.appendChild(convo);
});
await clear_conversations();
box_conversations.innerHTML += html.join("");
};
const hide_input = document.querySelector(".chat-toolbar .hide-input");
@@ -1800,6 +1825,13 @@ const uuid = () => {
);
};
function generateSecureRandomString(length = 128) {
const chars = 'abcdefghijklmnopqrstuvwxyz0123456789';
const array = new Uint8Array(length);
crypto.getRandomValues(array);
return Array.from(array, byte => chars[byte % chars.length]).join('');
}
function get_message_id() {
random_bytes = (Math.floor(Math.random() * 1338377565) + 2956589730).toString(
2
@@ -2003,6 +2035,9 @@ function count_chars(text) {
}
function count_words_and_tokens(text, model, completion_tokens, prompt_tokens) {
if (Array.isArray(text)) {
return "";
}
text = filter_message(text);
return `(${count_words(text)} words, ${count_chars(text)} chars, ${completion_tokens ? completion_tokens : count_tokens(model, text, prompt_tokens)} tokens)`;
}
@@ -2626,12 +2661,12 @@ async function upload_files(fileInput) {
fileInput.value = "";
}
if (result.media) {
const media = [];
result.media.forEach((filename)=> {
const url = `/files/${bucket_id}/media/${filename}`;
image_storage[url] = {bucket_id: bucket_id, name: filename};
media.push({bucket_id: bucket_id, name: filename, url: url});
});
mediaSelect.classList.remove("hidden");
renderMediaSelect();
await handle_ask(false, media);
}
}

View File

@@ -340,8 +340,7 @@ class Backend_Api(Api):
@app.route('/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename, dirname: str = None):
bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname))
media_dir = os.path.join(bucket_dir, "media")
media_dir = get_bucket_dir(dirname, bucket_id, "media")
try:
return send_from_directory(os.path.abspath(media_dir), filename)
except NotFound:
@@ -391,15 +390,14 @@ class Backend_Api(Api):
@self.app.route('/backend-api/v2/chat/<share_id>', methods=['GET'])
def get_chat(share_id: str) -> str:
share_id = secure_filename(share_id)
if self.chat_cache.get(share_id, 0) == request.headers.get("if-none-match", 0):
if self.chat_cache.get(share_id, 0) == int(request.headers.get("if-none-match", 0)):
return jsonify({"error": {"message": "Not modified"}}), 304
bucket_dir = get_bucket_dir(share_id)
file = os.path.join(bucket_dir, "chat.json")
file = get_bucket_dir(share_id, "chat.json")
if not os.path.isfile(file):
return jsonify({"error": {"message": "Not found"}}), 404
with open(file, 'r') as f:
chat_data = json.load(f)
if chat_data.get("updated", 0) == request.headers.get("if-none-match", 0):
if chat_data.get("updated", 0) == int(request.headers.get("if-none-match", 0)):
return jsonify({"error": {"message": "Not modified"}}), 304
self.chat_cache[share_id] = chat_data.get("updated", 0)
return jsonify(chat_data), 200

View File

@@ -103,7 +103,7 @@ def is_data_an_media(data, filename: str = None) -> str:
return is_accepted_format(data)
return is_data_uri_an_image(data)
def is_data_an_audio(data_uri: str, filename: str = None) -> str:
def is_data_an_audio(data_uri: str = None, filename: str = None) -> str:
if filename:
if filename.endswith(".wav"):
return "audio/wav"

View File

@@ -2,17 +2,18 @@ from __future__ import annotations
import os
import time
import uuid
import asyncio
import hashlib
import re
from typing import AsyncIterator
from urllib.parse import quote, unquote
from aiohttp import ClientSession, ClientError
from urllib.parse import urlparse
from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector, StreamResponse
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
from ..tools.files import secure_filename
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi
from . import is_accepted_format, extract_data_uri
@@ -23,13 +24,15 @@ images_dir = "./generated_images"
def get_media_extension(media: str) -> str:
"""Extract media file extension from URL or filename"""
match = re.search(r"\.(j?[a-z]{3})(?:\?|$)", media, re.IGNORECASE)
extension = match.group(1).lower() if match else ""
path = urlparse(media).path
extension = os.path.splitext(path)[1]
if not extension:
extension = os.path.splitext(media)[1]
if not extension:
return ""
if extension not in EXTENSIONS_MAP:
if extension[1:] not in EXTENSIONS_MAP:
raise ValueError(f"Unsupported media extension: {extension} in: {media}")
return f".{extension}"
return extension
def ensure_images_dir():
"""Create images directory if it doesn't exist"""
@@ -43,19 +46,6 @@ def get_source_url(image: str, default: str = None) -> str:
return decoded_url
return default
def secure_filename(filename: str) -> str:
if filename is None:
return None
# Keep letters, numbers, basic punctuation and all Unicode chars
filename = re.sub(
r'[^\w.,_-]+',
'_',
unquote(filename).strip(),
flags=re.UNICODE
)
filename = filename[:100].strip(".,_-")
return filename
def is_valid_media_type(content_type: str) -> bool:
return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/")

View File

@@ -72,9 +72,8 @@ async def to_async_iterator(iterator) -> AsyncIterator:
if hasattr(iterator, '__aiter__'):
async for item in iterator:
yield item
return
try:
elif asyncio.iscoroutine(iterator):
yield await iterator
else:
for item in iterator:
yield item
except TypeError:
yield await iterator

View File

@@ -6,6 +6,15 @@ import string
from ..typing import Messages, Cookies, AsyncIterator, Iterator
from .. import debug
def to_string(value) -> str:
if isinstance(value, str):
return value
elif isinstance(value, dict):
return value.get("text")
elif isinstance(value, list):
return "".join([to_string(v) for v in value if v.get("type") == "text"])
return str(value)
def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str:
"""
Format a series of messages into a single string, optionally adding special tokens.
@@ -18,11 +27,16 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
str: A formatted string containing all messages.
"""
if not add_special_tokens and len(messages) <= 1:
return messages[0]["content"]
formatted = "\n".join([
f'{message["role"].capitalize()}: {message["content"]}'
return to_string(messages[0]["content"])
messages = [
(message["role"], to_string(message["content"]))
for message in messages
if include_system or message["role"] != "system"
if include_system or message.get("role") != "system"
]
formatted = "\n".join([
f'{role.capitalize()}: {content}'
for role, content in messages
if content.strip()
])
if do_continue:
return formatted
@@ -34,11 +48,13 @@ def get_system_prompt(messages: Messages) -> str:
def get_last_user_message(messages: Messages) -> str:
user_messages = []
last_message = None if len(messages) == 0 else messages[-1]
messages = messages.copy()
while last_message is not None and messages:
last_message = messages.pop()
if last_message["role"] == "user":
if isinstance(last_message["content"], str):
user_messages.append(last_message["content"].strip())
content = to_string(last_message["content"]).strip()
if content:
user_messages.append(content)
else:
return "\n".join(user_messages[::-1])
return "\n".join(user_messages[::-1])

View File

@@ -1,22 +1,19 @@
from __future__ import annotations
import re
import os
import json
from pathlib import Path
from typing import Iterator, Optional, AsyncIterator
from aiohttp import ClientSession, ClientError, ClientResponse, ClientTimeout
import urllib.parse
from urllib.parse import unquote
import time
import zipfile
import asyncio
import hashlib
import base64
try:
from werkzeug.utils import secure_filename
except ImportError:
secure_filename = os.path.basename
try:
import PyPDF2
from PyPDF2.errors import PdfReadError
@@ -83,6 +80,19 @@ PLAIN_CACHE = "plain.cache"
DOWNLOADS_FILE = "downloads.json"
FILE_LIST = "files.txt"
def secure_filename(filename: str) -> str:
if filename is None:
return None
# Keep letters, numbers, basic punctuation and all Unicode chars
filename = re.sub(
r'[^\w.,_-]+',
'_',
unquote(filename).strip(),
flags=re.UNICODE
)
filename = filename[:100].strip(".,_-")
return filename
def supports_filename(filename: str):
if filename.endswith(".pdf"):
if has_pypdf2:
@@ -118,10 +128,8 @@ def supports_filename(filename: str):
return True
return False
def get_bucket_dir(bucket_id: str, dirname: str = None):
if dirname is None:
return os.path.join(get_cookies_dir(), "buckets", bucket_id)
return os.path.join(get_cookies_dir(), "buckets", dirname, bucket_id)
def get_bucket_dir(*parts):
return os.path.join(get_cookies_dir(), "buckets", *[secure_filename(part) for part in parts if part])
def get_buckets():
buckets_dir = os.path.join(get_cookies_dir(), "buckets")

82
g4f/tools/media.py Normal file
View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import os
import base64
from typing import Iterator, Union
from pathlib import Path
from ..typing import Messages
from ..image import is_data_an_media, is_data_an_audio, to_input_audio, to_data_uri
from .files import get_bucket_dir
def render_media(bucket_id: str, name: str, url: str, as_path: bool = False, as_base64: bool = False) -> Union[str, Path]:
if (not as_base64 or url.startswith("/")):
file = Path(get_bucket_dir(bucket_id, "media", name))
if as_path:
return file
data = file.read_bytes()
data_base64 = base64.b64encode(data).decode()
if as_base64:
return data_base64
return f"data:{is_data_an_media(data, name)};base64,{data_base64}"
def render_part(part: dict) -> dict:
if "type" in part:
return part
filename = part.get("name")
if filename.endswith(".wav") or filename.endswith(".mp3"):
return {
"type": "input_audio",
"input_audio": {
"data": render_media(**part, as_base64=True),
"format": "wav" if filename.endswith(".wav") else "mp3"
}
}
return {
"type": "image_url",
"image_url": {"url": render_media(**part)}
}
def merge_media(media: list, messages: list) -> Iterator:
buffer = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if isinstance(content, list):
for part in content:
if "type" not in part:
path = render_media(**part, as_path=True)
buffer.append((path, os.path.basename(path)))
elif part.get("type") == "image_url":
buffer.append((part.get("image_url"), None))
else:
buffer = []
yield from buffer
if media is not None:
yield from media
def render_messages(messages: Messages, media: list = None) -> Iterator:
for idx, message in enumerate(messages):
if isinstance(message["content"], list):
yield {
**message,
"content": [render_part(part) for part in message["content"] if part]
}
else:
if media is not None and idx == len(messages) - 1:
yield {
**message,
"content": [
{
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
] + ([{"type": "text", "text": message["content"]}] if isinstance(message["content"], str) else [])
}
else:
yield message