Set default model in HuggingFaceMedia

Improve handling of shared chats
Show api_key input if required
This commit is contained in:
hlohaus
2025-03-26 01:32:05 +01:00
parent 6767000604
commit ce500f0d49
18 changed files with 206 additions and 111 deletions

19
etc/examples/video.py Normal file
View File

@@ -0,0 +1,19 @@
import g4f.Provider
from g4f.client import Client
client = Client(
provider=g4f.Provider.HuggingFaceMedia,
api_key="hf_***" # Your API key here
)
video_models = client.models.get_video()
print(video_models)
result = client.media.generate(
model=video_models[0],
prompt="G4F AI technology is the best in the world.",
response_format="url"
)
print(result.data[0].url)

View File

@@ -66,6 +66,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
for provider_data in provider_keys:
prepend_models.append(f"{model}:{provider_data.get('provider')}")
cls.models = prepend_models + [model for model in new_models if model not in prepend_models]
cls.image_models = [model for model, task in cls.task_mapping.items() if task == "text-to-image"]
cls.video_models = [model for model, task in cls.task_mapping.items() if task == "text-to-video"]
else:
cls.models = []
return cls.models
@@ -99,12 +101,14 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str = None,
proxy: str = None,
timeout: int = 0,
aspect_ratio: str = "1:1",
aspect_ratio: str = None,
**kwargs
):
selected_provider = None
if ":" in model:
if model and ":" in model:
model, selected_provider = model.split(":", 1)
elif not model:
model = cls.get_models()[0]
provider_mapping = await cls.get_mapping(model, api_key)
headers = {
'Accept-Encoding': 'gzip, deflate',
@@ -133,11 +137,11 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
extra_data = {
"num_inference_steps": 20,
"resolution": "480p",
"aspect_ratio": aspect_ratio,
"aspect_ratio": "16:9" if aspect_ratio is None else aspect_ratio,
**extra_data
}
else:
extra_data = use_aspect_ratio(extra_data, aspect_ratio)
extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio)
if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}"
data = {

View File

@@ -30,6 +30,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
default_model = "grok-3"
models = [default_model, "grok-3-thinking", "grok-2"]
model_aliases = {"grok-3-r1": "grok-3-thinking"}
@classmethod
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
@@ -73,7 +74,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
"sendFinalMetadata": True,
"customInstructions": "",
"deepsearchPreset": "",
"isReasoning": model.endswith("-thinking"),
"isReasoning": model.endswith("-thinking") or model.endswith("-r1"),
}
@classmethod

View File

@@ -92,7 +92,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
}
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
data = await response.json()
cls.raise_error(data)
cls.raise_error(data, response.status)
await raise_for_status(response)
yield ImageResponse([image["url"] for image in data["data"]], prompt)
return
@@ -135,7 +135,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
if content_type.startswith("application/json"):
data = await response.json()
cls.raise_error(data)
cls.raise_error(data, response.status)
await raise_for_status(response)
choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]:

View File

@@ -10,6 +10,7 @@ from email.utils import formatdate
import os.path
import hashlib
import asyncio
from urllib.parse import quote_plus
from fastapi import FastAPI, Response, Request, UploadFile, Depends
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
@@ -562,6 +563,10 @@ class Api:
})
async def get_media(filename, request: Request):
target = os.path.join(images_dir, os.path.basename(filename))
if not os.path.isfile(target):
other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename)))
if os.path.isfile(other_name):
target = other_name
ext = os.path.splitext(filename)[1][1:]
mime_type = EXTENSIONS_MAP.get(ext)
stat_result = SimpleNamespace()

View File

