Add HuggingFaceMedia provider with Video Generation

Add Support for Video Response in UI
Improve Support for Audio Response in UI
Fix ModelNotSupported errors in HuggingSpace providers
This commit is contained in:
hlohaus
2025-03-23 05:27:52 +01:00
parent 97f1964bb6
commit 8eaaf5db95
21 changed files with 356 additions and 128 deletions

View File

@@ -14,7 +14,8 @@ from ..image import to_data_uri, is_data_an_audio, to_input_audio
from ..errors import ModelNotFoundError from ..errors import ModelNotFoundError
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls from ..image.copy_images import save_response_media
from ..providers.response import FinishReason, Usage, ToolCalls
from .. import debug from .. import debug
DEFAULT_HEADERS = { DEFAULT_HEADERS = {
@@ -239,8 +240,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session: async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with session.get(url, allow_redirects=True) as response: async with session.get(url, allow_redirects=True) as response:
await raise_for_status(response) await raise_for_status(response)
image_url = str(response.url) async for chunk in save_response_media(response, prompt):
yield ImageResponse(image_url, prompt) yield chunk
return
@classmethod @classmethod
async def _generate_text( async def _generate_text(
@@ -305,10 +307,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
}) })
async with session.post(url, json=data) as response: async with session.post(url, json=data) as response:
await raise_for_status(response) await raise_for_status(response)
if response.headers["content-type"] == "audio/mpeg": async for chunk in save_response_media(response, messages[-1]["content"]):
yield Audio(await response.read()) yield chunk
return return
elif response.headers["content-type"].startswith("text/plain"): if response.headers["content-type"].startswith("text/plain"):
yield await response.text() yield await response.text()
return return
elif response.headers["content-type"].startswith("text/event-stream"): elif response.headers["content-type"].startswith("text/event-stream"):

View File

@@ -9,7 +9,7 @@ from .deprecated import *
from .needs_auth import * from .needs_auth import *
from .not_working import * from .not_working import *
from .local import * from .local import *
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference, HuggingFaceMedia
from .hf_space import * from .hf_space import *
from .mini_max import HailuoAI, MiniMax from .mini_max import HailuoAI, MiniMax
from .template import OpenaiTemplate, BackendApi from .template import OpenaiTemplate, BackendApi

View File

@@ -24,7 +24,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
from ...cookies import get_cookies from ...cookies import get_cookies
from .models import default_model, fallback_models, image_models, model_aliases, llama_models from .models import default_model, default_vision_model, fallback_models, image_models, model_aliases
from ... import debug from ... import debug
class Conversation(JsonConversation): class Conversation(JsonConversation):
@@ -41,6 +41,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
supports_stream = True supports_stream = True
needs_auth = True needs_auth = True
default_model = default_model default_model = default_model
default_vision_model = default_vision_model
model_aliases = model_aliases model_aliases = model_aliases
image_models = image_models image_models = image_models
text_models = fallback_models text_models = fallback_models
@@ -107,8 +108,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
) -> AsyncResult: ) -> AsyncResult:
if not has_curl_cffi: if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi') raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
if model == llama_models["name"]: if not model and media is not None:
model = llama_models["text"] if media is None else llama_models["vision"] model = cls.default_vision_model
model = cls.get_model(model) model = cls.get_model(model)
session = Session(**auth_result.get_dict()) session = Session(**auth_result.get_dict())

View File

