Use dynamtic aspect_ratio for image and video size

This commit is contained in:
hlohaus
2025-03-23 14:47:26 +01:00
parent 8eaaf5db95
commit e76e5f7835
16 changed files with 276 additions and 187 deletions

View File

@@ -15,6 +15,7 @@ from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from ..image.copy_images import save_response_media
from ..image import use_aspect_ratio
from ..providers.response import FinishReason, Usage, ToolCalls
from .. import debug
@@ -139,8 +140,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
cache: bool = False,
# Image generation parameters
prompt: str = None,
width: int = 1024,
height: int = 1024,
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
seed: Optional[int] = None,
nologo: bool = True,
private: bool = False,
@@ -177,6 +179,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
model=model,
prompt=format_image_prompt(messages, prompt),
proxy=proxy,
aspect_ratio=aspect_ratio,
width=width,
height=height,
seed=seed,
@@ -212,6 +215,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
prompt: str,
proxy: str,
aspect_ratio: str,
width: int,
height: int,
seed: Optional[int],
@@ -223,17 +227,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
if not cache and seed is None:
seed = random.randint(9999, 99999999)
params = {
"seed": str(seed) if seed is not None else None,
"width": str(width),
"height": str(height),
params = use_aspect_ratio({
"seed": seed,
"width": width,
"height": height,
"model": model,
"nologo": str(nologo).lower(),
"private": str(private).lower(),
"enhance": str(enhance).lower(),
"safe": str(safe).lower()
}
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items() if v is not None)
}, aspect_ratio)
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in params.items() if v is not None)
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
#yield ImagePreview(url, prompt)

View File

@@ -36,8 +36,9 @@ class PollinationsImage(PollinationsAI):
messages: Messages,
proxy: str = None,
prompt: str = None,
width: int = 1024,
height: int = 1024,
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
seed: Optional[int] = None,
cache: bool = False,
nologo: bool = True,
@@ -52,6 +53,7 @@ class PollinationsImage(PollinationsAI):
model=model,
prompt=format_image_prompt(messages, prompt),
proxy=proxy,
aspect_ratio=aspect_ratio,
width=width,
height=height,
seed=seed,

View File

