Add HuggingFaceMedia provider with Video Generation

Add Support for Video Response in UI
Improve Support for Audio Response in UI
Fix ModelNotSupported errors in HuggingSpace providers
This commit is contained in:
hlohaus
2025-03-23 05:27:52 +01:00
parent 97f1964bb6
commit 8eaaf5db95
21 changed files with 356 additions and 128 deletions

View File

@@ -14,7 +14,8 @@ from ..image import to_data_uri, is_data_an_audio, to_input_audio
from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls
from ..image.copy_images import save_response_media
from ..providers.response import FinishReason, Usage, ToolCalls
from .. import debug
DEFAULT_HEADERS = {
@@ -239,8 +240,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with session.get(url, allow_redirects=True) as response:
await raise_for_status(response)
image_url = str(response.url)
yield ImageResponse(image_url, prompt)
async for chunk in save_response_media(response, prompt):
yield chunk
return
@classmethod
async def _generate_text(
@@ -305,10 +307,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
})
async with session.post(url, json=data) as response:
await raise_for_status(response)
if response.headers["content-type"] == "audio/mpeg":
yield Audio(await response.read())
async for chunk in save_response_media(response, messages[-1]["content"]):
yield chunk
return
elif response.headers["content-type"].startswith("text/plain"):
if response.headers["content-type"].startswith("text/plain"):
yield await response.text()
return
elif response.headers["content-type"].startswith("text/event-stream"):

View File

@@ -9,7 +9,7 @@ from .deprecated import *
from .needs_auth import *
from .not_working import *
from .local import *
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference, HuggingFaceMedia
from .hf_space import *
from .mini_max import HailuoAI, MiniMax
from .template import OpenaiTemplate, BackendApi

View File

@@ -24,7 +24,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
from ...cookies import get_cookies
from .models import default_model, fallback_models, image_models, model_aliases, llama_models
from .models import default_model, default_vision_model, fallback_models, image_models, model_aliases
from ... import debug
class Conversation(JsonConversation):
@@ -41,6 +41,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
supports_stream = True
needs_auth = True
default_model = default_model
default_vision_model = default_vision_model
model_aliases = model_aliases
image_models = image_models
text_models = fallback_models
@@ -107,8 +108,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
) -> AsyncResult:
if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
if model == llama_models["name"]:
model = llama_models["text"] if media is None else llama_models["vision"]
if not model and media is not None:
model = cls.default_vision_model
model = cls.get_model(model)
session = Session(**auth_result.get_dict())

View File

@@ -6,27 +6,30 @@ from ...providers.types import Messages
from ...typing import MediaListType
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
from ...providers.response import ProviderInfo
from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_vision_model, llama_models, text_models
from .models import model_aliases, vision_models, default_llama_model, default_vision_model, text_models
from ... import debug
class HuggingFaceAPI(OpenaiTemplate):
label = "HuggingFace (Inference API)"
label = "HuggingFace (Text Generation)"
parent = "HuggingFace"
url = "https://api-inference.huggingface.com"
api_base = "https://api-inference.huggingface.co/v1"
working = True
needs_auth = True
default_model = default_vision_model
default_model = default_llama_model
default_vision_model = default_vision_model
vision_models = vision_models
model_aliases = model_aliases
fallback_models = text_models + vision_models
provider_mapping: dict[str, dict] = {}
provider_mapping: dict[str, dict] = {
"google/gemma-3-27b-it": {
"hf-inference/models/google/gemma-3-27b-it": {
"task": "conversational",
"providerId": "google/gemma-3-27b-it"}}}
@classmethod
def get_model(cls, model: str, **kwargs) -> str:
@@ -47,7 +50,9 @@ class HuggingFaceAPI(OpenaiTemplate):
if [
provider
for provider in model.get("inferenceProviderMapping")
if provider.get("task") == "conversational"]]
if provider.get("status") == "live" and provider.get("task") == "conversational"
]
] + list(cls.provider_mapping.keys())
else:
cls.models = cls.fallback_models
return cls.models
@@ -78,11 +83,12 @@ class HuggingFaceAPI(OpenaiTemplate):
media: MediaListType = None,
**kwargs
):
if model == llama_models["name"]:
model = llama_models["text"] if media is None else llama_models["vision"]
if model in cls.model_aliases:
model = cls.model_aliases[model]
if not model and media is not None:
model = cls.default_vision_model
model = cls.get_model(model)
provider_mapping = await cls.get_mapping(model, api_key)
if not provider_mapping:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
for provider_key in provider_mapping:
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
api_base = f"https://router.huggingface.co/{api_path}"