@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
from .image_models import ImageModels
from .image_models import MediaModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose
@@ -267,8 +267,11 @@ class Client(BaseClient):
) -> None:
super().__init__(**kwargs)
self.chat: Chat = Chat(self, provider)
if image_provider is None:
image_provider = provider
self.models: MediaModels = MediaModels(self, image_provider)
self.images: Images = Images(self, image_provider)
self.media: Images = Images(self, image_provider)
self.media: Images = self.images
class Completions:
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
@@ -349,7 +352,6 @@ class Images:
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.client: Client = client
self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
def generate(
self,
@@ -369,7 +371,7 @@ class Images:
if provider is None:
provider_handler = self.provider
if provider_handler is None:
provider_handler = self.models.get(model, default)
provider_handler = self.client.models.get(model, default)
elif isinstance(provider, str):
provider_handler = convert_to_provider(provider)
else:
@@ -385,19 +387,21 @@ class Images:
provider: Optional[ProviderType] = None,
response_format: Optional[str] = None,
proxy: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
if proxy is None:
proxy = self.client.proxy
if api_key is None:
api_key = self.client.api_key
error = None
response = None
if isinstance(provider_handler, IterListProvider):
for provider in provider_handler.providers:
try:
response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
if response is not None:
provider_name = provider.__name__
break
@@ -405,7 +409,7 @@ class Images:
error = e
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
if isinstance(response, MediaResponse):
return await self._process_image_response(
@@ -534,7 +538,7 @@ class Images:
else:
# Save locally for None (default) case
images = await copy_media(response.get_list(), response.get("cookies"), proxy)
images = [Image.model_construct(url=f"/media/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in images]
return ImagesResponse.model_construct(
created=int(time.time()),
@@ -552,6 +556,9 @@ class AsyncClient(BaseClient):
) -> None:
super().__init__(**kwargs)
self.chat: AsyncChat = AsyncChat(self, provider)
if image_provider is None:
image_provider = provider
self.models: MediaModels = MediaModels(self, image_provider)
self.images: AsyncImages = AsyncImages(self, image_provider)
self.media: AsyncImages = self.images
@@ -635,7 +642,6 @@ class AsyncImages(Images):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client
self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
async def generate(
self,

View File

@@ -1,15 +1,43 @@
from __future__ import annotations
from ..models import ModelUtils
from ..models import ModelUtils, ImageModel
from ..Provider import ProviderUtils
from ..providers.types import ProviderType
class ImageModels():
def __init__(self, client):
class MediaModels():
def __init__(self, client, provider: ProviderType = None):
self.client = client
self.provider = provider
def get(self, name, default=None):
def get(self, name, default=None) -> ProviderType:
if name in ModelUtils.convert:
return ModelUtils.convert[name].best_provider
if name in ProviderUtils.convert:
return ProviderUtils.convert[name]
return default
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
if self.provider is None:
return []
if api_key is None:
api_key = self.client.api_key
return self.provider.get_models(
**kwargs,
**{} if api_key is None else {"api_key": api_key}
)
def get_image(self, **kwargs) -> list[str]:
if self.provider is None:
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
self.get_all(**kwargs)
if hasattr(self.provider, "image_models"):
return self.provider.image_models
return []
def get_video(self, **kwargs) -> list[str]:
if self.provider is None:
return []
self.get_all(**kwargs)
if hasattr(self.provider, "video_models"):
return self.provider.video_models
return []

View File

@@ -61,9 +61,12 @@
const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">';
</script>
<script src="/static/js/highlight.min.js" async></script>
<script>window.conversation_id = "{{conversation_id}}"</script>
<script>window.chat_id = "{{chat_id}}"; window.share_url = "{{share_url}}";</script>
<script>
window.conversation_id = "{{conversation_id}}";
window.chat_id = "{{chat_id}}";
window.share_url = "{{share_url}}";
window.start_id = "{{conversation_id}}";
</script>
<title>G4F Chat</title>
</head>
<body>

View File

@@ -7,9 +7,9 @@
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
<style>
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover;}
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover; max-width: 100%;}
#qrcode { margin-top: 20px; }
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; }
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; max-width: 100%;}
button { margin: 5px; padding: 10px; }
</style>
</head>

View File

@@ -881,7 +881,7 @@ input.model:hover
padding: var(--inner-gap) 28px;
}
#systemPrompt, #chatPrompt, .settings textarea, form textarea {
#systemPrompt, #chatPrompt, .settings textarea, form textarea, .chat-body textarea {
font-size: 15px;
color: var(--colour-3);
outline: none;
@@ -1305,7 +1305,7 @@ form textarea {
padding: 0;
}
.settings textarea {
.settings textarea, .chat-body textarea {
height: 30px;
min-height: 30px;
padding: 6px;
@@ -1315,7 +1315,7 @@ form textarea {
text-wrap: nowrap;
}
form .field .fa-xmark {
.field .fa-xmark {
line-height: 20px;
cursor: pointer;
margin-left: auto;
@@ -1323,11 +1323,11 @@ form .field .fa-xmark {
margin-top: 0;
}
form .field.saved .fa-xmark {
.field.saved .fa-xmark {
color: var(--accent)
}
.settings .field, form .field {
.settings .field, form .field, .chat-body .field {
padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0;
}
@@ -1359,7 +1359,7 @@ form .field.saved .fa-xmark {
border: none;
}
.settings input, form input {
.settings input, form input, .chat-body input {
background-color: transparent;
padding: 2px;
border: none;
@@ -1368,11 +1368,11 @@ form .field.saved .fa-xmark {
color: var(--colour-3);
}
.settings input:focus, form input:focus {
.settings input:focus, form input:focus, .chat-body input:focus {
outline: none;
}
.settings .label, form .label, .settings label, form label {
.settings .label, form .label, .settings label, form label, .chat-body label {
font-size: 15px;
margin-left: var(--inner-gap);
}

View File

@@ -28,7 +28,7 @@ const switchInput = document.getElementById("switch");
const searchButton = document.getElementById("search");
const paperclip = document.querySelector(".user-input .fa-paperclip");
const optionElementsSelector = ".settings input, .settings textarea, #model, #model2, #provider";
const optionElementsSelector = ".settings input, .settings textarea, .chat-body input, #model, #model2, #provider";
let provider_storage = {};
let message_storage = {};
@@ -153,7 +153,7 @@ const iframe_close = Object.assign(document.createElement("button"), {
});
iframe_close.onclick = () => iframe_container.classList.add("hidden");
iframe_container.appendChild(iframe_close);
chat.appendChild(iframe_container);
document.body.appendChild(iframe_container);
class HtmlRenderPlugin {
constructor(options = {}) {
@@ -843,6 +843,16 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
conversation.data[key] = value;
}
await save_conversation(conversation_id, conversation);
} else if (message.type == "auth") {
error_storage[message_id] = message.message
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${message.message}`);
let provider = provider_storage[message_id]?.name;
let configEl = document.querySelector(`.settings .${provider}-api_key`);
if (configEl) {
configEl = configEl.parentElement.cloneNode(true);
content_map.content.appendChild(configEl);
await register_settings_storage();
}
} else if (message.type == "provider") {
provider_storage[message_id] = message.provider;
let provider_el = content_map.content.querySelector('.provider');
@@ -1122,10 +1132,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
let api_key;
if (is_demo && !provider) {
api_key = localStorage.getItem("HuggingFace-api_key");
if (!api_key) {
location.href = "/";
return;
}
} else {
api_key = get_api_key_by_provider(provider);
}
@@ -1221,6 +1227,7 @@ function sanitize(input, replacement) {
}
async function set_conversation_title(conversation_id, title) {
window.chat_id = null;
conversation = await get_conversation(conversation_id)
conversation.new_title = title;
const new_id = sanitize(title, " ");
@@ -1742,12 +1749,22 @@ const load_conversations = async () => {
let html = [];
conversations.forEach((conversation) => {
// const length = conversation.items.map((item) => (
// !item.content.toLowerCase().includes("hello") &&
// !item.content.toLowerCase().includes("hi") &&
// item.content
// ) ? 1 : 0).reduce((a,b)=>a+b, 0);
// if (!length) {
// appStorage.removeItem(`conversation:${conversation.id}`);
// return;
// }
const shareIcon = (conversation.id == window.start_id && window.chat_id) ? '<i class="fa-solid fa-qrcode"></i>': '';
html.push(`
<div class="convo" id="convo-${conversation.id}">
<div class="left" onclick="set_conversation('${conversation.id}')">
<i class="fa-regular fa-comments"></i>
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
<span class="convo-title">${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
<span class="convo-title">${shareIcon} ${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
</div>
<i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i>
<div id="cho-${conversation.id}" class="choise" style="display:none;">
@@ -2060,7 +2077,6 @@ window.addEventListener('load', async function() {
if (!window.conversation_id) {
window.conversation_id = window.chat_id;
}
window.start_id = window.conversation_id
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, {
headers: {'accept': 'application/json'},
});
@@ -2075,6 +2091,7 @@ window.addEventListener('load', async function() {
`conversation:${conversation.id}`,
JSON.stringify(conversation)
);
await load_conversations();
let refreshOnHide = true;
document.addEventListener("visibilitychange", () => {
if (document.hidden) {
@@ -2091,6 +2108,9 @@ window.addEventListener('load', async function() {
if (!refreshOnHide) {
return;
}
if (window.conversation_id != window.start_id) {
return;
}
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id}`, {
headers: {'accept': 'application/json', 'if-none-match': conversation.updated},
});
@@ -2102,6 +2122,7 @@ window.addEventListener('load', async function() {
`conversation:${conversation.id}`,
JSON.stringify(conversation)
);
await load_conversations();
await load_conversation(conversation);
}
}
@@ -2284,7 +2305,7 @@ async function on_api() {
}
} else if (provider.login_url) {
if (!login_urls[provider.name]) {
login_urls[provider.name] = [provider.label, provider.login_url, [], provider.auth];
login_urls[provider.name] = [provider.label, provider.login_url, [provider.name], provider.auth];
} else {
login_urls[provider.name][0] = provider.label;
login_urls[provider.name][1] = provider.login_url;

View File

@@ -7,7 +7,7 @@ from typing import Iterator
from flask import send_from_directory
from inspect import signature
from ...errors import VersionNotFoundError
from ...errors import VersionNotFoundError, MissingAuthError
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
@@ -187,7 +187,8 @@ class Api:
media = chunk
if download_media or chunk.get("cookies"):
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
tags = [tag for tag in [model, kwargs.get("aspect_ratio")] if tag]
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags))
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData):
@@ -214,6 +215,8 @@ class Api:
yield self._format_json(chunk.type, **chunk.get_dict())
else:
yield self._format_json("content", str(chunk))
except MissingAuthError as e:
yield self._format_json('auth', type(e).__name__, message=get_error_message(e))
except Exception as e:
logger.exception(e)
debug.error(e)

View File

@@ -16,7 +16,6 @@ from pathlib import Path
from urllib.parse import quote_plus
from hashlib import sha256
from ...image import is_allowed_extension
from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown
@@ -25,7 +24,7 @@ from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename, get_source_url
from ...image.copy_images import secure_filename, get_source_url, images_dir
from ... import ChatCompletion
from ... import models
from .api import Api
@@ -351,9 +350,30 @@ class Backend_Api(Api):
return redirect(source_url)
raise
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
def get_media_sub(dirname, bucket_id, filename):
return get_media(bucket_id, filename, dirname)
@app.route('/search/<search>', methods=['GET'])
def find_media(search: str, min: int = None):
search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
if min is None:
min = len(search)
if not os.access(images_dir, os.R_OK):
return jsonify({"error": {"message": "Not found"}}), 404
match_files = {}
for root, _, files in os.walk(images_dir):
for file in files:
mime_type = is_allowed_extension(file)
if mime_type is not None:
mime_type = secure_filename(mime_type)
for tag in search:
if tag in mime_type:
match_files[file] = match_files.get(file, 0) + 1
break
for tag in search:
if tag in file.lower():
match_files[file] = match_files.get(file, 0) + 1
match_files = [file for file, count in match_files.items() if count >= min]
if not match_files:
return jsonify({"error": {"message": "Not found"}}), 404
return redirect(f"/media/{random.choice(match_files)}")
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
def upload_cookies():
@@ -371,7 +391,7 @@ class Backend_Api(Api):
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET'])
def get_chat(chat_id: str) -> str:
chat_id = secure_filename(chat_id)
if int(self.chat_cache.get(chat_id, -1)) == int(request.headers.get("if-none-match", 0)):
if self.chat_cache.get(chat_id, 0) == request.headers.get("if-none-match", 0):
return jsonify({"error": {"message": "Not modified"}}), 304
bucket_dir = get_bucket_dir(chat_id)
file = os.path.join(bucket_dir, "chat.json")
@@ -379,7 +399,7 @@ class Backend_Api(Api):
return jsonify({"error": {"message": "Not found"}}), 404
with open(file, 'r') as f:
chat_data = json.load(f)
if int(chat_data.get("updated", 0)) == int(request.headers.get("if-none-match", 0)):
if chat_data.get("updated", 0) == request.headers.get("if-none-match", 0):
return jsonify({"error": {"message": "Not modified"}}), 304
self.chat_cache[chat_id] = chat_data.get("updated", 0)
return jsonify(chat_data), 200
@@ -387,12 +407,16 @@ class Backend_Api(Api):
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST'])
def upload_chat(chat_id: str) -> dict:
chat_data = {**request.json}
updated = chat_data.get("updated", 0)
cache_value = self.chat_cache.get(chat_id, 0)
if updated == cache_value:
return jsonify({"error": {"message": "invalid date"}}), 400
chat_id = secure_filename(chat_id)
bucket_dir = get_bucket_dir(chat_id)
os.makedirs(bucket_dir, exist_ok=True)
with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
json.dump(chat_data, f)
self.chat_cache[chat_id] = chat_data.get("updated", 0)
self.chat_cache[chat_id] = updated
return {"chat_id": chat_id}
def handle_synthesize(self, provider: str):

View File

@@ -6,6 +6,7 @@ import io
import base64
from io import BytesIO
from pathlib import Path
from typing import Optional
try:
from PIL.Image import open as open_image, new as new_image
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
@@ -17,15 +18,6 @@ from ..providers.helper import filter_none
from ..typing import ImageType, Union, Image
from ..errors import MissingRequirementsError
ALLOWED_EXTENSIONS = {
# Image
'png', 'jpg', 'jpeg', 'gif', 'webp',
# Audio
'wav', 'mp3', 'flac', 'opus', 'ogg',
# Video
'mkv', 'webm', 'mp4'
}
MEDIA_TYPE_MAP: dict[str, str] = {
"image/png": "png",
"image/jpeg": "jpg",
@@ -90,7 +82,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
return image
def is_allowed_extension(filename: str) -> bool:
def is_allowed_extension(filename: str) -> Optional[str]:
"""
Checks if the given filename has an allowed extension.
@@ -100,8 +92,8 @@ def is_allowed_extension(filename: str) -> bool:
Returns:
bool: True if the extension is allowed, False otherwise.
"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None
return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None
def is_data_an_media(data, filename: str = None) -> str:
content_type = is_data_an_audio(data, filename)
@@ -138,7 +130,7 @@ def is_data_uri_an_image(data_uri: str) -> bool:
# Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
if image_format not in EXTENSIONS_MAP and image_format != "svg+xml":
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> str:

View File

@@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector, StreamResponse
from ..image import MEDIA_TYPE_MAP, ALLOWED_EXTENSIONS
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
from ..tools.files import get_bucket_dir
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi
@@ -58,7 +58,7 @@ async def save_response_media(response: StreamResponse, prompt: str):
content_type = response.headers["content-type"]
if is_valid_media_type(content_type):
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
if extension not in ALLOWED_EXTENSIONS:
if extension not in EXTENSIONS_MAP:
raise ValueError(f"Unsupported media type: {content_type}")
bucket_id = str(uuid.uuid4())
dirname = str(int(time.time()))
@@ -86,6 +86,7 @@ async def copy_media(
headers: Optional[dict] = None,
proxy: Optional[str] = None,
alt: str = None,
tags: list[str] = None,
add_url: bool = True,
target: str = None,
ssl: bool = None
@@ -113,6 +114,7 @@ async def copy_media(
# Build safe filename with full Unicode support
filename = secure_filename("".join((
f"{int(time.time())}_",
(f"{''.join(tags, '_')}_" if tags else ""),
(f"{alt}_" if alt else ""),
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
f"{get_media_extension(image)}"

View File

@@ -18,6 +18,7 @@ from .Provider import (
FreeGpt,
HuggingSpace,
G4F,
Grok,
DeepseekAI_JanusPro7b,
Glider,
Goabror,
@@ -356,19 +357,19 @@ gemini_1_5_pro = Model(
gemini_2_0_flash = Model(
name = 'gemini-2.0-flash',
base_provider = 'Google DeepMind',
best_provider = IterListProvider([Dynaspark, GeminiPro, Liaobots])
best_provider = IterListProvider([Dynaspark, GeminiPro, Gemini])
)
gemini_2_0_flash_thinking = Model(
name = 'gemini-2.0-flash-thinking',
base_provider = 'Google DeepMind',
best_provider = Liaobots
best_provider = Gemini
)
gemini_2_0_pro = Model(
name = 'gemini-2.0-pro',
gemini_2_0_flash_thinking_with_apps = Model(
name = 'gemini-2.0-flash-thinking-with-apps',
base_provider = 'Google DeepMind',
best_provider = Liaobots
best_provider = Gemini
)
### Anthropic ###
@@ -379,19 +380,6 @@ claude_3_haiku = Model(
best_provider = IterListProvider([DDG, Jmuz])
)
claude_3_sonnet = Model(
name = 'claude-3-sonnet',
base_provider = 'Anthropic',
best_provider = Liaobots
)
claude_3_opus = Model(
name = 'claude-3-opus',
base_provider = 'Anthropic',
best_provider = IterListProvider([Jmuz, Liaobots])
)
# claude 3.5
claude_3_5_sonnet = Model(
name = 'claude-3.5-sonnet',
@@ -406,12 +394,6 @@ claude_3_7_sonnet = Model(
best_provider = IterListProvider([Blackbox, Liaobots])
)
claude_3_7_sonnet_thinking = Model(
name = 'claude-3.7-sonnet-thinking',
base_provider = 'Anthropic',
best_provider = Liaobots
)
### Reka AI ###
reka_core = Model(
name = 'reka-core',
@@ -548,13 +530,13 @@ janus_pro_7b = VisionModel(
grok_3 = Model(
name = 'grok-3',
base_provider = 'x.ai',
best_provider = Liaobots
best_provider = Grok
)
grok_3_r1 = Model(
name = 'grok-3-r1',
base_provider = 'x.ai',
best_provider = Liaobots
best_provider = Grok
)
### Perplexity AI ###
@@ -841,12 +823,10 @@ class ModelUtils:
gemini_1_5_flash.name: gemini_1_5_flash,
gemini_2_0_flash.name: gemini_2_0_flash,
gemini_2_0_flash_thinking.name: gemini_2_0_flash_thinking,
gemini_2_0_pro.name: gemini_2_0_pro,
gemini_2_0_flash_thinking_with_apps.name: gemini_2_0_flash_thinking_with_apps,
### Anthropic ###
# claude 3
claude_3_opus.name: claude_3_opus,
claude_3_sonnet.name: claude_3_sonnet,
claude_3_haiku.name: claude_3_haiku,
# claude 3.5
@@ -854,7 +834,6 @@ class ModelUtils:
# claude 3.7
claude_3_7_sonnet.name: claude_3_7_sonnet,
claude_3_7_sonnet_thinking.name: claude_3_7_sonnet_thinking,
### Reka AI ###
reka_core.name: reka_core,

View File

@@ -366,11 +366,15 @@ class ProviderModelMixin:
class RaiseErrorMixin():
@staticmethod
def raise_error(data: dict):
def raise_error(data: dict, status: int = None):
if "error_message" in data:
raise ResponseError(data["error_message"])
elif "error" in data:
if isinstance(data["error"], str):
if status is not None:
if status in (401, 402):
raise MissingAuthError(f"Error {status}: {data['error']}")
raise ResponseError(f"Error {status}: {data['error']}")
raise ResponseError(data["error"])
elif "code" in data["error"]:
raise ResponseError("\n".join(

View File

@@ -4,7 +4,7 @@ from typing import Union
from aiohttp import ClientResponse
from requests import Response as RequestsResponse
from ..errors import ResponseStatusError, RateLimitError
from ..errors import ResponseStatusError, RateLimitError, MissingAuthError
from . import Response, StreamResponse
class CloudflareError(ResponseStatusError):
@@ -23,6 +23,7 @@ def is_openai(text: str) -> bool:
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
if response.ok:
return
is_html = False
if message is None:
content_type = response.headers.get("content-type", "")
if content_type.startswith("application/json"):
@@ -31,39 +32,42 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
if isinstance(message, dict):
message = message.get("message", message)
else:
text = (await response.text()).strip()
message = (await response.text()).strip()
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else text
if message is None or message == "HTML content":
if message is None or is_html:
if response.status == 520:
message = "Unknown error (Cloudflare)"
elif response.status in (429, 402):
message = "Rate limit"
if response.status == 403 and is_cloudflare(text):
if response.status in (401, 402):
raise MissingAuthError(f"Response {response.status}: {message}")
if response.status == 403 and is_cloudflare(message):
raise CloudflareError(f"Response {response.status}: Cloudflare detected")
elif response.status == 403 and is_openai(text):
elif response.status == 403 and is_openai(message):
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
elif response.status == 502:
raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
elif response.status == 504:
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
else:
raise ResponseStatusError(f"Response {response.status}: {message}")
raise ResponseStatusError(f"Response {response.status}: {"HTML content" if is_html else message}")
def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None):
if hasattr(response, "status"):
return raise_for_status_async(response, message)
if response.ok:
return
is_html = False
if message is None:
is_html = response.headers.get("content-type", "").startswith("text/html") or response.text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else response.text
if message == "HTML content":
message = response.text
if message is None or is_html:
if response.status_code == 520:
message = "Unknown error (Cloudflare)"
elif response.status_code in (429, 402):
message = "Rate limit"
raise RateLimitError(f"Response {response.status_code}: {message}")
raise RateLimitError(f"Response {response.status_code}: Rate Limit")
if response.status_code in (401, 402):
raise MissingAuthError(f"Response {response.status_code}: {message}")
if response.status_code == 403 and is_cloudflare(response.text):
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
elif response.status_code == 403 and is_openai(response.text):
@@ -73,4 +77,4 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
elif response.status_code == 504:
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
else:
raise ResponseStatusError(f"Response {response.status_code}: {message}")
raise ResponseStatusError(f"Response {response.status_code}: {"HTML content" if is_html else message}")