@@ -11,6 +11,7 @@ from ...errors import ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse
from ...image.copy_images import save_response_media
from ...image import use_aspect_ratio
from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
from ... import debug
@@ -78,8 +79,9 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
action: str = None,
extra_data: dict = {},
seed: int = None,
width: int = 1024,
height: int = 1024,
aspect_ratio: str = None,
width: int = None,
height: int = None,
**kwargs
) -> AsyncResult:
try:
@@ -99,14 +101,14 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
) as session:
try:
if model in provider_together_urls:
data = {
data = use_aspect_ratio({
"response_format": "url",
"prompt": format_image_prompt(messages, prompt),
"model": model,
"width": width,
"height": height,
**extra_data
}
}, aspect_ratio)
async with session.post(provider_together_urls[model], json=data) as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import time
import asyncio
import random
import requests
@@ -8,12 +10,13 @@ from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import format_image_prompt
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...providers.response import ProviderInfo, ImageResponse, VideoResponse
from ...providers.response import ProviderInfo, ImageResponse, VideoResponse, Reasoning
from ...image.copy_images import save_response_media
from ...image import use_aspect_ratio
from ... import debug
class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
label = "HuggingFace (Image / Video Generation)"
label = "HuggingFace (Image/Video Generation)"
parent = "HuggingFace"
url = "https://huggingface.co"
working = True
@@ -79,6 +82,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str = None,
proxy: str = None,
timeout: int = 0,
aspect_ratio: str = "1:1",
**kwargs
):
provider_mapping = await cls.get_mapping(model, api_key)
@@ -91,85 +95,94 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
if key in ["replicate", "together", "hf-inference"]
}
provider_mapping = {**new_mapping, **provider_mapping}
last_response = None
for provider_key, provider in provider_mapping.items():
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"})
async def generate(extra_data: dict, prompt: str):
last_response = None
for provider_key, provider in provider_mapping.items():
provider_info = ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"})
api_base = f"https://router.huggingface.co/{provider_key}"
task = provider["task"]
provider_id = provider["providerId"]
if task not in cls.tasks:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
api_base = f"https://router.huggingface.co/{provider_key}"
task = provider["task"]
provider_id = provider["providerId"]
if task not in cls.tasks:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
prompt = format_image_prompt(messages, prompt)
if task == "text-to-video":
extra_data = {
"num_inference_steps": 20,
"video_size": "landscape_16_9",
**extra_data
}
else:
extra_data = {
"width": 1024,
"height": 1024,
**extra_data
}
if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}"
data = {
"prompt": prompt,
"image_size": "square_hd",
**extra_data
}
elif provider_key == "replicate":
url = f"{api_base}/v1/models/{provider_id}/prediction"
data = {
"input": {
prompt = format_image_prompt(messages, prompt)
if task == "text-to-video":
extra_data = {
"num_inference_steps": 20,
"resolution": "480p",
"aspect_ratio": aspect_ratio,
**extra_data
}
else:
extra_data = use_aspect_ratio(extra_data, aspect_ratio)
if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}"
data = {
"prompt": prompt,
"image_size": "square_hd",
**extra_data
}
}
elif provider_key in ("hf-inference", "hf-free"):
api_base = "https://api-inference.huggingface.co"
url = f"{api_base}/models/{provider_id}"
data = {
"inputs": prompt,
"parameters": {
"seed": random.randint(0, 2**32),
elif provider_key == "replicate":
url = f"{api_base}/v1/models/{provider_id}/prediction"
data = {
"input": {
"prompt": prompt,
**extra_data
}
}
elif provider_key in ("hf-inference", "hf-free"):
api_base = "https://api-inference.huggingface.co"
url = f"{api_base}/models/{provider_id}"
data = {
"inputs": prompt,
"parameters": {
"seed": random.randint(0, 2**32),
**extra_data
}
}
elif task == "text-to-image":
url = f"{api_base}/v1/images/generations"
data = {
"response_format": "url",
"prompt": prompt,
"model": provider_id,
**extra_data
}
}
elif task == "text-to-image":
url = f"{api_base}/v1/images/generations"
data = {
"response_format": "url",
"prompt": prompt,
"model": provider_id,
**extra_data
}
async with StreamSession(
headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"},
proxy=proxy,
timeout=timeout
) as session:
async with session.post(url, json=data) as response:
if response.status in (400, 401, 402):
last_response = response
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
continue
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
async for chunk in save_response_media(response, prompt):
yield chunk
return
result = await response.json()
if "video" in result:
yield VideoResponse(result["video"]["url"], prompt)
elif task == "text-to-image":
yield ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
elif task == "text-to-video":
yield VideoResponse(result["output"], prompt)
return
await raise_for_status(last_response)
async with StreamSession(
headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"},
proxy=proxy,
timeout=timeout
) as session:
async with session.post(url, json=data) as response:
if response.status in (400, 401, 402):
last_response = response
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
continue
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
async for chunk in save_response_media(response, prompt):
return provider_info, chunk
result = await response.json()
if "video" in result:
return provider_info, VideoResponse(result["video"]["url"], prompt)
elif task == "text-to-image":
return provider_info, ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
elif task == "text-to-video":
return provider_info, VideoResponse(result["output"], prompt)
await raise_for_status(last_response)
background_tasks = set()
started = time.time()
task = asyncio.create_task(generate(extra_data, prompt))
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
while background_tasks:
yield Reasoning(label="Generating", status=f"{time.time() - started:.2f}s")
await asyncio.sleep(0.2)
provider_info, media_response = await task
yield Reasoning(label="Finished", status=f"{time.time() - started:.2f}s")
yield provider_info
yield media_response

View File

@@ -6,6 +6,7 @@ import uuid
from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse, ImagePreview, JsonConversation, Reasoning
from ...requests import StreamSession
from ...image import use_aspect_ratio
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
@@ -56,8 +57,9 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
prompt: str = None,
proxy: str = None,
width: int = 1024,
height: int = 1024,
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
guidance_scale: float = 3.5,
num_inference_steps: int = 28,
seed: int = 0,
@@ -69,7 +71,8 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
async with StreamSession(impersonate="chrome", proxy=proxy) as session:
prompt = format_image_prompt(messages, prompt)
data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio)
data = [prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps]
conversation = JsonConversation(zerogpu_token=api_key, zerogpu_uuid=zerogpu_uuid, session_hash=uuid.uuid4().hex)
if conversation.zerogpu_token is None:
conversation.zerogpu_uuid, conversation.zerogpu_token = await get_zerogpu_token(cls.space, session, conversation, cookies)