View File

@@ -10,6 +10,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
from ...errors import ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse
from ...image.copy_images import save_response_media
from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
from ... import debug
@@ -176,11 +177,9 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length")
else:
if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
yield ImageResponse(url, inputs)
else:
async for chunk in save_response_media(response, prompt):
yield chunk
return
yield (await response.json())[0]["generated_text"].strip()
def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import random
import requests
from ...providers.types import Messages
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import format_image_prompt
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...providers.response import ProviderInfo, ImageResponse, VideoResponse
from ...image.copy_images import save_response_media
from ... import debug
class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
label = "HuggingFace (Image / Video Generation)"
parent = "HuggingFace"
url = "https://huggingface.co"
working = True
needs_auth = True
tasks = ["text-to-image", "text-to-video"]
provider_mapping: dict[str, dict] = {}
task_mapping: dict[str, str] = {}
@classmethod
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
url = "https://huggingface.co/api/models?inference=warm&expand[]=inferenceProviderMapping"
response = requests.get(url)
if response.ok:
models = response.json()
cls.models = [
model["id"]
for model in models
if [
provider
for provider in model.get("inferenceProviderMapping")
if provider.get("status") == "live" and provider.get("task") in cls.tasks
]
]
cls.task_mapping = {
model["id"]: [
provider.get("task")
for provider in model.get("inferenceProviderMapping")
].pop()
for model in models
}
else:
cls.models = []
return cls.models
@classmethod
async def get_mapping(cls, model: str, api_key: str = None):
if model in cls.provider_mapping:
return cls.provider_mapping[model]
headers = {
'Content-Type': 'application/json',
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with StreamSession(
timeout=30,
headers=headers,
) as session:
async with session.get(f"https://huggingface.co/api/models/{model}?expand[]=inferenceProviderMapping") as response:
await raise_for_status(response)
model_data = await response.json()
cls.provider_mapping[model] = {key: value for key, value in model_data.get("inferenceProviderMapping").items() if value["status"] == "live"}
return cls.provider_mapping[model]
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_key: str = None,
extra_data: dict = {},
prompt: str = None,
proxy: str = None,
timeout: int = 0,
**kwargs
):
provider_mapping = await cls.get_mapping(model, api_key)
headers = {
'Accept-Encoding': 'gzip, deflate',
'Content-Type': 'application/json',
}
new_mapping = {
"hf-free" if key == "hf-inference" else key: value for key, value in provider_mapping.items()
if key in ["replicate", "together", "hf-inference"]
}
provider_mapping = {**new_mapping, **provider_mapping}
last_response = None
for provider_key, provider in provider_mapping.items():
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"})
api_base = f"https://router.huggingface.co/{provider_key}"
task = provider["task"]
provider_id = provider["providerId"]
if task not in cls.tasks:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
prompt = format_image_prompt(messages, prompt)
if task == "text-to-video":
extra_data = {
"num_inference_steps": 20,
"video_size": "landscape_16_9",
**extra_data
}
else:
extra_data = {
"width": 1024,
"height": 1024,
**extra_data
}
if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}"
data = {
"prompt": prompt,
"image_size": "square_hd",
**extra_data
}
elif provider_key == "replicate":
url = f"{api_base}/v1/models/{provider_id}/prediction"
data = {
"input": {
"prompt": prompt,
**extra_data
}
}
elif provider_key in ("hf-inference", "hf-free"):
api_base = "https://api-inference.huggingface.co"
url = f"{api_base}/models/{provider_id}"
data = {
"inputs": prompt,
"parameters": {
"seed": random.randint(0, 2**32),
**extra_data
}
}
elif task == "text-to-image":
url = f"{api_base}/v1/images/generations"
data = {
"response_format": "url",
"prompt": prompt,
"model": provider_id,
**extra_data
}
async with StreamSession(
headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"},
proxy=proxy,
timeout=timeout
) as session:
async with session.post(url, json=data) as response:
if response.status in (400, 401, 402):
last_response = response
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
continue
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
async for chunk in save_response_media(response, prompt):
yield chunk
return
result = await response.json()
if "video" in result:
yield VideoResponse(result["video"]["url"], prompt)
elif task == "text-to-image":
yield ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
elif task == "text-to-video":
yield VideoResponse(result["output"], prompt)
return
await raise_for_status(last_response)

