mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-06 00:36:57 +08:00
Add example for video generation
Add support for images in messages
This commit is contained in:
@@ -21,6 +21,7 @@ The G4F AsyncClient API is designed to be compatible with the OpenAI API, making
|
||||
- [Using a Vision Model](#using-a-vision-model)
|
||||
- **[Transcribing Audio with Chat Completions](#transcribing-audio-with-chat-completions)** *(New Section)*
|
||||
- [Image Generation](#image-generation)
|
||||
- **[Video Generation](#video-generation)** *(New Section)*
|
||||
- [Advanced Usage](#advanced-usage)
|
||||
- [Conversation Memory](#conversation-memory)
|
||||
- [Search Tool Support](#search-tool-support)
|
||||
@@ -327,6 +328,46 @@ asyncio.run(main())
|
||||
|
||||
---
|
||||
|
||||
### Video Generation
|
||||
|
||||
The G4F `AsyncClient` also supports **video generation** through supported providers like `HuggingFaceMedia`. You can retrieve the list of available video models and generate videos from prompts.
|
||||
|
||||
**Example: Generate a video using a prompt**
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from g4f.client import AsyncClient
|
||||
from g4f.Provider import HuggingFaceMedia
|
||||
|
||||
async def main():
|
||||
client = AsyncClient(
|
||||
provider=HuggingFaceMedia,
|
||||
api_key="hf_***" # Your API key here
|
||||
)
|
||||
|
||||
# Get available video models
|
||||
video_models = client.models.get_video()
|
||||
print("Available Video Models:", video_models)
|
||||
|
||||
# Generate video
|
||||
result = await client.media.generate(
|
||||
model=video_models[0],
|
||||
prompt="G4F AI technology is the best in the world.",
|
||||
response_format="url"
|
||||
)
|
||||
|
||||
print("Generated Video URL:", result.data[0].url)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
#### Explanation
|
||||
- **Client Initialization**: An `AsyncClient` is initialized using the `HuggingFaceMedia` provider with an API key.
|
||||
- **Model Discovery**: `client.models.get_video()` fetches a list of supported video models.
|
||||
- **Video Generation**: A prompt is submitted to generate a video using `await client.media.generate(...)`.
|
||||
- **Output**: The result includes a URL to the generated video, accessed via `result.data[0].url`.
|
||||
|
||||
> Make sure your selected provider supports media generation and your API key has appropriate permissions.
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
|
@@ -16,9 +16,9 @@ from ..requests.raise_for_status import raise_for_status
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..image import to_data_uri
|
||||
from ..cookies import get_cookies_dir
|
||||
from .helper import format_prompt, format_image_prompt
|
||||
from .helper import format_image_prompt
|
||||
from ..providers.response import JsonConversation, ImageResponse
|
||||
from ..errors import ModelNotSupportedError
|
||||
from ..tools.media import merge_media
|
||||
from .. import debug
|
||||
|
||||
class Conversation(JsonConversation):
|
||||
@@ -488,7 +488,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"filePath": f"/{image_name}",
|
||||
"contents": to_data_uri(image)
|
||||
}
|
||||
for image, image_name in media
|
||||
for image, image_name in merge_media(media, messages)
|
||||
],
|
||||
"fileText": "",
|
||||
"title": ""
|
||||
|
@@ -24,8 +24,9 @@ from .openai.har_file import get_headers, get_har_files
|
||||
from ..typing import CreateResult, Messages, MediaListType
|
||||
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters, ImageResponse
|
||||
from ..providers.response import BaseConversation, JsonConversation, RequestLogin, ImageResponse
|
||||
from ..providers.asyncio import get_running_loop
|
||||
from ..tools.media import merge_media
|
||||
from ..requests import get_nodriver
|
||||
from ..image import to_bytes, is_accepted_format
|
||||
from .helper import get_last_user_message
|
||||
@@ -142,17 +143,18 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
||||
|
||||
uploaded_images = []
|
||||
if media is not None:
|
||||
for image, _ in media:
|
||||
data = to_bytes(image)
|
||||
media, _ = [(None, None), *merge_media(media, messages)].pop()
|
||||
if media:
|
||||
if not isinstance(media, str):
|
||||
data = to_bytes(media)
|
||||
response = session.post(
|
||||
"https://copilot.microsoft.com/c/api/attachments",
|
||||
headers={"content-type": is_accepted_format(data)},
|
||||
data=data
|
||||
)
|
||||
raise_for_status(response)
|
||||
uploaded_images.append({"type":"image", "url": response.json().get("url")})
|
||||
break
|
||||
media = response.json().get("url")
|
||||
uploaded_images.append({"type":"image", "url": media})
|
||||
|
||||
wss = session.ws_connect(cls.websocket_url)
|
||||
# if clarity_token is not None:
|
||||
|
@@ -11,13 +11,14 @@ from aiohttp import ClientSession
|
||||
from .helper import filter_none, format_image_prompt
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..image import to_data_uri, is_data_an_audio, to_input_audio
|
||||
from ..image import is_data_an_audio
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..image.copy_images import save_response_media
|
||||
from ..image import use_aspect_ratio
|
||||
from ..providers.response import FinishReason, Usage, ToolCalls, ImageResponse
|
||||
from ..tools.media import render_messages
|
||||
from .. import debug
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
@@ -285,32 +286,15 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if response_format and response_format.get("type") == "json_object":
|
||||
json_mode = True
|
||||
|
||||
if media and messages:
|
||||
last_message = messages[-1].copy()
|
||||
image_content = [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": to_input_audio(media_data, filename)
|
||||
}
|
||||
if is_data_an_audio(media_data, filename) else {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(media_data)}
|
||||
}
|
||||
for media_data, filename in media
|
||||
]
|
||||
last_message["content"] = image_content + ([{"type": "text", "text": last_message["content"]}] if isinstance(last_message["content"], str) else image_content)
|
||||
messages[-1] = last_message
|
||||
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
if model in cls.audio_models:
|
||||
#data["voice"] = random.choice(cls.audio_models[model])
|
||||
url = cls.text_api_endpoint
|
||||
stream = False
|
||||
else:
|
||||
url = cls.openai_endpoint
|
||||
extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs}
|
||||
data = filter_none(**{
|
||||
"messages": messages,
|
||||
"messages": list(render_messages(messages, media)),
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"presence_penalty": presence_penalty,
|
||||
@@ -324,7 +308,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
})
|
||||
async with session.post(url, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
async for chunk in save_response_media(response, messages[-1]["content"], [model]):
|
||||
async for chunk in save_response_media(response, format_image_prompt(messages), [model]):
|
||||
yield chunk
|
||||
return
|
||||
if response.headers["content-type"].startswith("text/plain"):
|
||||
|
@@ -24,6 +24,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
|
||||
from ...cookies import get_cookies
|
||||
from ...tools.media import merge_media
|
||||
from .models import default_model, default_vision_model, fallback_models, image_models, model_aliases
|
||||
from ... import debug
|
||||
|
||||
@@ -146,8 +147,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
}
|
||||
data = CurlMime()
|
||||
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
|
||||
if media is not None:
|
||||
for image, filename in media:
|
||||
for image, filename in merge_media(media, messages):
|
||||
data.addpart(
|
||||
"files",
|
||||
filename=f"base64;{filename}",
|
||||
|
@@ -142,20 +142,30 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
}
|
||||
else:
|
||||
extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio)
|
||||
if provider_key == "fal-ai":
|
||||
url = f"{api_base}/{provider_id}"
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"image_size": "square_hd",
|
||||
**extra_data
|
||||
}
|
||||
if provider_key == "fal-ai" and task == "text-to-image":
|
||||
if aspect_ratio is None or aspect_ratio == "1:1":
|
||||
image_size = "square_hd",
|
||||
elif aspect_ratio == "16:9":
|
||||
image_size = "landscape_hd",
|
||||
elif aspect_ratio == "9:16":
|
||||
image_size = "portrait_16_9"
|
||||
else:
|
||||
image_size = extra_data # width, height
|
||||
data = {
|
||||
"image_size": image_size,
|
||||
**data
|
||||
}
|
||||
elif provider_key == "novita":
|
||||
url = f"{api_base}/v3/hf/{provider_id}"
|
||||
elif provider_key == "replicate":
|
||||
url = f"{api_base}/v1/models/{provider_id}/predictions"
|
||||
data = {
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
**extra_data
|
||||
}
|
||||
"input": data
|
||||
}
|
||||
elif provider_key in ("hf-inference", "hf-free"):
|
||||
api_base = "https://api-inference.huggingface.co"
|
||||
@@ -171,9 +181,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = f"{api_base}/v1/images/generations"
|
||||
data = {
|
||||
"response_format": "url",
|
||||
"prompt": prompt,
|
||||
"model": provider_id,
|
||||
**extra_data
|
||||
**data
|
||||
}
|
||||
|
||||
async with StreamSession(
|
||||
@@ -193,7 +202,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
return provider_info, chunk
|
||||
result = await response.json()
|
||||
if "video" in result:
|
||||
return provider_info, VideoResponse(result["video"]["url"], prompt)
|
||||
return provider_info, VideoResponse(result.get("video").get("url", result.get("video").get("url")), prompt)#video_url
|
||||
elif task == "text-to-image":
|
||||
return provider_info, ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
|
||||
elif task == "text-to-video":
|
||||
|
@@ -20,7 +20,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
supports_message_history = True
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
def get_models(cls, **kwargs) -> list[str]:
|
||||
if not cls.models:
|
||||
cls.models = HuggingFaceInference.get_models()
|
||||
cls.image_models = HuggingFaceInference.image_models
|
||||
|
@@ -27,6 +27,7 @@ from ...requests import get_nodriver
|
||||
from ...errors import MissingAuthError
|
||||
from ...image import to_bytes
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...tools.media import merge_media
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, get_cookies, get_last_user_message
|
||||
from ... import debug
|
||||
@@ -186,7 +187,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls.start_auto_refresh()
|
||||
)
|
||||
|
||||
uploads = None if media is None else await cls.upload_images(base_connector, media)
|
||||
uploads = await cls.upload_images(base_connector, merge_media(media, messages))
|
||||
async with ClientSession(
|
||||
cookies=cls._cookies,
|
||||
headers=REQUEST_HEADERS,
|
||||
|
@@ -25,7 +25,8 @@ from ...requests import get_nodriver
|
||||
from ...image import ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, NoValidHarFileError
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters, Reasoning
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
|
||||
from ...tools.media import merge_media
|
||||
from ..helper import format_cookies, get_last_user_message
|
||||
from ..openai.models import default_model, default_image_model, models, image_models, text_models
|
||||
from ..openai.har_file import get_request_config
|
||||
@@ -187,8 +188,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
await raise_for_status(response, "Get download url failed")
|
||||
image_data["download_url"] = (await response.json())["download_url"]
|
||||
return ImageRequest(image_data)
|
||||
if not media:
|
||||
return
|
||||
return [await upload_image(image, image_name) for image, image_name in media]
|
||||
|
||||
@classmethod
|
||||
@@ -330,7 +329,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
cls._update_request_args(auth_result, session)
|
||||
await raise_for_status(response)
|
||||
try:
|
||||
image_requests = None if media is None else await cls.upload_images(session, auth_result, media)
|
||||
image_requests = await cls.upload_images(session, auth_result, merge_media(media, messages))
|
||||
except Exception as e:
|
||||
debug.error("OpenaiChat: Upload image failed")
|
||||
debug.error(e)
|
||||
|
@@ -7,8 +7,8 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErr
|
||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
||||
from ...tools.media import render_messages
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ...image import to_data_uri, is_data_an_audio, to_input_audio
|
||||
from ... import debug
|
||||
|
||||
class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
|
||||
@@ -97,27 +97,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||
return
|
||||
|
||||
if media is not None and messages:
|
||||
if not model and hasattr(cls, "default_vision_model"):
|
||||
model = cls.default_vision_model
|
||||
last_message = messages[-1].copy()
|
||||
image_content = [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": to_input_audio(media_data, filename)
|
||||
}
|
||||
if is_data_an_audio(media_data, filename) else {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(media_data)}
|
||||
}
|
||||
for media_data, filename in media
|
||||
]
|
||||
last_message["content"] = image_content + ([{"type": "text", "text": last_message["content"]}] if isinstance(last_message["content"], str) else image_content)
|
||||
|
||||
messages[-1] = last_message
|
||||
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||
data = filter_none(
|
||||
messages=messages,
|
||||
messages=list(render_messages(messages, media)),
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
|
@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
||||
from .image_models import MediaModels
|
||||
from .models import ClientModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||
@@ -269,7 +269,7 @@ class Client(BaseClient):
|
||||
self.chat: Chat = Chat(self, provider)
|
||||
if image_provider is None:
|
||||
image_provider = provider
|
||||
self.models: MediaModels = MediaModels(self, image_provider)
|
||||
self.models: ClientModels = ClientModels(self, provider, image_provider)
|
||||
self.images: Images = Images(self, image_provider)
|
||||
self.media: Images = self.images
|
||||
|
||||
@@ -558,7 +558,7 @@ class AsyncClient(BaseClient):
|
||||
self.chat: AsyncChat = AsyncChat(self, provider)
|
||||
if image_provider is None:
|
||||
image_provider = provider
|
||||
self.models: MediaModels = MediaModels(self, image_provider)
|
||||
self.models: ClientModels = ClientModels(self, provider, image_provider)
|
||||
self.images: AsyncImages = AsyncImages(self, image_provider)
|
||||
self.media: AsyncImages = self.images
|
||||
|
||||
|
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..models import ModelUtils, ImageModel
|
||||
from ..Provider import ProviderUtils
|
||||
from ..providers.types import ProviderType
|
||||
|
||||
class MediaModels():
|
||||
def __init__(self, client, provider: ProviderType = None):
|
||||
self.client = client
|
||||
self.provider = provider
|
||||
|
||||
def get(self, name, default=None) -> ProviderType:
|
||||
if name in ModelUtils.convert:
|
||||
return ModelUtils.convert[name].best_provider
|
||||
if name in ProviderUtils.convert:
|
||||
return ProviderUtils.convert[name]
|
||||
return default
|
||||
|
||||
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return []
|
||||
if api_key is None:
|
||||
api_key = self.client.api_key
|
||||
return self.provider.get_models(
|
||||
**kwargs,
|
||||
**{} if api_key is None else {"api_key": api_key}
|
||||
)
|
||||
|
||||
def get_image(self, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
|
||||
self.get_all(**kwargs)
|
||||
if hasattr(self.provider, "image_models"):
|
||||
return self.provider.image_models
|
||||
return []
|
||||
|
||||
def get_video(self, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return []
|
||||
self.get_all(**kwargs)
|
||||
if hasattr(self.provider, "video_models"):
|
||||
return self.provider.video_models
|
||||
return []
|
62
g4f/client/models.py
Normal file
62
g4f/client/models.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..models import ModelUtils, ImageModel, VisionModel
|
||||
from ..Provider import ProviderUtils
|
||||
from ..providers.types import ProviderType
|
||||
|
||||
class ClientModels():
|
||||
def __init__(self, client, provider: ProviderType = None, media_provider: ProviderType = None):
|
||||
self.client = client
|
||||
self.provider = provider
|
||||
self.media_provider = media_provider
|
||||
|
||||
def get(self, name, default=None) -> ProviderType:
|
||||
if name in ModelUtils.convert:
|
||||
return ModelUtils.convert[name].best_provider
|
||||
if name in ProviderUtils.convert:
|
||||
return ProviderUtils.convert[name]
|
||||
return default
|
||||
|
||||
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return []
|
||||
if api_key is None:
|
||||
api_key = self.client.api_key
|
||||
return self.provider.get_models(
|
||||
**kwargs,
|
||||
**{} if api_key is None else {"api_key": api_key}
|
||||
)
|
||||
|
||||
def get_vision(self, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, VisionModel)]
|
||||
self.get_all(**kwargs)
|
||||
if hasattr(self.provider, "vision_models"):
|
||||
return self.provider.vision_models
|
||||
return []
|
||||
|
||||
def get_media(self, api_key: str = None, **kwargs) -> list[str]:
|
||||
if self.media_provider is None:
|
||||
return []
|
||||
if api_key is None:
|
||||
api_key = self.client.api_key
|
||||
return self.media_provider.get_models(
|
||||
**kwargs,
|
||||
**{} if api_key is None else {"api_key": api_key}
|
||||
)
|
||||
|
||||
def get_image(self, **kwargs) -> list[str]:
|
||||
if self.media_provider is None:
|
||||
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
|
||||
self.get_media(**kwargs)
|
||||
if hasattr(self.media_provider, "image_models"):
|
||||
return self.media_provider.image_models
|
||||
return []
|
||||
|
||||
def get_video(self, **kwargs) -> list[str]:
|
||||
if self.media_provider is None:
|
||||
return []
|
||||
self.get_media(**kwargs)
|
||||
if hasattr(self.media_provider, "video_models"):
|
||||
return self.media_provider.video_models
|
||||
return []
|
@@ -89,7 +89,7 @@
|
||||
</head>
|
||||
<body>
|
||||
<img id="image-feed" class="hidden" alt="Image Feed">
|
||||
<video id="video-feed" class="hidden" alt="Video Feed" src="/search/video" autoplay></video>
|
||||
<video id="video-feed" class="hidden" alt="Video Feed" src="/search/video+g4f" autoplay></video>
|
||||
|
||||
<!-- Gradient Background Circle -->
|
||||
<div class="gradient"></div>
|
||||
@@ -105,6 +105,7 @@
|
||||
let skipImage = 0;
|
||||
let errorVideo = 0;
|
||||
let errorImage = 0;
|
||||
let skipRefresh = 0;
|
||||
videoFeed.onloadeddata = () => {
|
||||
videoFeed.classList.remove("hidden");
|
||||
gradient.classList.add("hidden");
|
||||
@@ -116,15 +117,15 @@
|
||||
gradient.classList.remove("hidden");
|
||||
return;
|
||||
}
|
||||
videoFeed.src = "/search/video?skip=" + skipVideo;
|
||||
videoFeed.src = "/search/video+g4f?skip=" + skipVideo;
|
||||
skipVideo++;
|
||||
};
|
||||
videoFeed.onended = () => {
|
||||
videoFeed.src = "/search/video?skip=" + skipVideo;
|
||||
videoFeed.src = "/search/video+g4f?skip=" + skipVideo;
|
||||
skipVideo++;
|
||||
};
|
||||
videoFeed.onclick = () => {
|
||||
videoFeed.src = "/search/video?skip=" + skipVideo;
|
||||
videoFeed.src = "/search/video+g4f?skip=" + skipVideo;
|
||||
skipVideo++;
|
||||
};
|
||||
function initES() {
|
||||
@@ -173,11 +174,15 @@
|
||||
skipImage++;
|
||||
return;
|
||||
}
|
||||
if (skipRefresh) {
|
||||
skipRefresh = 0;
|
||||
return;
|
||||
}
|
||||
if (images.length > 0) {
|
||||
imageFeed.classList.remove("hidden");
|
||||
imageFeed.src = images.shift();
|
||||
gradient.classList.add("hidden");
|
||||
} else if(imageFeed) {
|
||||
} else {
|
||||
initES();
|
||||
}
|
||||
}, 7000);
|
||||
@@ -192,6 +197,7 @@
|
||||
};
|
||||
imageFeed.onclick = () => {
|
||||
imageFeed.src = "/search/image?random=" + Math.random();
|
||||
skipRefresh = 1;
|
||||
};
|
||||
})();
|
||||
</script>
|
||||
|
@@ -81,16 +81,19 @@
|
||||
border: none;
|
||||
}
|
||||
|
||||
#background, #image-feed {
|
||||
#background {
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
z-index: -1;
|
||||
object-fit: cover;
|
||||
object-position: center;
|
||||
width: 100%;
|
||||
background: black;
|
||||
}
|
||||
|
||||
.container * {
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.description, form p a {
|
||||
font-size: 1.2rem;
|
||||
margin-bottom: 30px;
|
||||
@@ -176,9 +179,6 @@
|
||||
<body>
|
||||
<iframe id="background" src="/background"></iframe>
|
||||
|
||||
<!-- Gradient Background Circle -->
|
||||
<div class="gradient"></div>
|
||||
|
||||
<button class="slide-button">
|
||||
<i class="fa-solid fa-arrow-left"></i>
|
||||
</button>
|
||||
|
@@ -48,7 +48,6 @@
|
||||
align-items: center;
|
||||
height: 100%;
|
||||
text-align: center;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
header {
|
||||
@@ -67,7 +66,11 @@
|
||||
#background {
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
z-index: -1;
|
||||
top: 0;
|
||||
}
|
||||
|
||||
.container * {
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.stream-widget {
|
||||
|
@@ -270,7 +270,7 @@
|
||||
<i class="fa-regular fa-image"></i>
|
||||
</label>
|
||||
<label class="file-label" for="file">
|
||||
<input type="file" id="file" name="file" accept=".txt, .html, .xml, .json, .js, .har, .sh, .py, .php, .css, .yaml, .sql, .log, .csv, .twig, .md, .pdf, .docx, .odt, .epub, .xlsx, .zip" required multiple/>
|
||||
<input type="file" id="file" name="file" accept="*/*" required multiple/>
|
||||
<i class="fa-solid fa-paperclip"></i>
|
||||
</label>
|
||||
<label class="micro-label" for="micro">
|
||||
|
@@ -72,9 +72,15 @@
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let share_id = null;
|
||||
document.getElementById('generateQRCode').addEventListener('click', async () => {
|
||||
const share_id = generate_uuid();
|
||||
if (share_id) {
|
||||
const delete_url = `${share_url}/backend-api/v2/files/${encodeURI(share_id)}`;
|
||||
await fetch(delete_url, {
|
||||
method: 'DELETE'
|
||||
});
|
||||
}
|
||||
share_id = generate_uuid();
|
||||
|
||||
const url = `${share_url}/backend-api/v2/chat/${encodeURI(share_id)}`;
|
||||
const response = await fetch(url, {
|
||||
|
@@ -67,6 +67,17 @@ let markdown_render = (content) => escapeHtml(content);
|
||||
if (window.markdownit) {
|
||||
const markdown = window.markdownit();
|
||||
markdown_render = (content) => {
|
||||
if (Array.isArray(content)) {
|
||||
content = content.map((item) => {
|
||||
if (item.name.endsWith(".wav") || item.name.endsWith(".mp3")) {
|
||||
return `<audio controls src="${item.url}"></audio>`;
|
||||
}
|
||||
if (item.name.endsWith(".mp4") || item.name.endsWith(".webm")) {
|
||||
return `<video controls src="${item.url}"></video>`;
|
||||
}
|
||||
return `[]()`;
|
||||
}).join("\n");
|
||||
}
|
||||
return markdown.render(content
|
||||
.replaceAll(/<!-- generated images start -->|<!-- generated images end -->/gm, "")
|
||||
.replaceAll(/<img data-prompt="[^>]+">/gm, "")
|
||||
@@ -95,7 +106,7 @@ function render_reasoning(reasoning, final = false) {
|
||||
return `<div class="reasoning_body">
|
||||
<div class="reasoning_title">
|
||||
<strong>${reasoning.label ? reasoning.label :'Reasoning <i class="brain">🧠</i>'}: </strong>
|
||||
${reasoning.status ? escapeHtml(reasoning.status) : ' <i class="fas fa-spinner fa-spin"></i>'}
|
||||
${reasoning.status ? escapeHtml(reasoning.status) : '<i class="fas fa-spinner fa-spin"></i>'}
|
||||
</div>
|
||||
${inner_text}
|
||||
</div>`;
|
||||
@@ -106,12 +117,18 @@ function render_reasoning_text(reasoning) {
|
||||
}
|
||||
|
||||
function filter_message(text) {
|
||||
if (Array.isArray(text)) {
|
||||
return text;
|
||||
}
|
||||
return text.replaceAll(
|
||||
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
|
||||
).replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
|
||||
}
|
||||
|
||||
function filter_message_content(text) {
|
||||
if (Array.isArray(text)) {
|
||||
return text;
|
||||
}
|
||||
return text.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "")
|
||||
}
|
||||
|
||||
@@ -269,11 +286,12 @@ const register_message_buttons = async () => {
|
||||
return
|
||||
}
|
||||
el.dataset.click = true;
|
||||
const provider_forms = document.querySelector(".provider_forms");
|
||||
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
|
||||
const provider_link = el.querySelector("a");
|
||||
provider_link?.addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await load_provider_parameters(el.dataset.provider);
|
||||
const provider_forms = document.querySelector(".provider_forms");
|
||||
const provider_form = provider_forms.querySelector(`#${el.dataset.provider}-form`);
|
||||
if (provider_form) {
|
||||
provider_form.classList.remove("hidden");
|
||||
provider_forms.classList.remove("hidden");
|
||||
@@ -281,11 +299,6 @@ const register_message_buttons = async () => {
|
||||
}
|
||||
return false;
|
||||
});
|
||||
document.getElementById("close_provider_forms").addEventListener("click", async () => {
|
||||
provider_form.classList.add("hidden");
|
||||
provider_forms.classList.add("hidden");
|
||||
chat.classList.remove("hidden");
|
||||
});
|
||||
});
|
||||
|
||||
chatBody.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
|
||||
@@ -479,23 +492,24 @@ const delete_conversations = async () => {
|
||||
await new_conversation();
|
||||
};
|
||||
|
||||
const handle_ask = async (do_ask_gpt = true) => {
|
||||
const handle_ask = async (do_ask_gpt = true, message = null) => {
|
||||
userInput.style.height = "82px";
|
||||
userInput.focus();
|
||||
await scroll_to_bottom();
|
||||
|
||||
let message = userInput.value.trim();
|
||||
if (message.length <= 0) {
|
||||
if (!message) {
|
||||
message = userInput.value.trim();
|
||||
if (!message) {
|
||||
return;
|
||||
}
|
||||
userInput.value = "";
|
||||
await count_input()
|
||||
await add_conversation(window.conversation_id);
|
||||
}
|
||||
|
||||
// Is message a url?
|
||||
const expression = /^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$/gi;
|
||||
const regex = new RegExp(expression);
|
||||
if (message.match(regex)) {
|
||||
if (!Array.isArray(message) && message.match(regex)) {
|
||||
paperclip.classList.add("blink");
|
||||
const blob = new Blob([JSON.stringify([{url: message}])], { type: 'application/json' });
|
||||
const file = new File([blob], 'downloads.json', { type: 'application/json' }); // Create File object
|
||||
@@ -509,6 +523,8 @@ const handle_ask = async (do_ask_gpt = true) => {
|
||||
connectToSSE(`/backend-api/v2/files/${bucket_id}`, false, bucket_id); //Retrieve and refine
|
||||
return;
|
||||
}
|
||||
|
||||
await add_conversation(window.conversation_id);
|
||||
let message_index = await add_message(window.conversation_id, "user", message);
|
||||
let message_id = get_message_id();
|
||||
|
||||
@@ -602,6 +618,12 @@ document.querySelector(".media-player .fa-x").addEventListener("click", ()=>{
|
||||
media_player.removeChild(audio);
|
||||
});
|
||||
|
||||
document.getElementById("close_provider_forms").addEventListener("click", async () => {
|
||||
const provider_forms = document.querySelector(".provider_forms");
|
||||
provider_forms.classList.add("hidden");
|
||||
chat.classList.remove("hidden");
|
||||
});
|
||||
|
||||
const prepare_messages = (messages, message_index = -1, do_continue = false, do_filter = true) => {
|
||||
messages = [ ...messages ]
|
||||
if (message_index != null) {
|
||||
@@ -930,7 +952,6 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
|
||||
Object.entries(message.parameters).forEach(([key, value]) => {
|
||||
parameters_storage[provider][key] = value;
|
||||
});
|
||||
await load_provider_parameters(provider);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1332,6 +1353,9 @@ const new_conversation = async () => {
|
||||
};
|
||||
|
||||
function merge_messages(message1, message2) {
|
||||
if (Array.isArray(message2)) {
|
||||
return message2;
|
||||
}
|
||||
let newContent = message2;
|
||||
// Remove start tokens
|
||||
if (newContent.startsWith("```")) {
|
||||
@@ -1530,6 +1554,8 @@ const load_conversation = async (conversation, scroll=true) => {
|
||||
});
|
||||
|
||||
if (countTokensEnabled && window.GPTTokenizer_cl100k_base) {
|
||||
const has_media = messages.filter((item)=>Array.isArray(item.content)).length > 0;
|
||||
if (!has_media) {
|
||||
const filtered = prepare_messages(messages, null, true, false);
|
||||
if (filtered.length > 0) {
|
||||
last_model = last_model?.startsWith("gpt-3") ? "gpt-3.5-turbo" : "gpt-4"
|
||||
@@ -1539,11 +1565,9 @@ const load_conversation = async (conversation, scroll=true) => {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chatBody.innerHTML = elements.join("");
|
||||
[...new Set(providers)].forEach(async (provider) => {
|
||||
await load_provider_parameters(provider);
|
||||
});
|
||||
await register_message_buttons();
|
||||
highlight(chatBody);
|
||||
regenerate_button.classList.remove("regenerate-hidden");
|
||||
@@ -1674,7 +1698,7 @@ const add_message = async (
|
||||
}
|
||||
if (title) {
|
||||
conversation.title = title;
|
||||
} else if (!conversation.title) {
|
||||
} else if (!conversation.title && !Array.isArray(content)) {
|
||||
let new_value = content.trim();
|
||||
let new_lenght = new_value.indexOf("\n");
|
||||
new_lenght = new_lenght > 200 || new_lenght < 0 ? 200 : new_lenght;
|
||||
@@ -1728,8 +1752,10 @@ const add_message = async (
|
||||
return conversation.items.length - 1;
|
||||
};
|
||||
|
||||
const escapeHtml = (unsafe) => {
|
||||
return unsafe+"".replaceAll('&', '&').replaceAll('<', '<').replaceAll('>', '>').replaceAll('"', '"').replaceAll("'", ''');
|
||||
function escapeHtml(str) {
|
||||
const div = document.createElement('div');
|
||||
div.appendChild(document.createTextNode(str));
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
const toLocaleDateString = (date) => {
|
||||
@@ -1746,8 +1772,7 @@ const load_conversations = async () => {
|
||||
}
|
||||
}
|
||||
conversations.sort((a, b) => (b.updated||0)-(a.updated||0));
|
||||
|
||||
let html = [];
|
||||
await clear_conversations();
|
||||
conversations.forEach((conversation) => {
|
||||
// const length = conversation.items.map((item) => (
|
||||
// !item.content.toLowerCase().includes("hello") &&
|
||||
@@ -1759,8 +1784,10 @@ const load_conversations = async () => {
|
||||
// return;
|
||||
// }
|
||||
const shareIcon = (conversation.id == window.start_id && window.share_id) ? '<i class="fa-solid fa-qrcode"></i>': '';
|
||||
html.push(`
|
||||
<div class="convo" id="convo-${conversation.id}">
|
||||
let convo = document.createElement("div");
|
||||
convo.classList.add("convo");
|
||||
convo.id = `convo-${conversation.id}`;
|
||||
convo.innerHTML = `
|
||||
<div class="left" onclick="set_conversation('${conversation.id}')">
|
||||
<i class="fa-regular fa-comments"></i>
|
||||
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
|
||||
@@ -1771,11 +1798,9 @@ const load_conversations = async () => {
|
||||
<i onclick="delete_conversation('${conversation.id}')" class="fa-solid fa-trash"></i>
|
||||
<i onclick="hide_option('${conversation.id}')" class="fa-regular fa-x"></i>
|
||||
</div>
|
||||
</div>
|
||||
`);
|
||||
`;
|
||||
box_conversations.appendChild(convo);
|
||||
});
|
||||
await clear_conversations();
|
||||
box_conversations.innerHTML += html.join("");
|
||||
};
|
||||
|
||||
const hide_input = document.querySelector(".chat-toolbar .hide-input");
|
||||
@@ -1800,6 +1825,13 @@ const uuid = () => {
|
||||
);
|
||||
};
|
||||
|
||||
function generateSecureRandomString(length = 128) {
|
||||
const chars = 'abcdefghijklmnopqrstuvwxyz0123456789';
|
||||
const array = new Uint8Array(length);
|
||||
crypto.getRandomValues(array);
|
||||
return Array.from(array, byte => chars[byte % chars.length]).join('');
|
||||
}
|
||||
|
||||
function get_message_id() {
|
||||
random_bytes = (Math.floor(Math.random() * 1338377565) + 2956589730).toString(
|
||||
2
|
||||
@@ -2003,6 +2035,9 @@ function count_chars(text) {
|
||||
}
|
||||
|
||||
function count_words_and_tokens(text, model, completion_tokens, prompt_tokens) {
|
||||
if (Array.isArray(text)) {
|
||||
return "";
|
||||
}
|
||||
text = filter_message(text);
|
||||
return `(${count_words(text)} words, ${count_chars(text)} chars, ${completion_tokens ? completion_tokens : count_tokens(model, text, prompt_tokens)} tokens)`;
|
||||
}
|
||||
@@ -2626,12 +2661,12 @@ async function upload_files(fileInput) {
|
||||
fileInput.value = "";
|
||||
}
|
||||
if (result.media) {
|
||||
const media = [];
|
||||
result.media.forEach((filename)=> {
|
||||
const url = `/files/${bucket_id}/media/${filename}`;
|
||||
image_storage[url] = {bucket_id: bucket_id, name: filename};
|
||||
media.push({bucket_id: bucket_id, name: filename, url: url});
|
||||
});
|
||||
mediaSelect.classList.remove("hidden");
|
||||
renderMediaSelect();
|
||||
await handle_ask(false, media);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -340,8 +340,7 @@ class Backend_Api(Api):
|
||||
|
||||
@app.route('/files/<bucket_id>/media/<filename>', methods=['GET'])
|
||||
def get_media(bucket_id, filename, dirname: str = None):
|
||||
bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname))
|
||||
media_dir = os.path.join(bucket_dir, "media")
|
||||
media_dir = get_bucket_dir(dirname, bucket_id, "media")
|
||||
try:
|
||||
return send_from_directory(os.path.abspath(media_dir), filename)
|
||||
except NotFound:
|
||||
@@ -391,15 +390,14 @@ class Backend_Api(Api):
|
||||
@self.app.route('/backend-api/v2/chat/<share_id>', methods=['GET'])
|
||||
def get_chat(share_id: str) -> str:
|
||||
share_id = secure_filename(share_id)
|
||||
if self.chat_cache.get(share_id, 0) == request.headers.get("if-none-match", 0):
|
||||
if self.chat_cache.get(share_id, 0) == int(request.headers.get("if-none-match", 0)):
|
||||
return jsonify({"error": {"message": "Not modified"}}), 304
|
||||
bucket_dir = get_bucket_dir(share_id)
|
||||
file = os.path.join(bucket_dir, "chat.json")
|
||||
file = get_bucket_dir(share_id, "chat.json")
|
||||
if not os.path.isfile(file):
|
||||
return jsonify({"error": {"message": "Not found"}}), 404
|
||||
with open(file, 'r') as f:
|
||||
chat_data = json.load(f)
|
||||
if chat_data.get("updated", 0) == request.headers.get("if-none-match", 0):
|
||||
if chat_data.get("updated", 0) == int(request.headers.get("if-none-match", 0)):
|
||||
return jsonify({"error": {"message": "Not modified"}}), 304
|
||||
self.chat_cache[share_id] = chat_data.get("updated", 0)
|
||||
return jsonify(chat_data), 200
|
||||
|
@@ -103,7 +103,7 @@ def is_data_an_media(data, filename: str = None) -> str:
|
||||
return is_accepted_format(data)
|
||||
return is_data_uri_an_image(data)
|
||||
|
||||
def is_data_an_audio(data_uri: str, filename: str = None) -> str:
|
||||
def is_data_an_audio(data_uri: str = None, filename: str = None) -> str:
|
||||
if filename:
|
||||
if filename.endswith(".wav"):
|
||||
return "audio/wav"
|
||||
|
@@ -2,17 +2,18 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
import hashlib
|
||||
import re
|
||||
from typing import AsyncIterator
|
||||
from urllib.parse import quote, unquote
|
||||
from aiohttp import ClientSession, ClientError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..typing import Optional, Cookies
|
||||
from ..requests.aiohttp import get_connector, StreamResponse
|
||||
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
|
||||
from ..tools.files import secure_filename
|
||||
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
|
||||
from ..Provider.template import BackendApi
|
||||
from . import is_accepted_format, extract_data_uri
|
||||
@@ -23,13 +24,15 @@ images_dir = "./generated_images"
|
||||
|
||||
def get_media_extension(media: str) -> str:
|
||||
"""Extract media file extension from URL or filename"""
|
||||
match = re.search(r"\.(j?[a-z]{3})(?:\?|$)", media, re.IGNORECASE)
|
||||
extension = match.group(1).lower() if match else ""
|
||||
path = urlparse(media).path
|
||||
extension = os.path.splitext(path)[1]
|
||||
if not extension:
|
||||
extension = os.path.splitext(media)[1]
|
||||
if not extension:
|
||||
return ""
|
||||
if extension not in EXTENSIONS_MAP:
|
||||
if extension[1:] not in EXTENSIONS_MAP:
|
||||
raise ValueError(f"Unsupported media extension: {extension} in: {media}")
|
||||
return f".{extension}"
|
||||
return extension
|
||||
|
||||
def ensure_images_dir():
|
||||
"""Create images directory if it doesn't exist"""
|
||||
@@ -43,19 +46,6 @@ def get_source_url(image: str, default: str = None) -> str:
|
||||
return decoded_url
|
||||
return default
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
if filename is None:
|
||||
return None
|
||||
# Keep letters, numbers, basic punctuation and all Unicode chars
|
||||
filename = re.sub(
|
||||
r'[^\w.,_-]+',
|
||||
'_',
|
||||
unquote(filename).strip(),
|
||||
flags=re.UNICODE
|
||||
)
|
||||
filename = filename[:100].strip(".,_-")
|
||||
return filename
|
||||
|
||||
def is_valid_media_type(content_type: str) -> bool:
|
||||
return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/")
|
||||
|
||||
|
@@ -72,9 +72,8 @@ async def to_async_iterator(iterator) -> AsyncIterator:
|
||||
if hasattr(iterator, '__aiter__'):
|
||||
async for item in iterator:
|
||||
yield item
|
||||
return
|
||||
try:
|
||||
elif asyncio.iscoroutine(iterator):
|
||||
yield await iterator
|
||||
else:
|
||||
for item in iterator:
|
||||
yield item
|
||||
except TypeError:
|
||||
yield await iterator
|
@@ -6,6 +6,15 @@ import string
|
||||
from ..typing import Messages, Cookies, AsyncIterator, Iterator
|
||||
from .. import debug
|
||||
|
||||
def to_string(value) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return value.get("text")
|
||||
elif isinstance(value, list):
|
||||
return "".join([to_string(v) for v in value if v.get("type") == "text"])
|
||||
return str(value)
|
||||
|
||||
def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str:
|
||||
"""
|
||||
Format a series of messages into a single string, optionally adding special tokens.
|
||||
@@ -18,11 +27,16 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
|
||||
str: A formatted string containing all messages.
|
||||
"""
|
||||
if not add_special_tokens and len(messages) <= 1:
|
||||
return messages[0]["content"]
|
||||
formatted = "\n".join([
|
||||
f'{message["role"].capitalize()}: {message["content"]}'
|
||||
return to_string(messages[0]["content"])
|
||||
messages = [
|
||||
(message["role"], to_string(message["content"]))
|
||||
for message in messages
|
||||
if include_system or message["role"] != "system"
|
||||
if include_system or message.get("role") != "system"
|
||||
]
|
||||
formatted = "\n".join([
|
||||
f'{role.capitalize()}: {content}'
|
||||
for role, content in messages
|
||||
if content.strip()
|
||||
])
|
||||
if do_continue:
|
||||
return formatted
|
||||
@@ -34,11 +48,13 @@ def get_system_prompt(messages: Messages) -> str:
|
||||
def get_last_user_message(messages: Messages) -> str:
|
||||
user_messages = []
|
||||
last_message = None if len(messages) == 0 else messages[-1]
|
||||
messages = messages.copy()
|
||||
while last_message is not None and messages:
|
||||
last_message = messages.pop()
|
||||
if last_message["role"] == "user":
|
||||
if isinstance(last_message["content"], str):
|
||||
user_messages.append(last_message["content"].strip())
|
||||
content = to_string(last_message["content"]).strip()
|
||||
if content:
|
||||
user_messages.append(content)
|
||||
else:
|
||||
return "\n".join(user_messages[::-1])
|
||||
return "\n".join(user_messages[::-1])
|
||||
|
@@ -1,22 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, AsyncIterator
|
||||
from aiohttp import ClientSession, ClientError, ClientResponse, ClientTimeout
|
||||
import urllib.parse
|
||||
from urllib.parse import unquote
|
||||
import time
|
||||
import zipfile
|
||||
import asyncio
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
try:
|
||||
from werkzeug.utils import secure_filename
|
||||
except ImportError:
|
||||
secure_filename = os.path.basename
|
||||
|
||||
try:
|
||||
import PyPDF2
|
||||
from PyPDF2.errors import PdfReadError
|
||||
@@ -83,6 +80,19 @@ PLAIN_CACHE = "plain.cache"
|
||||
DOWNLOADS_FILE = "downloads.json"
|
||||
FILE_LIST = "files.txt"
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
if filename is None:
|
||||
return None
|
||||
# Keep letters, numbers, basic punctuation and all Unicode chars
|
||||
filename = re.sub(
|
||||
r'[^\w.,_-]+',
|
||||
'_',
|
||||
unquote(filename).strip(),
|
||||
flags=re.UNICODE
|
||||
)
|
||||
filename = filename[:100].strip(".,_-")
|
||||
return filename
|
||||
|
||||
def supports_filename(filename: str):
|
||||
if filename.endswith(".pdf"):
|
||||
if has_pypdf2:
|
||||
@@ -118,10 +128,8 @@ def supports_filename(filename: str):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_bucket_dir(bucket_id: str, dirname: str = None):
|
||||
if dirname is None:
|
||||
return os.path.join(get_cookies_dir(), "buckets", bucket_id)
|
||||
return os.path.join(get_cookies_dir(), "buckets", dirname, bucket_id)
|
||||
def get_bucket_dir(*parts):
|
||||
return os.path.join(get_cookies_dir(), "buckets", *[secure_filename(part) for part in parts if part])
|
||||
|
||||
def get_buckets():
|
||||
buckets_dir = os.path.join(get_cookies_dir(), "buckets")
|
||||
|
82
g4f/tools/media.py
Normal file
82
g4f/tools/media.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
from typing import Iterator, Union
|
||||
from pathlib import Path
|
||||
|
||||
from ..typing import Messages
|
||||
from ..image import is_data_an_media, is_data_an_audio, to_input_audio, to_data_uri
|
||||
from .files import get_bucket_dir
|
||||
|
||||
def render_media(bucket_id: str, name: str, url: str, as_path: bool = False, as_base64: bool = False) -> Union[str, Path]:
|
||||
if (not as_base64 or url.startswith("/")):
|
||||
file = Path(get_bucket_dir(bucket_id, "media", name))
|
||||
if as_path:
|
||||
return file
|
||||
data = file.read_bytes()
|
||||
data_base64 = base64.b64encode(data).decode()
|
||||
if as_base64:
|
||||
return data_base64
|
||||
return f"data:{is_data_an_media(data, name)};base64,{data_base64}"
|
||||
|
||||
def render_part(part: dict) -> dict:
|
||||
if "type" in part:
|
||||
return part
|
||||
filename = part.get("name")
|
||||
if filename.endswith(".wav") or filename.endswith(".mp3"):
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": render_media(**part, as_base64=True),
|
||||
"format": "wav" if filename.endswith(".wav") else "mp3"
|
||||
}
|
||||
}
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": render_media(**part)}
|
||||
}
|
||||
|
||||
def merge_media(media: list, messages: list) -> Iterator:
|
||||
buffer = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if "type" not in part:
|
||||
path = render_media(**part, as_path=True)
|
||||
buffer.append((path, os.path.basename(path)))
|
||||
elif part.get("type") == "image_url":
|
||||
buffer.append((part.get("image_url"), None))
|
||||
else:
|
||||
buffer = []
|
||||
yield from buffer
|
||||
if media is not None:
|
||||
yield from media
|
||||
|
||||
def render_messages(messages: Messages, media: list = None) -> Iterator:
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message["content"], list):
|
||||
yield {
|
||||
**message,
|
||||
"content": [render_part(part) for part in message["content"] if part]
|
||||
}
|
||||
else:
|
||||
if media is not None and idx == len(messages) - 1:
|
||||
yield {
|
||||
**message,
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": to_input_audio(media_data, filename)
|
||||
}
|
||||
if is_data_an_audio(media_data, filename) else {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(media_data)}
|
||||
}
|
||||
for media_data, filename in media
|
||||
] + ([{"type": "text", "text": message["content"]}] if isinstance(message["content"], str) else [])
|
||||
}
|
||||
else:
|
||||
yield message
|
Reference in New Issue
Block a user