@@ -6,27 +6,30 @@ from ...providers.types import Messages
from ...typing import MediaListType from ...typing import MediaListType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
from ...providers.response import ProviderInfo from ...providers.response import ProviderInfo
from ..template.OpenaiTemplate import OpenaiTemplate from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_vision_model, llama_models, text_models from .models import model_aliases, vision_models, default_llama_model, default_vision_model, text_models
from ... import debug from ... import debug
class HuggingFaceAPI(OpenaiTemplate): class HuggingFaceAPI(OpenaiTemplate):
label = "HuggingFace (Inference API)" label = "HuggingFace (Text Generation)"
parent = "HuggingFace" parent = "HuggingFace"
url = "https://api-inference.huggingface.com" url = "https://api-inference.huggingface.com"
api_base = "https://api-inference.huggingface.co/v1" api_base = "https://api-inference.huggingface.co/v1"
working = True working = True
needs_auth = True needs_auth = True
default_model = default_vision_model default_model = default_llama_model
default_vision_model = default_vision_model default_vision_model = default_vision_model
vision_models = vision_models vision_models = vision_models
model_aliases = model_aliases model_aliases = model_aliases
fallback_models = text_models + vision_models fallback_models = text_models + vision_models
provider_mapping: dict[str, dict] = {} provider_mapping: dict[str, dict] = {
"google/gemma-3-27b-it": {
"hf-inference/models/google/gemma-3-27b-it": {
"task": "conversational",
"providerId": "google/gemma-3-27b-it"}}}
@classmethod @classmethod
def get_model(cls, model: str, **kwargs) -> str: def get_model(cls, model: str, **kwargs) -> str:
@@ -47,7 +50,9 @@ class HuggingFaceAPI(OpenaiTemplate):
if [ if [
provider provider
for provider in model.get("inferenceProviderMapping") for provider in model.get("inferenceProviderMapping")
if provider.get("task") == "conversational"]] if provider.get("status") == "live" and provider.get("task") == "conversational"
]
] + list(cls.provider_mapping.keys())
else: else:
cls.models = cls.fallback_models cls.models = cls.fallback_models
return cls.models return cls.models
@@ -78,11 +83,12 @@ class HuggingFaceAPI(OpenaiTemplate):
media: MediaListType = None, media: MediaListType = None,
**kwargs **kwargs
): ):
if model == llama_models["name"]: if not model and media is not None:
model = llama_models["text"] if media is None else llama_models["vision"] model = cls.default_vision_model
if model in cls.model_aliases: model = cls.get_model(model)
model = cls.model_aliases[model]
provider_mapping = await cls.get_mapping(model, api_key) provider_mapping = await cls.get_mapping(model, api_key)
if not provider_mapping:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
for provider_key in provider_mapping: for provider_key in provider_mapping:
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1" api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
api_base = f"https://router.huggingface.co/{api_path}" api_base = f"https://router.huggingface.co/{api_path}"

View File

@@ -10,6 +10,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
from ...errors import ModelNotSupportedError, ResponseError from ...errors import ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse from ...providers.response import FinishReason, ImageResponse
from ...image.copy_images import save_response_media
from ..helper import format_image_prompt, get_last_user_message from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
from ... import debug from ... import debug
@@ -176,11 +177,9 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
debug.log(f"Special token: {is_special}") debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length") yield FinishReason("stop" if is_special else "length")
else: else:
if response.headers["content-type"].startswith("image/"): async for chunk in save_response_media(response, prompt):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()])) yield chunk
url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}" return
yield ImageResponse(url, inputs)
else:
yield (await response.json())[0]["generated_text"].strip() yield (await response.json())[0]["generated_text"].strip()
def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str: def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import random
import requests
from ...providers.types import Messages
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import format_image_prompt
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...providers.response import ProviderInfo, ImageResponse, VideoResponse
from ...image.copy_images import save_response_media
from ... import debug
class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
label = "HuggingFace (Image / Video Generation)"
parent = "HuggingFace"
url = "https://huggingface.co"
working = True
needs_auth = True
tasks = ["text-to-image", "text-to-video"]
provider_mapping: dict[str, dict] = {}
task_mapping: dict[str, str] = {}
@classmethod
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
url = "https://huggingface.co/api/models?inference=warm&expand[]=inferenceProviderMapping"
response = requests.get(url)
if response.ok:
models = response.json()
cls.models = [
model["id"]
for model in models
if [
provider
for provider in model.get("inferenceProviderMapping")
if provider.get("status") == "live" and provider.get("task") in cls.tasks
]
]
cls.task_mapping = {
model["id"]: [
provider.get("task")
for provider in model.get("inferenceProviderMapping")
].pop()
for model in models
}
else:
cls.models = []
return cls.models
@classmethod
async def get_mapping(cls, model: str, api_key: str = None):
if model in cls.provider_mapping:
return cls.provider_mapping[model]
headers = {
'Content-Type': 'application/json',
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with StreamSession(
timeout=30,
headers=headers,
) as session:
async with session.get(f"https://huggingface.co/api/models/{model}?expand[]=inferenceProviderMapping") as response:
await raise_for_status(response)
model_data = await response.json()
cls.provider_mapping[model] = {key: value for key, value in model_data.get("inferenceProviderMapping").items() if value["status"] == "live"}
return cls.provider_mapping[model]
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_key: str = None,
extra_data: dict = {},
prompt: str = None,
proxy: str = None,
timeout: int = 0,
**kwargs
):
provider_mapping = await cls.get_mapping(model, api_key)
headers = {
'Accept-Encoding': 'gzip, deflate',
'Content-Type': 'application/json',
}
new_mapping = {
"hf-free" if key == "hf-inference" else key: value for key, value in provider_mapping.items()
if key in ["replicate", "together", "hf-inference"]
}
provider_mapping = {**new_mapping, **provider_mapping}
last_response = None
for provider_key, provider in provider_mapping.items():
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"})
api_base = f"https://router.huggingface.co/{provider_key}"
task = provider["task"]
provider_id = provider["providerId"]
if task not in cls.tasks:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
prompt = format_image_prompt(messages, prompt)
if task == "text-to-video":
extra_data = {
"num_inference_steps": 20,
"video_size": "landscape_16_9",
**extra_data
}
else:
extra_data = {
"width": 1024,
"height": 1024,
**extra_data
}
if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}"
data = {
"prompt": prompt,
"image_size": "square_hd",
**extra_data
}
elif provider_key == "replicate":
url = f"{api_base}/v1/models/{provider_id}/prediction"
data = {
"input": {
"prompt": prompt,
**extra_data
}
}
elif provider_key in ("hf-inference", "hf-free"):
api_base = "https://api-inference.huggingface.co"
url = f"{api_base}/models/{provider_id}"
data = {
"inputs": prompt,
"parameters": {
"seed": random.randint(0, 2**32),
**extra_data
}
}
elif task == "text-to-image":
url = f"{api_base}/v1/images/generations"
data = {
"response_format": "url",
"prompt": prompt,
"model": provider_id,
**extra_data
}
async with StreamSession(
headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"},
proxy=proxy,
timeout=timeout
) as session:
async with session.post(url, json=data) as response:
if response.status in (400, 401, 402):
last_response = response
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
continue
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
async for chunk in save_response_media(response, prompt):
yield chunk
return
result = await response.json()
if "video" in result:
yield VideoResponse(result["video"]["url"], prompt)
elif task == "text-to-image":
yield ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
elif task == "text-to-video":
yield VideoResponse(result["output"], prompt)
return
await raise_for_status(last_response)