View File

@@ -9,6 +9,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .HuggingChat import HuggingChat
from .HuggingFaceAPI import HuggingFaceAPI
from .HuggingFaceInference import HuggingFaceInference
from .HuggingFaceMedia import HuggingFaceMedia
from .models import model_aliases, vision_models, default_vision_model
from ... import debug
@@ -51,6 +52,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
if not cls.image_models:
cls.get_models()
try:
async for chunk in HuggingFaceMedia.create_async_generator(model, messages, **kwargs):
yield chunk
return
except ModelNotSupportedError:
pass
if model in cls.image_models:
if "api_key" not in kwargs:
async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):

View File

@@ -47,9 +47,5 @@ extra_models = [
"NousResearch/Hermes-3-Llama-3.1-8B",
]
default_vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_llama_model = "meta-llama/Llama-3.3-70B-Instruct"
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
llama_models = {
"name": "llama-3",
"text": "meta-llama/Llama-3.3-70B-Instruct",
"vision": "meta-llama/Llama-3.2-11B-Vision-Instruct",
}

View File

@@ -67,7 +67,6 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
zerogpu_uuid: str = "[object Object]",
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
async with StreamSession(impersonate="chrome", proxy=proxy) as session:
prompt = format_image_prompt(messages, prompt)
data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]

View File

@@ -37,8 +37,6 @@ class BlackForestLabs_Flux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
randomize_seed: bool = True,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
width = max(32, width - (width % 8))
height = max(32, height - (height % 8))
prompt = format_image_prompt(messages, prompt)

View File

