mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-26 01:30:25 +08:00
Add HuggingFaceMedia provider with Video Generation
Add Support for Video Response in UI Improve Support for Audio Response in UI Fix ModelNotSupported errors in HuggingSpace providers
This commit is contained in:
@@ -14,7 +14,8 @@ from ..image import to_data_uri, is_data_an_audio, to_input_audio
|
||||
from ..errors import ModelNotFoundError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio, ToolCalls
|
||||
from ..image.copy_images import save_response_media
|
||||
from ..providers.response import FinishReason, Usage, ToolCalls
|
||||
from .. import debug
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
@@ -239,8 +240,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
async with session.get(url, allow_redirects=True) as response:
|
||||
await raise_for_status(response)
|
||||
image_url = str(response.url)
|
||||
yield ImageResponse(image_url, prompt)
|
||||
async for chunk in save_response_media(response, prompt):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _generate_text(
|
||||
@@ -305,10 +307,10 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
})
|
||||
async with session.post(url, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
if response.headers["content-type"] == "audio/mpeg":
|
||||
yield Audio(await response.read())
|
||||
async for chunk in save_response_media(response, messages[-1]["content"]):
|
||||
yield chunk
|
||||
return
|
||||
elif response.headers["content-type"].startswith("text/plain"):
|
||||
if response.headers["content-type"].startswith("text/plain"):
|
||||
yield await response.text()
|
||||
return
|
||||
elif response.headers["content-type"].startswith("text/event-stream"):
|
||||
|
||||
@@ -9,7 +9,7 @@ from .deprecated import *
|
||||
from .needs_auth import *
|
||||
from .not_working import *
|
||||
from .local import *
|
||||
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference
|
||||
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference, HuggingFaceMedia
|
||||
from .hf_space import *
|
||||
from .mini_max import HailuoAI, MiniMax
|
||||
from .template import OpenaiTemplate, BackendApi
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
|
||||
from ...cookies import get_cookies
|
||||
from .models import default_model, fallback_models, image_models, model_aliases, llama_models
|
||||
from .models import default_model, default_vision_model, fallback_models, image_models, model_aliases
|
||||
from ... import debug
|
||||
|
||||
class Conversation(JsonConversation):
|
||||
@@ -41,6 +41,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
supports_stream = True
|
||||
needs_auth = True
|
||||
default_model = default_model
|
||||
default_vision_model = default_vision_model
|
||||
model_aliases = model_aliases
|
||||
image_models = image_models
|
||||
text_models = fallback_models
|
||||
@@ -107,8 +108,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
) -> AsyncResult:
|
||||
if not has_curl_cffi:
|
||||
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
||||
if model == llama_models["name"]:
|
||||
model = llama_models["text"] if media is None else llama_models["vision"]
|
||||
if not model and media is not None:
|
||||
model = cls.default_vision_model
|
||||
model = cls.get_model(model)
|
||||
|
||||
session = Session(**auth_result.get_dict())
|
||||
|
||||
@@ -6,27 +6,30 @@ from ...providers.types import Messages
|
||||
from ...typing import MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import ModelNotSupportedError
|
||||
from ...providers.helper import get_last_user_message
|
||||
from ...providers.response import ProviderInfo
|
||||
from ..template.OpenaiTemplate import OpenaiTemplate
|
||||
from .models import model_aliases, vision_models, default_vision_model, llama_models, text_models
|
||||
from .models import model_aliases, vision_models, default_llama_model, default_vision_model, text_models
|
||||
from ... import debug
|
||||
|
||||
class HuggingFaceAPI(OpenaiTemplate):
|
||||
label = "HuggingFace (Inference API)"
|
||||
label = "HuggingFace (Text Generation)"
|
||||
parent = "HuggingFace"
|
||||
url = "https://api-inference.huggingface.com"
|
||||
api_base = "https://api-inference.huggingface.co/v1"
|
||||
working = True
|
||||
needs_auth = True
|
||||
|
||||
default_model = default_vision_model
|
||||
default_model = default_llama_model
|
||||
default_vision_model = default_vision_model
|
||||
vision_models = vision_models
|
||||
model_aliases = model_aliases
|
||||
fallback_models = text_models + vision_models
|
||||
|
||||
provider_mapping: dict[str, dict] = {}
|
||||
provider_mapping: dict[str, dict] = {
|
||||
"google/gemma-3-27b-it": {
|
||||
"hf-inference/models/google/gemma-3-27b-it": {
|
||||
"task": "conversational",
|
||||
"providerId": "google/gemma-3-27b-it"}}}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
@@ -47,7 +50,9 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
if [
|
||||
provider
|
||||
for provider in model.get("inferenceProviderMapping")
|
||||
if provider.get("task") == "conversational"]]
|
||||
if provider.get("status") == "live" and provider.get("task") == "conversational"
|
||||
]
|
||||
] + list(cls.provider_mapping.keys())
|
||||
else:
|
||||
cls.models = cls.fallback_models
|
||||
return cls.models
|
||||
@@ -78,11 +83,12 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
media: MediaListType = None,
|
||||
**kwargs
|
||||
):
|
||||
if model == llama_models["name"]:
|
||||
model = llama_models["text"] if media is None else llama_models["vision"]
|
||||
if model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
if not model and media is not None:
|
||||
model = cls.default_vision_model
|
||||
model = cls.get_model(model)
|
||||
provider_mapping = await cls.get_mapping(model, api_key)
|
||||
if not provider_mapping:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
for provider_key in provider_mapping:
|
||||
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
|
||||
api_base = f"https://router.huggingface.co/{api_path}"
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
|
||||
from ...errors import ModelNotSupportedError, ResponseError
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ImageResponse
|
||||
from ...image.copy_images import save_response_media
|
||||
from ..helper import format_image_prompt, get_last_user_message
|
||||
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
|
||||
from ... import debug
|
||||
@@ -176,11 +177,9 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
debug.log(f"Special token: {is_special}")
|
||||
yield FinishReason("stop" if is_special else "length")
|
||||
else:
|
||||
if response.headers["content-type"].startswith("image/"):
|
||||
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
|
||||
url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
|
||||
yield ImageResponse(url, inputs)
|
||||
else:
|
||||
async for chunk in save_response_media(response, prompt):
|
||||
yield chunk
|
||||
return
|
||||
yield (await response.json())[0]["generated_text"].strip()
|
||||
|
||||
def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:
|
||||
|
||||
175
g4f/Provider/hf/HuggingFaceMedia.py
Normal file
175
g4f/Provider/hf/HuggingFaceMedia.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import requests
|
||||
|
||||
from ...providers.types import Messages
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import ModelNotSupportedError
|
||||
from ...providers.helper import format_image_prompt
|
||||
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...providers.response import ProviderInfo, ImageResponse, VideoResponse
|
||||
from ...image.copy_images import save_response_media
|
||||
from ... import debug
|
||||
|
||||
class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "HuggingFace (Image / Video Generation)"
|
||||
parent = "HuggingFace"
|
||||
url = "https://huggingface.co"
|
||||
working = True
|
||||
needs_auth = True
|
||||
|
||||
tasks = ["text-to-image", "text-to-video"]
|
||||
provider_mapping: dict[str, dict] = {}
|
||||
task_mapping: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs) -> list[str]:
|
||||
if not cls.models:
|
||||
url = "https://huggingface.co/api/models?inference=warm&expand[]=inferenceProviderMapping"
|
||||
response = requests.get(url)
|
||||
if response.ok:
|
||||
models = response.json()
|
||||
cls.models = [
|
||||
model["id"]
|
||||
for model in models
|
||||
if [
|
||||
provider
|
||||
for provider in model.get("inferenceProviderMapping")
|
||||
if provider.get("status") == "live" and provider.get("task") in cls.tasks
|
||||
]
|
||||
]
|
||||
cls.task_mapping = {
|
||||
model["id"]: [
|
||||
provider.get("task")
|
||||
for provider in model.get("inferenceProviderMapping")
|
||||
].pop()
|
||||
for model in models
|
||||
}
|
||||
else:
|
||||
cls.models = []
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def get_mapping(cls, model: str, api_key: str = None):
|
||||
if model in cls.provider_mapping:
|
||||
return cls.provider_mapping[model]
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with StreamSession(
|
||||
timeout=30,
|
||||
headers=headers,
|
||||
) as session:
|
||||
async with session.get(f"https://huggingface.co/api/models/{model}?expand[]=inferenceProviderMapping") as response:
|
||||
await raise_for_status(response)
|
||||
model_data = await response.json()
|
||||
cls.provider_mapping[model] = {key: value for key, value in model_data.get("inferenceProviderMapping").items() if value["status"] == "live"}
|
||||
return cls.provider_mapping[model]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_key: str = None,
|
||||
extra_data: dict = {},
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
timeout: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
provider_mapping = await cls.get_mapping(model, api_key)
|
||||
headers = {
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
new_mapping = {
|
||||
"hf-free" if key == "hf-inference" else key: value for key, value in provider_mapping.items()
|
||||
if key in ["replicate", "together", "hf-inference"]
|
||||
}
|
||||
provider_mapping = {**new_mapping, **provider_mapping}
|
||||
last_response = None
|
||||
for provider_key, provider in provider_mapping.items():
|
||||
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"})
|
||||
|
||||
api_base = f"https://router.huggingface.co/{provider_key}"
|
||||
task = provider["task"]
|
||||
provider_id = provider["providerId"]
|
||||
if task not in cls.tasks:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
|
||||
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
if task == "text-to-video":
|
||||
extra_data = {
|
||||
"num_inference_steps": 20,
|
||||
"video_size": "landscape_16_9",
|
||||
**extra_data
|
||||
}
|
||||
else:
|
||||
extra_data = {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
**extra_data
|
||||
}
|
||||
if provider_key == "fal-ai":
|
||||
url = f"{api_base}/{provider_id}"
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"image_size": "square_hd",
|
||||
**extra_data
|
||||
}
|
||||
elif provider_key == "replicate":
|
||||
url = f"{api_base}/v1/models/{provider_id}/prediction"
|
||||
data = {
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
**extra_data
|
||||
}
|
||||
}
|
||||
elif provider_key in ("hf-inference", "hf-free"):
|
||||
api_base = "https://api-inference.huggingface.co"
|
||||
url = f"{api_base}/models/{provider_id}"
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"seed": random.randint(0, 2**32),
|
||||
**extra_data
|
||||
}
|
||||
}
|
||||
elif task == "text-to-image":
|
||||
url = f"{api_base}/v1/images/generations"
|
||||
data = {
|
||||
"response_format": "url",
|
||||
"prompt": prompt,
|
||||
"model": provider_id,
|
||||
**extra_data
|
||||
}
|
||||
|
||||
async with StreamSession(
|
||||
headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"},
|
||||
proxy=proxy,
|
||||
timeout=timeout
|
||||
) as session:
|
||||
async with session.post(url, json=data) as response:
|
||||
if response.status in (400, 401, 402):
|
||||
last_response = response
|
||||
debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}")
|
||||
continue
|
||||
if response.status == 404:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model}")
|
||||
await raise_for_status(response)
|
||||
async for chunk in save_response_media(response, prompt):
|
||||
yield chunk
|
||||
return
|
||||
result = await response.json()
|
||||
if "video" in result:
|
||||
yield VideoResponse(result["video"]["url"], prompt)
|
||||
elif task == "text-to-image":
|
||||
yield ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt)
|
||||
elif task == "text-to-video":
|
||||
yield VideoResponse(result["output"], prompt)
|
||||
return
|
||||
await raise_for_status(last_response)
|
||||
@@ -9,6 +9,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .HuggingChat import HuggingChat
|
||||
from .HuggingFaceAPI import HuggingFaceAPI
|
||||
from .HuggingFaceInference import HuggingFaceInference
|
||||
from .HuggingFaceMedia import HuggingFaceMedia
|
||||
from .models import model_aliases, vision_models, default_vision_model
|
||||
from ... import debug
|
||||
|
||||
@@ -51,6 +52,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
|
||||
if not cls.image_models:
|
||||
cls.get_models()
|
||||
try:
|
||||
async for chunk in HuggingFaceMedia.create_async_generator(model, messages, **kwargs):
|
||||
yield chunk
|
||||
return
|
||||
except ModelNotSupportedError:
|
||||
pass
|
||||
if model in cls.image_models:
|
||||
if "api_key" not in kwargs:
|
||||
async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):
|
||||
|
||||
@@ -47,9 +47,5 @@ extra_models = [
|
||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
]
|
||||
default_vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
default_llama_model = "meta-llama/Llama-3.3-70B-Instruct"
|
||||
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
|
||||
llama_models = {
|
||||
"name": "llama-3",
|
||||
"text": "meta-llama/Llama-3.3-70B-Instruct",
|
||||
"vision": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
}
|
||||
@@ -67,7 +67,6 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
zerogpu_uuid: str = "[object Object]",
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
async with StreamSession(impersonate="chrome", proxy=proxy) as session:
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
|
||||
|
||||
@@ -37,8 +37,6 @@ class BlackForestLabs_Flux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
randomize_seed: bool = True,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
|
||||
model = cls.get_model(model)
|
||||
width = max(32, width - (width % 8))
|
||||
height = max(32, height - (height % 8))
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
|
||||
@@ -24,9 +24,14 @@ class CohereForAI_C4AI_Command(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"command-r": "command-r",
|
||||
"command-r7b": "command-r7b-12-2024",
|
||||
}
|
||||
|
||||
models = list(model_aliases.keys())
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
if model in cls.model_aliases.values():
|
||||
return model
|
||||
return super().get_model(model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls, model: str, messages: Messages,
|
||||
|
||||
@@ -203,7 +203,21 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
Returns:
|
||||
A list of messages with the user input and the image, if any
|
||||
"""
|
||||
# Create a message object with the user role and the content
|
||||
# merged_messages = []
|
||||
# last_message = None
|
||||
# for message in messages:
|
||||
# current_message = last_message
|
||||
# if current_message is not None:
|
||||
# if current_message["role"] == message["role"]:
|
||||
# current_message["content"] += "\n" + message["content"]
|
||||
# else:
|
||||
# merged_messages.append(current_message)
|
||||
# last_message = message.copy()
|
||||
# else:
|
||||
# last_message = message.copy()
|
||||
# if last_message is not None:
|
||||
# merged_messages.append(last_message)
|
||||
|
||||
messages = [{
|
||||
"id": str(uuid.uuid4()),
|
||||
"author": {"role": message["role"]},
|
||||
|
||||
@@ -39,7 +39,7 @@ from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_p
|
||||
from g4f.providers.response import BaseConversation, JsonConversation
|
||||
from g4f.client.helper import filter_none
|
||||
from g4f.image import is_data_an_media
|
||||
from g4f.image.copy_images import images_dir, copy_images, get_source_url
|
||||
from g4f.image.copy_images import images_dir, copy_media, get_source_url
|
||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
|
||||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||
from g4f.Provider import ProviderType, ProviderUtils, __providers__
|
||||
@@ -594,10 +594,10 @@ class Api:
|
||||
ssl = False
|
||||
if source_url is not None:
|
||||
try:
|
||||
await copy_images(
|
||||
await copy_media(
|
||||
[source_url],
|
||||
target=target, ssl=ssl)
|
||||
debug.log(f"Image copied from {source_url}")
|
||||
debug.log(f"File copied from {source_url}")
|
||||
except Exception as e:
|
||||
debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
|
||||
return RedirectResponse(url=source_url)
|
||||
|
||||
@@ -9,7 +9,7 @@ import aiohttp
|
||||
import base64
|
||||
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
|
||||
|
||||
from ..image.copy_images import copy_images
|
||||
from ..image.copy_images import copy_media
|
||||
from ..typing import Messages, ImageType
|
||||
from ..providers.types import ProviderType, BaseRetryProvider
|
||||
from ..providers.response import *
|
||||
@@ -532,7 +532,7 @@ class Images:
|
||||
images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
|
||||
else:
|
||||
# Save locally for None (default) case
|
||||
images = await copy_images(response.get_list(), response.get("cookies"), proxy)
|
||||
images = await copy_media(response.get_list(), response.get("cookies"), proxy)
|
||||
images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
|
||||
|
||||
return ImagesResponse.model_construct(
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h1>QR Scanner & QR Code Generator</h1>
|
||||
<h1>QR Scanner & QR Code</h1>
|
||||
|
||||
<h2>QR Code Scanner</h2>
|
||||
<video id="video"></video>
|
||||
@@ -25,9 +25,8 @@
|
||||
<button id="toggleFlash">Toggle Flash</button>
|
||||
<p id="cam-status"></p>
|
||||
|
||||
<h2>Generate QR Code</h2>
|
||||
<h2>QR Code</h2>
|
||||
<div id="qrcode"></div>
|
||||
|
||||
<button id="generateQRCode">Generate QR Code</button>
|
||||
|
||||
<script type="module">
|
||||
|
||||
@@ -80,7 +80,11 @@ if (window.markdownit) {
|
||||
.replaceAll('<code>', '<code class="language-plaintext">')
|
||||
.replaceAll('<i class="', '<i class="')
|
||||
.replaceAll('"></i>', '"></i>')
|
||||
.replaceAll('<iframe type="text/html" src="', '<iframe type="text/html" frameborder="0" allow="fullscreen" src="')
|
||||
.replaceAll('<video controls src="', '<video controls width="400" src="')
|
||||
.replaceAll('"></video>', '"></video>')
|
||||
.replaceAll('<audio controls src="', '<audio controls src="')
|
||||
.replaceAll('"></audio>', '"></audio>')
|
||||
.replaceAll('<iframe type="text/html" src="', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="390" width="640" src="')
|
||||
.replaceAll('"></iframe>', `?enablejsapi=1&origin=${new URL(location.href).origin}"></iframe>`)
|
||||
}
|
||||
}
|
||||
@@ -229,7 +233,7 @@ function register_message_images() {
|
||||
let seed = Math.floor(Date.now() / 1000);
|
||||
newPath = `https://image.pollinations.ai/prompt/${newPath}?seed=${seed}&nologo=true`;
|
||||
let downloadUrl = newPath;
|
||||
if (document.getElementById("download_images")?.checked) {
|
||||
if (document.getElementById("download_media")?.checked) {
|
||||
downloadUrl = `/images/${filename}?url=${escapeHtml(newPath)}`;
|
||||
}
|
||||
const link = document.createElement("a");
|
||||
@@ -862,12 +866,6 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
|
||||
content_map.inner.innerHTML = markdown_render(message.preview);
|
||||
await register_message_images();
|
||||
}
|
||||
} else if (message.type == "audio") {
|
||||
audio = new Audio(message.audio);
|
||||
audio.controls = true;
|
||||
content_map.inner.appendChild(audio);
|
||||
audio.play();
|
||||
reloadConversation = false;
|
||||
} else if (message.type == "content") {
|
||||
message_storage[message_id] += message.content;
|
||||
update_message(content_map, message_id, null, scroll);
|
||||
@@ -1090,7 +1088,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
||||
} else {
|
||||
api_key = get_api_key_by_provider(provider);
|
||||
}
|
||||
const download_images = document.getElementById("download_images")?.checked;
|
||||
const download_media = document.getElementById("download_media")?.checked;
|
||||
let api_base;
|
||||
if (provider == "Custom") {
|
||||
api_base = document.getElementById("api_base")?.value;
|
||||
@@ -1119,7 +1117,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
||||
provider: provider,
|
||||
messages: messages,
|
||||
action: action,
|
||||
download_images: download_images,
|
||||
download_media: download_media,
|
||||
api_key: api_key,
|
||||
api_base: api_base,
|
||||
ignored: ignored,
|
||||
@@ -2723,6 +2721,9 @@ async function load_provider_models(provider=null) {
|
||||
option.value = model.model;
|
||||
option.dataset.label = model.model;
|
||||
option.text = `${model.model}${model.image ? " (Image Generation)" : ""}${model.vision ? " (Image Upload)" : ""}`;
|
||||
if (model.task) {
|
||||
option.text += ` (${model.task})`;
|
||||
}
|
||||
modelProvider.appendChild(option);
|
||||
if (model.default) {
|
||||
defaultIndex = i;
|
||||
|
||||
@@ -8,7 +8,7 @@ from flask import send_from_directory
|
||||
from inspect import signature
|
||||
|
||||
from ...errors import VersionNotFoundError
|
||||
from ...image.copy_images import copy_images, ensure_images_dir, images_dir
|
||||
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
|
||||
from ...tools.run_tools import iter_run_tools
|
||||
from ...Provider import ProviderUtils, __providers__
|
||||
from ...providers.base_provider import ProviderModelMixin
|
||||
@@ -53,6 +53,7 @@ class Api:
|
||||
"default": model == provider.default_model,
|
||||
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
|
||||
"image": False if provider.image_models is None else model in provider.image_models,
|
||||
"task": None if not hasattr(provider, "task_mapping") else provider.task_mapping[model] if model in provider.task_mapping else None
|
||||
}
|
||||
for model in models
|
||||
]
|
||||
@@ -127,7 +128,7 @@ class Api:
|
||||
**kwargs
|
||||
}
|
||||
|
||||
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
|
||||
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_media: bool = True) -> Iterator:
|
||||
def decorated_log(text: str, file = None):
|
||||
debug.logs.append(text)
|
||||
if debug.logging:
|
||||
@@ -154,7 +155,7 @@ class Api:
|
||||
if hasattr(provider_handler, "get_parameters"):
|
||||
yield self._format_json("parameters", provider_handler.get_parameters(as_json=True))
|
||||
try:
|
||||
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler})
|
||||
result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler, "download_media": download_media})
|
||||
for chunk in result:
|
||||
if isinstance(chunk, ProviderInfo):
|
||||
yield self.handle_provider(chunk, model)
|
||||
@@ -182,13 +183,13 @@ class Api:
|
||||
yield self._format_json("preview", chunk.to_string())
|
||||
elif isinstance(chunk, ImagePreview):
|
||||
yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
|
||||
elif isinstance(chunk, ImageResponse):
|
||||
images = chunk
|
||||
if download_images or chunk.get("cookies"):
|
||||
elif isinstance(chunk, (ImageResponse, VideoResponse)):
|
||||
media = chunk
|
||||
if download_media or chunk.get("cookies"):
|
||||
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
|
||||
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
|
||||
images = ImageResponse(images, chunk.alt)
|
||||
yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
|
||||
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
|
||||
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
|
||||
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
|
||||
elif isinstance(chunk, SynthesizeData):
|
||||
yield self._format_json("synthesize", chunk.get_dict())
|
||||
elif isinstance(chunk, TitleGeneration):
|
||||
@@ -205,8 +206,8 @@ class Api:
|
||||
yield self._format_json("reasoning", **chunk.get_dict())
|
||||
elif isinstance(chunk, YouTube):
|
||||
yield self._format_json("content", chunk.to_string())
|
||||
elif isinstance(chunk, Audio):
|
||||
yield self._format_json("audio", str(chunk))
|
||||
elif isinstance(chunk, AudioResponse):
|
||||
yield self._format_json("content", str(chunk))
|
||||
elif isinstance(chunk, DebugResponse):
|
||||
yield self._format_json("log", chunk.log)
|
||||
elif isinstance(chunk, RawResponse):
|
||||
|
||||
@@ -14,9 +14,8 @@ from typing import Generator
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from hashlib import sha256
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from ...image import is_allowed_extension, to_image
|
||||
from ...image import is_allowed_extension
|
||||
from ...client.service import convert_to_provider
|
||||
from ...providers.asyncio import to_sync_generator
|
||||
from ...client.helper import filter_markdown
|
||||
@@ -25,6 +24,7 @@ from ...tools.run_tools import iter_run_tools
|
||||
from ...errors import ProviderNotFoundError
|
||||
from ...image import is_allowed_extension
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...image.copy_images import secure_filename
|
||||
from ... import ChatCompletion
|
||||
from ... import models
|
||||
from .api import Api
|
||||
@@ -130,16 +130,14 @@ class Backend_Api(Api):
|
||||
if model != "default" and model in models.demo_models:
|
||||
json_data["provider"] = random.choice(models.demo_models[model][1])
|
||||
else:
|
||||
if not model or model == "default":
|
||||
json_data["model"] = models.demo_models["default"][0].name
|
||||
json_data["provider"] = random.choice(models.demo_models["default"][1])
|
||||
json_data["provider"] = models.HuggingFace
|
||||
kwargs = self._prepare_conversation_kwargs(json_data)
|
||||
return self.app.response_class(
|
||||
self._create_response_stream(
|
||||
kwargs,
|
||||
json_data.get("conversation_id"),
|
||||
json_data.get("provider"),
|
||||
json_data.get("download_images", True),
|
||||
json_data.get("download_media", True),
|
||||
),
|
||||
mimetype='text/event-stream'
|
||||
)
|
||||
@@ -331,21 +329,26 @@ class Backend_Api(Api):
|
||||
[f.write(f"{filename}\n") for filename in filenames]
|
||||
return {"bucket_id": bucket_id, "files": filenames, "media": media}
|
||||
|
||||
@app.route('/backend-api/v2/files/<bucket_id>/media/<filename>', methods=['GET'])
|
||||
def get_media(bucket_id, filename):
|
||||
bucket_id = secure_filename(bucket_id)
|
||||
bucket_dir = get_bucket_dir(bucket_id)
|
||||
@app.route('/files/<bucket_id>/media/<filename>', methods=['GET'])
|
||||
def get_media(bucket_id, filename, dirname: str = None):
|
||||
bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname))
|
||||
media_dir = os.path.join(bucket_dir, "media")
|
||||
if os.path.exists(media_dir):
|
||||
return send_from_directory(os.path.abspath(media_dir), filename)
|
||||
return "File not found", 404
|
||||
return "Not found", 404
|
||||
|
||||
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
|
||||
def get_media_sub(dirname, bucket_id, filename):
|
||||
return get_media(bucket_id, filename, dirname)
|
||||
|
||||
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
|
||||
def upload_file(bucket_id, filename):
|
||||
bucket_id = secure_filename(bucket_id)
|
||||
bucket_dir = get_bucket_dir(bucket_id)
|
||||
def upload_file(bucket_id, filename, dirname: str = None):
|
||||
bucket_dir = secure_filename(bucket_id if dirname is None else dirname)
|
||||
bucket_dir = get_bucket_dir(bucket_dir)
|
||||
filename = secure_filename(filename)
|
||||
bucket_path = Path(bucket_dir)
|
||||
if dirname is not None:
|
||||
bucket_path = bucket_path / secure_filename(bucket_id)
|
||||
|
||||
if not supports_filename(filename):
|
||||
return jsonify({"error": {"message": f"File type not allowed"}}), 400
|
||||
@@ -366,6 +369,10 @@ class Backend_Api(Api):
|
||||
except Exception as e:
|
||||
return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500
|
||||
|
||||
@app.route('/backend-api/v2/files/<bucket_id>/<dirname>/<filename>', methods=['PUT'])
|
||||
def upload_file_sub(bucket_id, filename, dirname):
|
||||
return upload_file(bucket_id, filename, dirname)
|
||||
|
||||
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
|
||||
def upload_cookies():
|
||||
file = None
|
||||
|
||||
@@ -10,7 +10,10 @@ from urllib.parse import quote, unquote
|
||||
from aiohttp import ClientSession, ClientError
|
||||
|
||||
from ..typing import Optional, Cookies
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..requests.aiohttp import get_connector, StreamResponse
|
||||
from ..image import EXTENSIONS_MAP
|
||||
from ..tools.files import get_bucket_dir
|
||||
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
|
||||
from ..Provider.template import BackendApi
|
||||
from . import is_accepted_format, extract_data_uri
|
||||
from .. import debug
|
||||
@@ -18,10 +21,10 @@ from .. import debug
|
||||
# Directory for storing generated images
|
||||
images_dir = "./generated_images"
|
||||
|
||||
def get_image_extension(image: str) -> str:
|
||||
def get_media_extension(image: str) -> str:
|
||||
"""Extract image extension from URL or filename, default to .jpg"""
|
||||
match = re.search(r"\.(jpe?g|png|webp)$", image, re.IGNORECASE)
|
||||
return f".{match.group(1).lower()}" if match else ".jpg"
|
||||
match = re.search(r"\.(jpe?g|png|webp|mp4|mp3|wav)[?$]", image, re.IGNORECASE)
|
||||
return f".{match.group(1).lower()}" if match else ""
|
||||
|
||||
def ensure_images_dir():
|
||||
"""Create images directory if it doesn't exist"""
|
||||
@@ -35,7 +38,42 @@ def get_source_url(image: str, default: str = None) -> str:
|
||||
return decoded_url
|
||||
return default
|
||||
|
||||
async def copy_images(
|
||||
def secure_filename(filename: str) -> str:
|
||||
# Keep letters, numbers, basic punctuation and all Unicode chars
|
||||
filename = re.sub(
|
||||
r'[^\w.,_-]+',
|
||||
'_',
|
||||
unquote(filename).strip(),
|
||||
flags=re.UNICODE
|
||||
)
|
||||
filename = filename[:100].strip(".,_-")
|
||||
return filename
|
||||
|
||||
async def save_response_media(response: StreamResponse, prompt: str):
|
||||
content_type = response.headers["content-type"]
|
||||
if content_type in EXTENSIONS_MAP or content_type.startswith("audio/"):
|
||||
extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3")
|
||||
bucket_id = str(uuid.uuid4())
|
||||
dirname = str(int(time.time()))
|
||||
bucket_dir = get_bucket_dir(bucket_id, dirname)
|
||||
media_dir = os.path.join(bucket_dir, "media")
|
||||
os.makedirs(media_dir, exist_ok=True)
|
||||
filename = secure_filename(f"{content_type[0:5] if prompt is None else prompt}.{extension}")
|
||||
newfile = os.path.join(media_dir, filename)
|
||||
with open(newfile, 'wb') as f:
|
||||
async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
|
||||
f.write(chunk)
|
||||
media_url = f"/files/{dirname}/{bucket_id}/media/{filename}"
|
||||
if response.method == "GET":
|
||||
media_url = f"{media_url}?url={str(response.url)}"
|
||||
if content_type.startswith("audio/"):
|
||||
yield AudioResponse(media_url)
|
||||
elif content_type.startswith("video/"):
|
||||
yield VideoResponse(media_url, prompt)
|
||||
else:
|
||||
yield ImageResponse(media_url, prompt)
|
||||
|
||||
async def copy_media(
|
||||
images: list[str],
|
||||
cookies: Optional[Cookies] = None,
|
||||
headers: Optional[dict] = None,
|
||||
@@ -60,33 +98,18 @@ async def copy_images(
|
||||
) as session:
|
||||
async def copy_image(image: str, target: str = None) -> str:
|
||||
"""Process individual image and return its local URL"""
|
||||
# Skip if image is already local
|
||||
if image.startswith("/"):
|
||||
return image
|
||||
target_path = target
|
||||
if target_path is None:
|
||||
# Generate filename components
|
||||
file_hash = hashlib.sha256(image.encode()).hexdigest()[:16]
|
||||
timestamp = int(time.time())
|
||||
|
||||
# Sanitize alt text for filename (Unicode-safe)
|
||||
if alt:
|
||||
# Keep letters, numbers, basic punctuation and all Unicode chars
|
||||
clean_alt = re.sub(
|
||||
r'[^\w\s.-]', # Allow all Unicode word chars
|
||||
'_',
|
||||
unquote(alt).strip(),
|
||||
flags=re.UNICODE
|
||||
)
|
||||
clean_alt = re.sub(r'[\s_]+', '_', clean_alt)[:100]
|
||||
else:
|
||||
clean_alt = "image"
|
||||
|
||||
# Build safe filename with full Unicode support
|
||||
extension = get_image_extension(image)
|
||||
filename = (
|
||||
f"{timestamp}_"
|
||||
f"{clean_alt}_"
|
||||
f"{file_hash}"
|
||||
f"{extension}"
|
||||
)
|
||||
filename = secure_filename("".join((
|
||||
f"{int(time.time())}_",
|
||||
(f"{alt}_" if alt else ""),
|
||||
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
|
||||
f"{get_media_extension(image)}"
|
||||
)))
|
||||
target_path = os.path.join(images_dir, filename)
|
||||
try:
|
||||
# Handle different image types
|
||||
|
||||
@@ -234,11 +234,6 @@ llama_3_1_405b = Model(
|
||||
)
|
||||
|
||||
# llama 3.2
|
||||
llama_3 = VisionModel(
|
||||
name = "llama-3",
|
||||
base_provider = "Meta Llama",
|
||||
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
||||
)
|
||||
|
||||
llama_3_2_1b = Model(
|
||||
name = "llama-3.2-1b",
|
||||
@@ -977,7 +972,6 @@ class ModelUtils:
|
||||
|
||||
|
||||
demo_models = {
|
||||
"default": [llama_3, [HuggingFace]],
|
||||
llama_3_2_11b.name: [llama_3_2_11b, [HuggingChat]],
|
||||
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
|
||||
deepseek_r1.name: [deepseek_r1, [HuggingFace, PollinationsAI]],
|
||||
|
||||
@@ -25,15 +25,12 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
||||
return
|
||||
if message is None:
|
||||
content_type = response.headers.get("content-type", "")
|
||||
# if content_type.startswith("application/json"):
|
||||
# try:
|
||||
# data = await response.json()
|
||||
# message = data.get("error")
|
||||
# if isinstance(message, dict):
|
||||
# message = data.get("message")
|
||||
# except Exception:
|
||||
# pass
|
||||
# else:
|
||||
if content_type.startswith("application/json"):
|
||||
message = await response.json()
|
||||
message = message.get("error", message)
|
||||
if isinstance(message, dict):
|
||||
message = message.get("message", message)
|
||||
else:
|
||||
text = (await response.text()).strip()
|
||||
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
|
||||
message = "HTML content" if is_html else text
|
||||
@@ -47,7 +44,9 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
||||
elif response.status == 403 and is_openai(text):
|
||||
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
|
||||
elif response.status == 502:
|
||||
raise ResponseStatusError(f"Response {response.status}: Bad gateway")
|
||||
raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
|
||||
elif response.status == 504:
|
||||
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
|
||||
else:
|
||||
raise ResponseStatusError(f"Response {response.status}: {message}")
|
||||
|
||||
@@ -70,6 +69,8 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
|
||||
elif response.status_code == 403 and is_openai(response.text):
|
||||
raise ResponseStatusError(f"Response {response.status_code}: OpenAI Bot detected")
|
||||
elif response.status_code == 502:
|
||||
raise ResponseStatusError(f"Response {response.status_code}: Bad gateway")
|
||||
raise ResponseStatusError(f"Response {response.status_code}: Bad Gateway")
|
||||
elif response.status_code == 504:
|
||||
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
|
||||
else:
|
||||
raise ResponseStatusError(f"Response {response.status_code}: {message}")
|
||||
Reference in New Issue
Block a user