View File

@@ -9,6 +9,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .HuggingChat import HuggingChat from .HuggingChat import HuggingChat
from .HuggingFaceAPI import HuggingFaceAPI from .HuggingFaceAPI import HuggingFaceAPI
from .HuggingFaceInference import HuggingFaceInference from .HuggingFaceInference import HuggingFaceInference
from .HuggingFaceMedia import HuggingFaceMedia
from .models import model_aliases, vision_models, default_vision_model from .models import model_aliases, vision_models, default_vision_model
from ... import debug from ... import debug
@@ -51,6 +52,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
debug.error(f"{cls.__name__} {type(e).__name__}; {e}") debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
if not cls.image_models: if not cls.image_models:
cls.get_models() cls.get_models()
try:
async for chunk in HuggingFaceMedia.create_async_generator(model, messages, **kwargs):
yield chunk
return
except ModelNotSupportedError:
pass
if model in cls.image_models: if model in cls.image_models:
if "api_key" not in kwargs: if "api_key" not in kwargs:
async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs): async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):

View File

@@ -47,9 +47,5 @@ extra_models = [
"NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Hermes-3-Llama-3.1-8B",
] ]
default_vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" default_vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_llama_model = "meta-llama/Llama-3.3-70B-Instruct"
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"] vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
llama_models = {
"name": "llama-3",
"text": "meta-llama/Llama-3.3-70B-Instruct",
"vision": "meta-llama/Llama-3.2-11B-Vision-Instruct",
}

View File

@@ -67,7 +67,6 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
zerogpu_uuid: str = "[object Object]", zerogpu_uuid: str = "[object Object]",
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
model = cls.get_model(model)
async with StreamSession(impersonate="chrome", proxy=proxy) as session: async with StreamSession(impersonate="chrome", proxy=proxy) as session:
prompt = format_image_prompt(messages, prompt) prompt = format_image_prompt(messages, prompt)
data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]

View File

@@ -37,8 +37,6 @@ class BlackForestLabs_Flux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
randomize_seed: bool = True, randomize_seed: bool = True,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
model = cls.get_model(model)
width = max(32, width - (width % 8)) width = max(32, width - (width % 8))
height = max(32, height - (height % 8)) height = max(32, height - (height % 8))
prompt = format_image_prompt(messages, prompt) prompt = format_image_prompt(messages, prompt)