View File

@@ -37,8 +37,9 @@ class G4F(DeepseekAI_JanusPro7b):
messages: Messages,
proxy: str = None,
prompt: str = None,
width: int = 1024,
height: int = 1024,
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
seed: int = None,
cookies: dict = None,
api_key: str = None,
@@ -50,6 +51,7 @@ class G4F(DeepseekAI_JanusPro7b):
model, messages,
proxy=proxy,
prompt=prompt,
aspect_ratio=aspect_ratio,
width=width,
height=height,
seed=seed,

View File

@@ -5,6 +5,7 @@ from aiohttp import ClientSession
from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse, ImagePreview
from ...image import use_aspect_ratio
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
@@ -29,8 +30,9 @@ class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin):
negative_prompt: str = None,
api_key: str = None,
proxy: str = None,
width: int = 1024,
height: int = 1024,
aspect_ratio: str = "1:1",
width: int = None,
height: int = None,
guidance_scale: float = 4.5,
num_inference_steps: int = 50,
seed: int = 0,
@@ -45,8 +47,9 @@ class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin):
headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session:
prompt = format_image_prompt(messages, prompt)
data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio)
data = {
"data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
"data": [prompt, negative_prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps]
}
async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
response.raise_for_status()

View File

@@ -358,37 +358,48 @@
gradient.classList.add('hidden');
const url = "https://image.pollinations.ai/feed";
const eventSource = new EventSource(url);
const imageFeed = document.getElementById("image-feed");
const images = []
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.nsfw || !data.nologo || data.width < 1024 || !data.imageURL || data.isChild) {
return;
}
const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) {
if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]);
console.debug("Skipping image:", data.imageURL);
return;
let es = null;
function initES() {
if (es == null || es.readyState == EventSource.CLOSED) {
const eventSource = new EventSource(url);
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.nsfw || !data.nologo || data.width < 512 || !data.imageURL || data.isChild || data.status != "end_generating") {
return;
}
const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) {
if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]);
console.debug("Skipping image:", data.imageURL);
return;
}
}
const landscape = window.innerWidth > window.innerHeight;
if (landscape && data.width > data.height) {
images.push(data.imageURL);
} else if (!landscape && data.width < data.height) {
images.push(data.imageURL);
}
};
eventSource.onerror = (event) => {
eventSource.close();
}
imageFeed.onerror = () => {
imageFeed.classList.add("hidden");
}
}
images.push(data.imageURL);
};
eventSource.onerror = (event) => {
eventSource.close();
}
imageFeed.onerror = () => {
imageFeed.classList.add("hidden");
}
initES();
setInterval(() => {
if (images.length > 0) {
imageFeed.classList.remove("hidden");
imageFeed.src = images.shift();
} else if(imageFeed) {
imageFeed.remove();
initES();
}
}, 7000);

View File