@@ -24,9 +24,14 @@ class CohereForAI_C4AI_Command(AsyncGeneratorProvider, ProviderModelMixin):
"command-r": "command-r",
"command-r7b": "command-r7b-12-2024",
}
models = list(model_aliases.keys())
@classmethod
def get_model(cls, model: str, **kwargs) -> str:
if model in cls.model_aliases.values():
return model
return super().get_model(model, **kwargs)
@classmethod
async def create_async_generator(
cls, model: str, messages: Messages,

View File

@@ -203,7 +203,21 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
Returns:
A list of messages with the user input and the image, if any
"""
# Create a message object with the user role and the content
# merged_messages = []
# last_message = None
# for message in messages:
# current_message = last_message
# if current_message is not None:
# if current_message["role"] == message["role"]:
# current_message["content"] += "\n" + message["content"]
# else:
# merged_messages.append(current_message)
# last_message = message.copy()
# else:
# last_message = message.copy()
# if last_message is not None:
# merged_messages.append(last_message)
messages = [{
"id": str(uuid.uuid4()),
"author": {"role": message["role"]},

View File

@@ -39,7 +39,7 @@ from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_p
from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none
from g4f.image import is_data_an_media
from g4f.image.copy_images import images_dir, copy_images, get_source_url
from g4f.image.copy_images import images_dir, copy_media, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__
@@ -594,10 +594,10 @@ class Api:
ssl = False
if source_url is not None:
try:
await copy_images(
await copy_media(
[source_url],
target=target, ssl=ssl)
debug.log(f"Image copied from {source_url}")
debug.log(f"File copied from {source_url}")
except Exception as e:
debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
return RedirectResponse(url=source_url)

View File

@@ -9,7 +9,7 @@ import aiohttp
import base64
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image.copy_images import copy_images
from ..image.copy_images import copy_media
from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import *
@@ -532,7 +532,7 @@ class Images:
images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
else:
# Save locally for None (default) case
images = await copy_images(response.get_list(), response.get("cookies"), proxy)
images = await copy_media(response.get_list(), response.get("cookies"), proxy)
images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
return ImagesResponse.model_construct(

View File

@@ -15,7 +15,7 @@
</head>
<body>
<h1>QR Scanner & QR Code Generator</h1>
<h1>QR Scanner & QR Code</h1>
<h2>QR Code Scanner</h2>
<video id="video"></video>
@@ -25,9 +25,8 @@
<button id="toggleFlash">Toggle Flash</button>
<p id="cam-status"></p>
<h2>Generate QR Code</h2>
<h2>QR Code</h2>
<div id="qrcode"></div>
<button id="generateQRCode">Generate QR Code</button>
<script type="module">

View File

@@ -80,7 +80,11 @@ if (window.markdownit) {
.replaceAll('<code>', '<code class="language-plaintext">')
.replaceAll('&lt;i class=&quot;', '<i class="')
.replaceAll('&quot;&gt;&lt;/i&gt;', '"></i>')
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" allow="fullscreen" src="')
.replaceAll('&lt;video controls src=&quot;', '<video controls width="400" src="')
.replaceAll('&quot;&gt;&lt;/video&gt;', '"></video>')
.replaceAll('&lt;audio controls src=&quot;', '<audio controls src="')
.replaceAll('&quot;&gt;&lt;/audio&gt;', '"></audio>')
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="390" width="640" src="')
.replaceAll('&quot;&gt;&lt;/iframe&gt;', `?enablejsapi=1&origin=${new URL(location.href).origin}"></iframe>`)
}
}
@@ -229,7 +233,7 @@ function register_message_images() {
let seed = Math.floor(Date.now() / 1000);
newPath = `https://image.pollinations.ai/prompt/${newPath}?seed=${seed}&nologo=true`;
let downloadUrl = newPath;
if (document.getElementById("download_images")?.checked) {
if (document.getElementById("download_media")?.checked) {
downloadUrl = `/images/${filename}?url=${escapeHtml(newPath)}`;
}
const link = document.createElement("a");
@@ -862,12 +866,6 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
content_map.inner.innerHTML = markdown_render(message.preview);
await register_message_images();
}
} else if (message.type == "audio") {
audio = new Audio(message.audio);
audio.controls = true;
content_map.inner.appendChild(audio);
audio.play();
reloadConversation = false;
} else if (message.type == "content") {
message_storage[message_id] += message.content;
update_message(content_map, message_id, null, scroll);
@@ -1090,7 +1088,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} else {
api_key = get_api_key_by_provider(provider);
}
const download_images = document.getElementById("download_images")?.checked;
const download_media = document.getElementById("download_media")?.checked;
let api_base;
if (provider == "Custom") {
api_base = document.getElementById("api_base")?.value;
@@ -1119,7 +1117,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
provider: provider,
messages: messages,
action: action,
download_images: download_images,
download_media: download_media,
api_key: api_key,
api_base: api_base,
ignored: ignored,
@@ -2723,6 +2721,9 @@ async function load_provider_models(provider=null) {
option.value = model.model;
option.dataset.label = model.model;
option.text = `${model.model}${model.image ? " (Image Generation)" : ""}${model.vision ? " (Image Upload)" : ""}`;
if (model.task) {
option.text += ` (${model.task})`;
}
modelProvider.appendChild(option);
if (model.default) {
defaultIndex = i;

View File

@@ -8,7 +8,7 @@ from flask import send_from_directory
from inspect import signature
from ...errors import VersionNotFoundError
from ...image.copy_images import copy_images, ensure_images_dir, images_dir
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin
@@ -53,6 +53,7 @@ class Api:
"default": model == provider.default_model,
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": False if provider.image_models is None else model in provider.image_models,
"task": None if not hasattr(provider, "task_mapping") else provider.task_mapping[model] if model in provider.task_mapping else None
}
for model in models
]
@@ -127,7 +128,7 @@ class Api:
**kwargs
}
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_media: bool = True) -> Iterator:
def decorated_log(text: str, file = None):
debug.logs.append(text)
if debug.logging:
@@ -154,7 +155,7 @@ class Api:
if hasattr(provider_handler, "get_parameters"):
yield self._format_json("parameters", provider_handler.get_parameters(as_json=True))
try:
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler})
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler, "download_media": download_media})
for chunk in result:
if isinstance(chunk, ProviderInfo):
yield self.handle_provider(chunk, model)
@@ -182,13 +183,13 @@ class Api:
yield self._format_json("preview", chunk.to_string())
elif isinstance(chunk, ImagePreview):
yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
elif isinstance(chunk, ImageResponse):
images = chunk
if download_images or chunk.get("cookies"):
elif isinstance(chunk, (ImageResponse, VideoResponse)):
media = chunk
if download_media or chunk.get("cookies"):
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.get_dict())
elif isinstance(chunk, TitleGeneration):
@@ -205,8 +206,8 @@ class Api:
yield self._format_json("reasoning", **chunk.get_dict())
elif isinstance(chunk, YouTube):
yield self._format_json("content", chunk.to_string())
elif isinstance(chunk, Audio):
yield self._format_json("audio", str(chunk))
elif isinstance(chunk, AudioResponse):
yield self._format_json("content", str(chunk))
elif isinstance(chunk, DebugResponse):
yield self._format_json("log", chunk.log)
elif isinstance(chunk, RawResponse):

View File

@@ -14,9 +14,8 @@ from typing import Generator
from pathlib import Path
from urllib.parse import quote_plus
from hashlib import sha256
from werkzeug.utils import secure_filename
from ...image import is_allowed_extension, to_image
from ...image import is_allowed_extension
from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown
@@ -25,6 +24,7 @@ from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename
from ... import ChatCompletion
from ... import models
from .api import Api
@@ -130,16 +130,14 @@ class Backend_Api(Api):
if model != "default" and model in models.demo_models:
json_data["provider"] = random.choice(models.demo_models[model][1])
else:
if not model or model == "default":
json_data["model"] = models.demo_models["default"][0].name
json_data["provider"] = random.choice(models.demo_models["default"][1])
json_data["provider"] = models.HuggingFace
kwargs = self._prepare_conversation_kwargs(json_data)
return self.app.response_class(
self._create_response_stream(
kwargs,
json_data.get("conversation_id"),
json_data.get("provider"),
json_data.get("download_images", True),
json_data.get("download_media", True),
),
mimetype='text/event-stream'
)
@@ -331,21 +329,26 @@ class Backend_Api(Api):
[f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "files": filenames, "media": media}
@app.route('/backend-api/v2/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
@app.route('/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename, dirname: str = None):
bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname))
media_dir = os.path.join(bucket_dir, "media")
if os.path.exists(media_dir):
return send_from_directory(os.path.abspath(media_dir), filename)
return "File not found", 404
return "Not found", 404
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
def get_media_sub(dirname, bucket_id, filename):
return get_media(bucket_id, filename, dirname)
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
def upload_file(bucket_id, filename, dirname: str = None):
bucket_dir = secure_filename(bucket_id if dirname is None else dirname)
bucket_dir = get_bucket_dir(bucket_dir)
filename = secure_filename(filename)
bucket_path = Path(bucket_dir)
if dirname is not None:
bucket_path = bucket_path / secure_filename(bucket_id)
if not supports_filename(filename):
return jsonify({"error": {"message": f"File type not allowed"}}), 400
@@ -366,6 +369,10 @@ class Backend_Api(Api):
except Exception as e:
return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
@app.route('/backend-api/v2/files/<bucket_id>/<dirname>/<filename>', methods=['PUT'])
def upload_file_sub(bucket_id, filename, dirname):
return upload_file(bucket_id, filename, dirname)
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
def upload_cookies():
file = None

View File

@@ -10,7 +10,10 @@ from urllib.parse import quote, unquote
from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector
from ..requests.aiohttp import get_connector, StreamResponse
from ..image import EXTENSIONS_MAP
from ..tools.files import get_bucket_dir
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi
from . import is_accepted_format, extract_data_uri
from .. import debug
@@ -18,10 +21,10 @@ from .. import debug
# Directory for storing generated images
images_dir = "./generated_images"
def get_image_extension(image: str) -> str:
def get_media_extension(image: str) -> str:
"""Extract image extension from URL or filename, default to .jpg"""
match = re.search(r"\.(jpe?g|png|webp)$", image, re.IGNORECASE)
return f".{match.group(1).lower()}" if match else ".jpg"
match = re.search(r"\.(jpe?g|png|webp|mp4|mp3|wav)[?$]", image, re.IGNORECASE)
return f".{match.group(1).lower()}" if match else ""
def ensure_images_dir():
"""Create images directory if it doesn't exist"""
@@ -35,7 +38,42 @@ def get_source_url(image: str, default: str = None) -> str:
return decoded_url
return default
async def copy_images(
def secure_filename(filename: str) -> str:
# Keep letters, numbers, basic punctuation and all Unicode chars
filename = re.sub(
r'[^\w.,_-]+',
'_',
unquote(filename).strip(),
flags=re.UNICODE
)
filename = filename[:100].strip(".,_-")
return filename
async def save_response_media(response: StreamResponse, prompt: str):
content_type = response.headers["content-type"]
if content_type in EXTENSIONS_MAP or content_type.startswith("audio/"):
extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3")
bucket_id = str(uuid.uuid4())
dirname = str(int(time.time()))
bucket_dir = get_bucket_dir(bucket_id, dirname)
media_dir = os.path.join(bucket_dir, "media")
os.makedirs(media_dir, exist_ok=True)
filename = secure_filename(f"{content_type[0:5] if prompt is None else prompt}.{extension}")
newfile = os.path.join(media_dir, filename)
with open(newfile, 'wb') as f:
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
f.write(chunk)
media_url = f"/files/{dirname}/{bucket_id}/media/{filename}"
if response.method == "GET":
media_url = f"{media_url}?url={str(response.url)}"
if content_type.startswith("audio/"):
yield AudioResponse(media_url)
elif content_type.startswith("video/"):
yield VideoResponse(media_url, prompt)
else:
yield ImageResponse(media_url, prompt)
async def copy_media(
images: list[str],
cookies: Optional[Cookies] = None,
headers: Optional[dict] = None,
@@ -60,33 +98,18 @@ async def copy_images(
) as session:
async def copy_image(image: str, target: str = None) -> str:
"""Process individual image and return its local URL"""
# Skip if image is already local
if image.startswith("/"):
return image
target_path = target
if target_path is None:
# Generate filename components
file_hash = hashlib.sha256(image.encode()).hexdigest()[:16]
timestamp = int(time.time())
# Sanitize alt text for filename (Unicode-safe)
if alt:
# Keep letters, numbers, basic punctuation and all Unicode chars
clean_alt = re.sub(
r'[^\w\s.-]', # Allow all Unicode word chars
'_',
unquote(alt).strip(),
flags=re.UNICODE
)
clean_alt = re.sub(r'[\s_]+', '_', clean_alt)[:100]
else:
clean_alt = "image"
# Build safe filename with full Unicode support
extension = get_image_extension(image)
filename = (
f"{timestamp}_"
f"{clean_alt}_"
f"{file_hash}"
f"{extension}"
)
filename = secure_filename("".join((
f"{int(time.time())}_",
(f"{alt}_" if alt else ""),
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
f"{get_media_extension(image)}"
)))
target_path = os.path.join(images_dir, filename)
try:
# Handle different image types

View File

@@ -234,11 +234,6 @@ llama_3_1_405b = Model(
)
# llama 3.2
llama_3 = VisionModel(
name = "llama-3",
base_provider = "Meta Llama",
best_provider = IterListProvider([HuggingChat, HuggingFace])
)
llama_3_2_1b = Model(
name = "llama-3.2-1b",
@@ -977,7 +972,6 @@ class ModelUtils:
demo_models = {
"default": [llama_3, [HuggingFace]],
llama_3_2_11b.name: [llama_3_2_11b, [HuggingChat]],
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
deepseek_r1.name: [deepseek_r1, [HuggingFace, PollinationsAI]],

View File

@@ -25,15 +25,12 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
return
if message is None:
content_type = response.headers.get("content-type", "")
# if content_type.startswith("application/json"):
# try:
# data = await response.json()
# message = data.get("error")
# if isinstance(message, dict):
# message = data.get("message")
# except Exception:
# pass
# else:
if content_type.startswith("application/json"):
message = await response.json()
message = message.get("error", message)
if isinstance(message, dict):
message = message.get("message", message)
else:
text = (await response.text()).strip()
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else text
@@ -47,7 +44,9 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
elif response.status == 403 and is_openai(text):
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
elif response.status == 502:
raise ResponseStatusError(f"Response {response.status}: Bad gateway")
raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
elif response.status == 504:
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
else:
raise ResponseStatusError(f"Response {response.status}: {message}")
@@ -70,6 +69,8 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
elif response.status_code == 403 and is_openai(response.text):
raise ResponseStatusError(f"Response {response.status_code}: OpenAI Bot detected")
elif response.status_code == 502:
raise ResponseStatusError(f"Response {response.status_code}: Bad gateway")
raise ResponseStatusError(f"Response {response.status_code}: Bad Gateway")
elif response.status_code == 504:
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
else:
raise ResponseStatusError(f"Response {response.status_code}: {message}")