View File

@@ -24,9 +24,14 @@ class CohereForAI_C4AI_Command(AsyncGeneratorProvider, ProviderModelMixin):
"command-r": "command-r", "command-r": "command-r",
"command-r7b": "command-r7b-12-2024", "command-r7b": "command-r7b-12-2024",
} }
models = list(model_aliases.keys()) models = list(model_aliases.keys())
@classmethod
def get_model(cls, model: str, **kwargs) -> str:
if model in cls.model_aliases.values():
return model
return super().get_model(model, **kwargs)
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
cls, model: str, messages: Messages, cls, model: str, messages: Messages,

View File

@@ -203,7 +203,21 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
Returns: Returns:
A list of messages with the user input and the image, if any A list of messages with the user input and the image, if any
""" """
# Create a message object with the user role and the content # merged_messages = []
# last_message = None
# for message in messages:
# current_message = last_message
# if current_message is not None:
# if current_message["role"] == message["role"]:
# current_message["content"] += "\n" + message["content"]
# else:
# merged_messages.append(current_message)
# last_message = message.copy()
# else:
# last_message = message.copy()
# if last_message is not None:
# merged_messages.append(last_message)
messages = [{ messages = [{
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"author": {"role": message["role"]}, "author": {"role": message["role"]},

View File

@@ -39,7 +39,7 @@ from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_p
from g4f.providers.response import BaseConversation, JsonConversation from g4f.providers.response import BaseConversation, JsonConversation
from g4f.client.helper import filter_none from g4f.client.helper import filter_none
from g4f.image import is_data_an_media from g4f.image import is_data_an_media
from g4f.image.copy_images import images_dir, copy_images, get_source_url from g4f.image.copy_images import images_dir, copy_media, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__ from g4f.Provider import ProviderType, ProviderUtils, __providers__
@@ -594,10 +594,10 @@ class Api:
ssl = False ssl = False
if source_url is not None: if source_url is not None:
try: try:
await copy_images( await copy_media(
[source_url], [source_url],
target=target, ssl=ssl) target=target, ssl=ssl)
debug.log(f"Image copied from {source_url}") debug.log(f"File copied from {source_url}")
except Exception as e: except Exception as e:
debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}") debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
return RedirectResponse(url=source_url) return RedirectResponse(url=source_url)

View File

@@ -9,7 +9,7 @@ import aiohttp
import base64 import base64
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image.copy_images import copy_images from ..image.copy_images import copy_media
from ..typing import Messages, ImageType from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import * from ..providers.response import *
@@ -532,7 +532,7 @@ class Images:
images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()]) images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
else: else:
# Save locally for None (default) case # Save locally for None (default) case
images = await copy_images(response.get_list(), response.get("cookies"), proxy) images = await copy_media(response.get_list(), response.get("cookies"), proxy)
images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images] images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
return ImagesResponse.model_construct( return ImagesResponse.model_construct(

View File

@@ -15,7 +15,7 @@
</head> </head>
<body> <body>
<h1>QR Scanner & QR Code Generator</h1> <h1>QR Scanner & QR Code</h1>
<h2>QR Code Scanner</h2> <h2>QR Code Scanner</h2>
<video id="video"></video> <video id="video"></video>
@@ -25,9 +25,8 @@
<button id="toggleFlash">Toggle Flash</button> <button id="toggleFlash">Toggle Flash</button>
<p id="cam-status"></p> <p id="cam-status"></p>
<h2>Generate QR Code</h2> <h2>QR Code</h2>
<div id="qrcode"></div> <div id="qrcode"></div>
<button id="generateQRCode">Generate QR Code</button> <button id="generateQRCode">Generate QR Code</button>
<script type="module"> <script type="module">

View File

@@ -80,7 +80,11 @@ if (window.markdownit) {
.replaceAll('<code>', '<code class="language-plaintext">') .replaceAll('<code>', '<code class="language-plaintext">')
.replaceAll('&lt;i class=&quot;', '<i class="') .replaceAll('&lt;i class=&quot;', '<i class="')
.replaceAll('&quot;&gt;&lt;/i&gt;', '"></i>') .replaceAll('&quot;&gt;&lt;/i&gt;', '"></i>')
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" allow="fullscreen" src="') .replaceAll('&lt;video controls src=&quot;', '<video controls width="400" src="')
.replaceAll('&quot;&gt;&lt;/video&gt;', '"></video>')
.replaceAll('&lt;audio controls src=&quot;', '<audio controls src="')
.replaceAll('&quot;&gt;&lt;/audio&gt;', '"></audio>')
.replaceAll('&lt;iframe type=&quot;text/html&quot; src=&quot;', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="390" width="640" src="')
.replaceAll('&quot;&gt;&lt;/iframe&gt;', `?enablejsapi=1&origin=${new URL(location.href).origin}"></iframe>`) .replaceAll('&quot;&gt;&lt;/iframe&gt;', `?enablejsapi=1&origin=${new URL(location.href).origin}"></iframe>`)
} }
} }
@@ -229,7 +233,7 @@ function register_message_images() {
let seed = Math.floor(Date.now() / 1000); let seed = Math.floor(Date.now() / 1000);
newPath = `https://image.pollinations.ai/prompt/${newPath}?seed=${seed}&nologo=true`; newPath = `https://image.pollinations.ai/prompt/${newPath}?seed=${seed}&nologo=true`;
let downloadUrl = newPath; let downloadUrl = newPath;
if (document.getElementById("download_images")?.checked) { if (document.getElementById("download_media")?.checked) {
downloadUrl = `/images/${filename}?url=${escapeHtml(newPath)}`; downloadUrl = `/images/${filename}?url=${escapeHtml(newPath)}`;
} }
const link = document.createElement("a"); const link = document.createElement("a");
@@ -862,12 +866,6 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
content_map.inner.innerHTML = markdown_render(message.preview); content_map.inner.innerHTML = markdown_render(message.preview);
await register_message_images(); await register_message_images();
} }
} else if (message.type == "audio") {
audio = new Audio(message.audio);
audio.controls = true;
content_map.inner.appendChild(audio);
audio.play();
reloadConversation = false;
} else if (message.type == "content") { } else if (message.type == "content") {
message_storage[message_id] += message.content; message_storage[message_id] += message.content;
update_message(content_map, message_id, null, scroll); update_message(content_map, message_id, null, scroll);
@@ -1090,7 +1088,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} else { } else {
api_key = get_api_key_by_provider(provider); api_key = get_api_key_by_provider(provider);
} }
const download_images = document.getElementById("download_images")?.checked; const download_media = document.getElementById("download_media")?.checked;
let api_base; let api_base;
if (provider == "Custom") { if (provider == "Custom") {
api_base = document.getElementById("api_base")?.value; api_base = document.getElementById("api_base")?.value;
@@ -1119,7 +1117,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
provider: provider, provider: provider,
messages: messages, messages: messages,
action: action, action: action,
download_images: download_images, download_media: download_media,
api_key: api_key, api_key: api_key,
api_base: api_base, api_base: api_base,
ignored: ignored, ignored: ignored,
@@ -2723,6 +2721,9 @@ async function load_provider_models(provider=null) {
option.value = model.model; option.value = model.model;
option.dataset.label = model.model; option.dataset.label = model.model;
option.text = `${model.model}${model.image ? " (Image Generation)" : ""}${model.vision ? " (Image Upload)" : ""}`; option.text = `${model.model}${model.image ? " (Image Generation)" : ""}${model.vision ? " (Image Upload)" : ""}`;
if (model.task) {
option.text += ` (${model.task})`;
}
modelProvider.appendChild(option); modelProvider.appendChild(option);
if (model.default) { if (model.default) {
defaultIndex = i; defaultIndex = i;

View File

@@ -8,7 +8,7 @@ from flask import send_from_directory
from inspect import signature from inspect import signature
from ...errors import VersionNotFoundError from ...errors import VersionNotFoundError
from ...image.copy_images import copy_images, ensure_images_dir, images_dir from ...image.copy_images import copy_media, ensure_images_dir, images_dir
from ...tools.run_tools import iter_run_tools from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__ from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin from ...providers.base_provider import ProviderModelMixin
@@ -53,6 +53,7 @@ class Api:
"default": model == provider.default_model, "default": model == provider.default_model,
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []), "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": False if provider.image_models is None else model in provider.image_models, "image": False if provider.image_models is None else model in provider.image_models,
"task": None if not hasattr(provider, "task_mapping") else provider.task_mapping[model] if model in provider.task_mapping else None
} }
for model in models for model in models
] ]
@@ -127,7 +128,7 @@ class Api:
**kwargs **kwargs
} }
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator: def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_media: bool = True) -> Iterator:
def decorated_log(text: str, file = None): def decorated_log(text: str, file = None):
debug.logs.append(text) debug.logs.append(text)
if debug.logging: if debug.logging:
@@ -154,7 +155,7 @@ class Api:
if hasattr(provider_handler, "get_parameters"): if hasattr(provider_handler, "get_parameters"):
yield self._format_json("parameters", provider_handler.get_parameters(as_json=True)) yield self._format_json("parameters", provider_handler.get_parameters(as_json=True))
try: try:
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler}) result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler, "download_media": download_media})
for chunk in result: for chunk in result:
if isinstance(chunk, ProviderInfo): if isinstance(chunk, ProviderInfo):
yield self.handle_provider(chunk, model) yield self.handle_provider(chunk, model)
@@ -182,13 +183,13 @@ class Api:
yield self._format_json("preview", chunk.to_string()) yield self._format_json("preview", chunk.to_string())
elif isinstance(chunk, ImagePreview): elif isinstance(chunk, ImagePreview):
yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt) yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
elif isinstance(chunk, ImageResponse): elif isinstance(chunk, (ImageResponse, VideoResponse)):
images = chunk media = chunk
if download_images or chunk.get("cookies"): if download_media or chunk.get("cookies"):
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt) chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt)) media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
images = ImageResponse(images, chunk.alt) media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt) yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData): elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.get_dict()) yield self._format_json("synthesize", chunk.get_dict())
elif isinstance(chunk, TitleGeneration): elif isinstance(chunk, TitleGeneration):
@@ -205,8 +206,8 @@ class Api:
yield self._format_json("reasoning", **chunk.get_dict()) yield self._format_json("reasoning", **chunk.get_dict())
elif isinstance(chunk, YouTube): elif isinstance(chunk, YouTube):
yield self._format_json("content", chunk.to_string()) yield self._format_json("content", chunk.to_string())
elif isinstance(chunk, Audio): elif isinstance(chunk, AudioResponse):
yield self._format_json("audio", str(chunk)) yield self._format_json("content", str(chunk))
elif isinstance(chunk, DebugResponse): elif isinstance(chunk, DebugResponse):
yield self._format_json("log", chunk.log) yield self._format_json("log", chunk.log)
elif isinstance(chunk, RawResponse): elif isinstance(chunk, RawResponse):