@@ -84,8 +84,8 @@ if (window.markdownit) {
.replaceAll('&quot;&gt;&lt;/video&gt;', '"></video>')
.replaceAll('&lt;audio controls src=&quot;', '<audio controls src="')
.replaceAll('&quot;&gt;&lt;/audio&gt;', '"></audio>')
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="390" width="640" src="')
.replaceAll('&quot;&gt;&lt;/iframe&gt;', `?enablejsapi=1&origin=${new URL(location.href).origin}"></iframe>`)
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="224" width="400" src="')
.replaceAll('&quot;&gt;&lt;/iframe&gt;', `?enablejsapi=1"></iframe>`)
}
}
@@ -95,7 +95,7 @@ function render_reasoning(reasoning, final = false) {
</div>` : "";
return `<div class="reasoning_body">
<div class="reasoning_title">
<strong>Reasoning <i class="brain">🧠</i>:</strong> ${escapeHtml(reasoning.status)}
<strong>${reasoning.label ? reasoning.label :'Reasoning <i class="brain">🧠</i>'}:</strong> ${escapeHtml(reasoning.status)}
</div>
${inner_text}
</div>`;
@@ -893,6 +893,8 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
message_storage[message_id] = "";
} else if (message.status) {
reasoning_storage[message_id].status = message.status;
} if (message.label) {
reasoning_storage[message_id].label = message.label;
} if (message.token) {
reasoning_storage[message_id].text += message.token;
}
@@ -999,7 +1001,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
content_map.inner.innerHTML = html;
highlight(content_map.inner);
}
if (message_storage[message_id]) {
if (message_storage[message_id] || reasoning_storage[message_id]) {
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
let usage = {};
if (usage_storage[message_id]) {
@@ -1064,7 +1066,35 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
}
// Reload conversation if no error
if (!error_storage[message_id] && reloadConversation) {
await safe_load_conversation(window.conversation_id, scroll);
if(await safe_load_conversation(window.conversation_id, scroll)) {
const new_message = Array.from(document.querySelectorAll(".message")).at(-1);
const new_media = new_message?.querySelector("audio, video, iframe");
if (new_media) {
if (new_media.tagName == "IFRAME") {
if (YT) {
async function onPlayerReady(event) {
if (scroll) {
await lazy_scroll_to_bottom();
}
event.target.setVolume(100);
event.target.playVideo();
}
player = new YT.Player(new_media, {
events: {
'onReady': onPlayerReady,
}
});
}
} else {
setTimeout(async () => {
if (scroll) {
await lazy_scroll_to_bottom();
}
new_media.play();
}, 2000);
}
}
}
}
let cursorDiv = message_el.querySelector(".cursor");
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
@@ -1121,6 +1151,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
api_key: api_key,
api_base: api_base,
ignored: ignored,
aspect_ratio: window.innerHeight > window.innerWidth ? "9:16" : "16:9",
...extra_parameters
}, Object.values(image_storage), message_id, scroll, finish_message);
} catch (e) {
@@ -1494,7 +1525,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
[...new Set(providers)].forEach(async (provider) => {
await load_provider_parameters(provider);
});
register_message_buttons();
await register_message_buttons();
highlight(message_box);
regenerate_button.classList.remove("regenerate-hidden");
@@ -1504,6 +1535,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
setTimeout(() => {
message_box.scrollTop = message_box.scrollHeight;
}, 500);
return true;
}
};
@@ -1516,7 +1548,7 @@ async function safe_load_conversation(conversation_id, scroll=true) {
}
}
if (!is_running) {
load_conversation(conversation_id, scroll);
return await load_conversation(conversation_id, scroll);
}
}

View File

@@ -9,7 +9,8 @@ import shutil
import random
import datetime
import tempfile
from flask import Flask, Response, request, jsonify, render_template, send_from_directory
from flask import Flask, Response, redirect, request, jsonify, render_template, send_from_directory
from werkzeug.exceptions import NotFound
from typing import Generator
from pathlib import Path
from urllib.parse import quote_plus
@@ -24,7 +25,7 @@ from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename
from ...image.copy_images import secure_filename, get_source_url
from ... import ChatCompletion
from ... import models
from .api import Api
@@ -333,46 +334,18 @@ class Backend_Api(Api):
def get_media(bucket_id, filename, dirname: str = None):
bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname))
media_dir = os.path.join(bucket_dir, "media")
if os.path.exists(media_dir):
try:
return send_from_directory(os.path.abspath(media_dir), filename)
return "Not found", 404
except NotFound:
source_url = get_source_url(request.query_string.decode())
if source_url is not None:
return redirect(source_url)
raise
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
def get_media_sub(dirname, bucket_id, filename):
return get_media(bucket_id, filename, dirname)
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename, dirname: str = None):
bucket_dir = secure_filename(bucket_id if dirname is None else dirname)
bucket_dir = get_bucket_dir(bucket_dir)
filename = secure_filename(filename)
bucket_path = Path(bucket_dir)
if dirname is not None:
bucket_path = bucket_path / secure_filename(bucket_id)
if not supports_filename(filename):
return jsonify({"error": {"message": f"File type not allowed"}}), 400
if not bucket_path.exists():
bucket_path.mkdir(parents=True, exist_ok=True)
try:
file_path = bucket_path / filename
file_data = request.get_data()
if not file_data:
return jsonify({"error": {"message": "No file data received"}}), 400
with file_path.open('wb') as f:
f.write(file_data)
return jsonify({"message": f"File '{filename}' uploaded successfully to bucket '{bucket_id}'"}), 201
except Exception as e:
return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
@app.route('/backend-api/v2/files/<bucket_id>/<dirname>/<filename>', methods=['PUT'])
def upload_file_sub(bucket_id, filename, dirname):
return upload_file(bucket_id, filename, dirname)
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
def upload_cookies():
file = None

View File

@@ -16,7 +16,7 @@ except ImportError:
from ..typing import ImageType, Union, Image
from ..errors import MissingRequirementsError
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'webm', 'svg', 'mp3', 'wav', 'mp4', 'flac', 'opus', 'ogg', 'mkv'}
EXTENSIONS_MAP: dict[str, str] = {
"image/png": "png",
@@ -259,6 +259,27 @@ def to_input_audio(audio: ImageType, filename: str = None) -> str:
}
raise ValueError("Invalid input audio")
def use_aspect_ratio(extra_data: dict, aspect_ratio: str) -> Image:
if aspect_ratio == "1:1":
extra_data = {
"width": 1024,
"height": 1024,
**extra_data
}
elif aspect_ratio == "16:9":
extra_data = {
"width": 800,
"height": 512,
**extra_data
}
elif aspect_ratio == "9:16":
extra_data = {
"width": 512,
"height": 800,
**extra_data
}
return extra_data
class ImageDataResponse():
def __init__(
self,

View File

@@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector, StreamResponse
from ..image import EXTENSIONS_MAP
from ..image import EXTENSIONS_MAP, ALLOWED_EXTENSIONS
from ..tools.files import get_bucket_dir
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi
@@ -21,9 +21,9 @@ from .. import debug
# Directory for storing generated images
images_dir = "./generated_images"
def get_media_extension(image: str) -> str:
"""Extract image extension from URL or filename, default to .jpg"""
match = re.search(r"\.(jpe?g|png|webp|mp4|mp3|wav)[?$]", image, re.IGNORECASE)
def get_media_extension(media: str) -> str:
"""Extract media file extension from URL or filename"""
match = re.search(r"\.(jpe?g|png|gif|svg|webp|webm|mp4|mp3|wav|flac|opus|ogg|mkv)(?:\?|$)", media, re.IGNORECASE)
return f".{match.group(1).lower()}" if match else ""
def ensure_images_dir():
@@ -51,8 +51,10 @@ def secure_filename(filename: str) -> str:
async def save_response_media(response: StreamResponse, prompt: str):
content_type = response.headers["content-type"]
if content_type in EXTENSIONS_MAP or content_type.startswith("audio/"):
if content_type in EXTENSIONS_MAP or content_type.startswith("audio/") or content_type.startswith("video/"):
extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3")
if extension not in ALLOWED_EXTENSIONS:
raise ValueError(f"Unsupported media type: {content_type}")
bucket_id = str(uuid.uuid4())
dirname = str(int(time.time()))
bucket_dir = get_bucket_dir(bucket_id, dirname)
@@ -135,11 +137,14 @@ async def copy_media(
if target is None and not os.path.splitext(target_path)[1]:
with open(target_path, "rb") as f:
file_header = f.read(12)
detected_type = is_accepted_format(file_header)
if detected_type:
new_ext = f".{detected_type.split('/')[-1]}"
os.rename(target_path, f"{target_path}{new_ext}")
target_path = f"{target_path}{new_ext}"
try:
detected_type = is_accepted_format(file_header)
if detected_type:
new_ext = f".{detected_type.split('/')[-1]}"
os.rename(target_path, f"{target_path}{new_ext}")
target_path = f"{target_path}{new_ext}"
except ValueError:
pass
# Build URL with safe encoding
url_filename = quote(os.path.basename(target_path))

View File

@@ -178,11 +178,13 @@ class Reasoning(ResponseType):
def __init__(
self,
token: Optional[str] = None,
label: Optional[str] = None,
status: Optional[str] = None,
is_thinking: Optional[str] = None
) -> None:
"""Initialize with token, status, and thinking state."""
self.token = token
self.label = label
self.status = status
self.is_thinking = is_thinking
@@ -203,6 +205,8 @@ class Reasoning(ResponseType):
def get_dict(self) -> Dict:
"""Return a dictionary representation of the reasoning."""
if self.label is not None:
return {"label": self.label, "status": self.status}
if self.is_thinking is None:
if self.status is None:
return {"token": self.token}
@@ -248,16 +252,22 @@ class YouTube(HiddenResponse):
for id in self.ids
]))
class Audio(ResponseType):
def __init__(self, data: bytes) -> None:
class AudioResponse(ResponseType):
def __init__(self, data: Union[bytes, str]) -> None:
"""Initialize with audio data bytes."""
self.data = data
def __str__(self) -> str:
def to_uri(self) -> str:
if isinstance(self.data, str):
return self.data
"""Return audio data as a base64-encoded data URI."""
data_base64 = base64.b64encode(self.data).decode()
return f"data:audio/mpeg;base64,{data_base64}"
def __str__(self) -> str:
"""Return audio as html element."""
return f'<audio controls src="{self.to_uri()}"></audio>'
class BaseConversation(ResponseType):
def __str__(self) -> str:
"""Return an empty string by default."""
@@ -282,7 +292,7 @@ class RequestLogin(HiddenResponse):
"""Return formatted login link as a string."""
return format_link(self.login_url, f"[Login to {self.label}]") + "\n\n"
class ImageResponse(ResponseType):
class MediaResponse(ResponseType):
def __init__(
self,
images: Union[str, List[str]],
@@ -294,10 +304,6 @@ class ImageResponse(ResponseType):
self.alt = alt
self.options = options
def __str__(self) -> str:
"""Return images as markdown."""
return format_images_markdown(self.images, self.alt, self.get("preview"))
def get(self, key: str) -> any:
"""Get an option value by key."""
return self.options.get(key)
@@ -306,6 +312,16 @@ class ImageResponse(ResponseType):
"""Return images as a list."""
return [self.images] if isinstance(self.images, str) else self.images
class ImageResponse(MediaResponse):
def __str__(self) -> str:
"""Return images as markdown."""
return format_images_markdown(self.images, self.alt, self.get("preview"))
class VideoResponse(MediaResponse):
def __str__(self) -> str:
"""Return videos as html elements."""
return "\n".join([f'<video controls src="{video}"></video>' for video in self.get_list()])
class ImagePreview(ImageResponse):
def __str__(self) -> str:
"""Return an empty string for preview."""

View File

@@ -4,7 +4,7 @@ import random
from ..typing import Type, List, CreateResult, Messages, AsyncResult
from .types import BaseProvider, BaseRetryProvider, ProviderType
from .response import ImageResponse, ProviderInfo
from .response import MediaResponse, ProviderInfo
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
@@ -59,7 +59,7 @@ class IterListProvider(BaseRetryProvider):
for chunk in response:
if chunk:
yield chunk
if isinstance(chunk, (str, ImageResponse)):
if isinstance(chunk, (str, MediaResponse)):
started = True
if started:
return
@@ -94,7 +94,7 @@ class IterListProvider(BaseRetryProvider):
async for chunk in response:
if chunk:
yield chunk
if isinstance(chunk, (str, ImageResponse)):
if isinstance(chunk, (str, MediaResponse)):
started = True
elif response:
response = await response

View File

@@ -64,6 +64,7 @@ class StreamResponse:
inner: Response = await self.inner
self.inner = inner
self.url = inner.url
self.method = inner.request.method
self.request = inner.request
self.status: int = inner.status_code
self.reason: str = inner.reason

View File

@@ -118,9 +118,10 @@ def supports_filename(filename: str):
return True
return False
def get_bucket_dir(bucket_id: str):
bucket_dir = os.path.join(get_cookies_dir(), "buckets", bucket_id)
return bucket_dir
def get_bucket_dir(bucket_id: str, dirname: str = None):
if dirname is None:
return os.path.join(get_cookies_dir(), "buckets", bucket_id)
return os.path.join(get_cookies_dir(), "buckets", dirname, bucket_id)
def get_buckets():
buckets_dir = os.path.join(get_cookies_dir(), "buckets")