View File

@@ -14,9 +14,8 @@ from typing import Generator
from pathlib import Path from pathlib import Path
from urllib.parse import quote_plus from urllib.parse import quote_plus
from hashlib import sha256 from hashlib import sha256
from werkzeug.utils import secure_filename
from ...image import is_allowed_extension, to_image from ...image import is_allowed_extension
from ...client.service import convert_to_provider from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown from ...client.helper import filter_markdown
@@ -25,6 +24,7 @@ from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension from ...image import is_allowed_extension
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename
from ... import ChatCompletion from ... import ChatCompletion
from ... import models from ... import models
from .api import Api from .api import Api
@@ -130,16 +130,14 @@ class Backend_Api(Api):
if model != "default" and model in models.demo_models: if model != "default" and model in models.demo_models:
json_data["provider"] = random.choice(models.demo_models[model][1]) json_data["provider"] = random.choice(models.demo_models[model][1])
else: else:
if not model or model == "default": json_data["provider"] = models.HuggingFace
json_data["model"] = models.demo_models["default"][0].name
json_data["provider"] = random.choice(models.demo_models["default"][1])
kwargs = self._prepare_conversation_kwargs(json_data) kwargs = self._prepare_conversation_kwargs(json_data)
return self.app.response_class( return self.app.response_class(
self._create_response_stream( self._create_response_stream(
kwargs, kwargs,
json_data.get("conversation_id"), json_data.get("conversation_id"),
json_data.get("provider"), json_data.get("provider"),
json_data.get("download_images", True), json_data.get("download_media", True),
), ),
mimetype='text/event-stream' mimetype='text/event-stream'
) )
@@ -331,21 +329,26 @@ class Backend_Api(Api):
[f.write(f"{filename}\n") for filename in filenames] [f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "files": filenames, "media": media} return {"bucket_id": bucket_id, "files": filenames, "media": media}
@app.route('/backend-api/v2/files/<bucket_id>/media/<filename>', methods=['GET']) @app.route('/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename): def get_media(bucket_id, filename, dirname: str = None):
bucket_id = secure_filename(bucket_id) bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname))
bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media") media_dir = os.path.join(bucket_dir, "media")
if os.path.exists(media_dir): if os.path.exists(media_dir):
return send_from_directory(os.path.abspath(media_dir), filename) return send_from_directory(os.path.abspath(media_dir), filename)
return "File not found", 404 return "Not found", 404
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
def get_media_sub(dirname, bucket_id, filename):
return get_media(bucket_id, filename, dirname)
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT']) @app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename): def upload_file(bucket_id, filename, dirname: str = None):
bucket_id = secure_filename(bucket_id) bucket_dir = secure_filename(bucket_id if dirname is None else dirname)
bucket_dir = get_bucket_dir(bucket_id) bucket_dir = get_bucket_dir(bucket_dir)
filename = secure_filename(filename) filename = secure_filename(filename)
bucket_path = Path(bucket_dir) bucket_path = Path(bucket_dir)
if dirname is not None:
bucket_path = bucket_path / secure_filename(bucket_id)
if not supports_filename(filename): if not supports_filename(filename):
return jsonify({"error": {"message": f"File type not allowed"}}), 400 return jsonify({"error": {"message": f"File type not allowed"}}), 400
@@ -366,6 +369,10 @@ class Backend_Api(Api):
except Exception as e: except Exception as e:
return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500 return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
@app.route('/backend-api/v2/files/<bucket_id>/<dirname>/<filename>', methods=['PUT'])
def upload_file_sub(bucket_id, filename, dirname):
return upload_file(bucket_id, filename, dirname)
@app.route('/backend-api/v2/upload_cookies', methods=['POST']) @app.route('/backend-api/v2/upload_cookies', methods=['POST'])
def upload_cookies(): def upload_cookies():
file = None file = None

View File

@@ -10,7 +10,10 @@ from urllib.parse import quote, unquote
from aiohttp import ClientSession, ClientError from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector from ..requests.aiohttp import get_connector, StreamResponse
from ..image import EXTENSIONS_MAP
from ..tools.files import get_bucket_dir
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi from ..Provider.template import BackendApi
from . import is_accepted_format, extract_data_uri from . import is_accepted_format, extract_data_uri
from .. import debug from .. import debug
@@ -18,10 +21,10 @@ from .. import debug
# Directory for storing generated images # Directory for storing generated images
images_dir = "./generated_images" images_dir = "./generated_images"
def get_image_extension(image: str) -> str: def get_media_extension(image: str) -> str:
"""Extract image extension from URL or filename, default to .jpg""" """Extract image extension from URL or filename, default to .jpg"""
match = re.search(r"\.(jpe?g|png|webp)$", image, re.IGNORECASE) match = re.search(r"\.(jpe?g|png|webp|mp4|mp3|wav)[?$]", image, re.IGNORECASE)
return f".{match.group(1).lower()}" if match else ".jpg" return f".{match.group(1).lower()}" if match else ""
def ensure_images_dir(): def ensure_images_dir():
"""Create images directory if it doesn't exist""" """Create images directory if it doesn't exist"""
@@ -35,7 +38,42 @@ def get_source_url(image: str, default: str = None) -> str:
return decoded_url return decoded_url
return default return default
async def copy_images( def secure_filename(filename: str) -> str:
# Keep letters, numbers, basic punctuation and all Unicode chars
filename = re.sub(
r'[^\w.,_-]+',
'_',
unquote(filename).strip(),
flags=re.UNICODE
)
filename = filename[:100].strip(".,_-")
return filename
async def save_response_media(response: StreamResponse, prompt: str):
content_type = response.headers["content-type"]
if content_type in EXTENSIONS_MAP or content_type.startswith("audio/"):
extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3")
bucket_id = str(uuid.uuid4())
dirname = str(int(time.time()))
bucket_dir = get_bucket_dir(bucket_id, dirname)
media_dir = os.path.join(bucket_dir, "media")
os.makedirs(media_dir, exist_ok=True)
filename = secure_filename(f"{content_type[0:5] if prompt is None else prompt}.{extension}")
newfile = os.path.join(media_dir, filename)
with open(newfile, 'wb') as f:
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
f.write(chunk)
media_url = f"/files/{dirname}/{bucket_id}/media/{filename}"
if response.method == "GET":
media_url = f"{media_url}?url={str(response.url)}"
if content_type.startswith("audio/"):
yield AudioResponse(media_url)
elif content_type.startswith("video/"):
yield VideoResponse(media_url, prompt)
else:
yield ImageResponse(media_url, prompt)
async def copy_media(
images: list[str], images: list[str],
cookies: Optional[Cookies] = None, cookies: Optional[Cookies] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
@@ -60,33 +98,18 @@ async def copy_images(
) as session: ) as session:
async def copy_image(image: str, target: str = None) -> str: async def copy_image(image: str, target: str = None) -> str:
"""Process individual image and return its local URL""" """Process individual image and return its local URL"""
# Skip if image is already local
if image.startswith("/"):
return image
target_path = target target_path = target
if target_path is None: if target_path is None:
# Generate filename components
file_hash = hashlib.sha256(image.encode()).hexdigest()[:16]
timestamp = int(time.time())
# Sanitize alt text for filename (Unicode-safe)
if alt:
# Keep letters, numbers, basic punctuation and all Unicode chars
clean_alt = re.sub(
r'[^\w\s.-]', # Allow all Unicode word chars
'_',
unquote(alt).strip(),
flags=re.UNICODE
)
clean_alt = re.sub(r'[\s_]+', '_', clean_alt)[:100]
else:
clean_alt = "image"
# Build safe filename with full Unicode support # Build safe filename with full Unicode support
extension = get_image_extension(image) filename = secure_filename("".join((
filename = ( f"{int(time.time())}_",
f"{timestamp}_" (f"{alt}_" if alt else ""),
f"{clean_alt}_" f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
f"{file_hash}" f"{get_media_extension(image)}"
f"{extension}" )))
)
target_path = os.path.join(images_dir, filename) target_path = os.path.join(images_dir, filename)
try: try:
# Handle different image types # Handle different image types

View File

@@ -234,11 +234,6 @@ llama_3_1_405b = Model(
) )
# llama 3.2 # llama 3.2
llama_3 = VisionModel(
name = "llama-3",
base_provider = "Meta Llama",
best_provider = IterListProvider([HuggingChat, HuggingFace])
)
llama_3_2_1b = Model( llama_3_2_1b = Model(
name = "llama-3.2-1b", name = "llama-3.2-1b",
@@ -977,7 +972,6 @@ class ModelUtils:
demo_models = { demo_models = {
"default": [llama_3, [HuggingFace]],
llama_3_2_11b.name: [llama_3_2_11b, [HuggingChat]], llama_3_2_11b.name: [llama_3_2_11b, [HuggingChat]],
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]], qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
deepseek_r1.name: [deepseek_r1, [HuggingFace, PollinationsAI]], deepseek_r1.name: [deepseek_r1, [HuggingFace, PollinationsAI]],

View File

@@ -25,15 +25,12 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
return return
if message is None: if message is None:
content_type = response.headers.get("content-type", "") content_type = response.headers.get("content-type", "")
# if content_type.startswith("application/json"): if content_type.startswith("application/json"):
# try: message = await response.json()
# data = await response.json() message = message.get("error", message)
# message = data.get("error") if isinstance(message, dict):
# if isinstance(message, dict): message = message.get("message", message)
# message = data.get("message") else:
# except Exception:
# pass
# else:
text = (await response.text()).strip() text = (await response.text()).strip()
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE") is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else text message = "HTML content" if is_html else text
@@ -47,7 +44,9 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
elif response.status == 403 and is_openai(text): elif response.status == 403 and is_openai(text):
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected") raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
elif response.status == 502: elif response.status == 502:
raise ResponseStatusError(f"Response {response.status}: Bad gateway") raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
elif response.status == 504:
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
else: else:
raise ResponseStatusError(f"Response {response.status}: {message}") raise ResponseStatusError(f"Response {response.status}: {message}")
@@ -70,6 +69,8 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
elif response.status_code == 403 and is_openai(response.text): elif response.status_code == 403 and is_openai(response.text):
raise ResponseStatusError(f"Response {response.status_code}: OpenAI Bot detected") raise ResponseStatusError(f"Response {response.status_code}: OpenAI Bot detected")
elif response.status_code == 502: elif response.status_code == 502:
raise ResponseStatusError(f"Response {response.status_code}: Bad gateway") raise ResponseStatusError(f"Response {response.status_code}: Bad Gateway")
elif response.status_code == 504:
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
else: else:
raise ResponseStatusError(f"Response {response.status_code}: {message}") raise ResponseStatusError(f"Response {response.status_code}: {message}")