Set default model in HuggingFaceMedia

Improve handling of shared chats
Show api_key input if required
This commit is contained in:
hlohaus
2025-03-26 01:32:05 +01:00
parent 6767000604
commit ce500f0d49
18 changed files with 206 additions and 111 deletions

19
etc/examples/video.py Normal file
View File

@@ -0,0 +1,19 @@
import g4f.Provider
from g4f.client import Client
client = Client(
provider=g4f.Provider.HuggingFaceMedia,
api_key="hf_***" # Your API key here
)
video_models = client.models.get_video()
print(video_models)
result = client.media.generate(
model=video_models[0],
prompt="G4F AI technology is the best in the world.",
response_format="url"
)
print(result.data[0].url)

View File

@@ -66,6 +66,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
for provider_data in provider_keys: for provider_data in provider_keys:
prepend_models.append(f"{model}:{provider_data.get('provider')}") prepend_models.append(f"{model}:{provider_data.get('provider')}")
cls.models = prepend_models + [model for model in new_models if model not in prepend_models] cls.models = prepend_models + [model for model in new_models if model not in prepend_models]
cls.image_models = [model for model, task in cls.task_mapping.items() if task == "text-to-image"]
cls.video_models = [model for model, task in cls.task_mapping.items() if task == "text-to-video"]
else: else:
cls.models = [] cls.models = []
return cls.models return cls.models
@@ -99,12 +101,14 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str = None, prompt: str = None,
proxy: str = None, proxy: str = None,
timeout: int = 0, timeout: int = 0,
aspect_ratio: str = "1:1", aspect_ratio: str = None,
**kwargs **kwargs
): ):
selected_provider = None selected_provider = None
if ":" in model: if model and ":" in model:
model, selected_provider = model.split(":", 1) model, selected_provider = model.split(":", 1)
elif not model:
model = cls.get_models()[0]
provider_mapping = await cls.get_mapping(model, api_key) provider_mapping = await cls.get_mapping(model, api_key)
headers = { headers = {
'Accept-Encoding': 'gzip, deflate', 'Accept-Encoding': 'gzip, deflate',
@@ -133,11 +137,11 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
extra_data = { extra_data = {
"num_inference_steps": 20, "num_inference_steps": 20,
"resolution": "480p", "resolution": "480p",
"aspect_ratio": aspect_ratio, "aspect_ratio": "16:9" if aspect_ratio is None else aspect_ratio,
**extra_data **extra_data
} }
else: else:
extra_data = use_aspect_ratio(extra_data, aspect_ratio) extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio)
if provider_key == "fal-ai": if provider_key == "fal-ai":
url = f"{api_base}/{provider_id}" url = f"{api_base}/{provider_id}"
data = { data = {

View File

@@ -30,6 +30,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
default_model = "grok-3" default_model = "grok-3"
models = [default_model, "grok-3-thinking", "grok-2"] models = [default_model, "grok-3-thinking", "grok-2"]
model_aliases = {"grok-3-r1": "grok-3-thinking"}
@classmethod @classmethod
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator: async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
@@ -73,7 +74,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
"sendFinalMetadata": True, "sendFinalMetadata": True,
"customInstructions": "", "customInstructions": "",
"deepsearchPreset": "", "deepsearchPreset": "",
"isReasoning": model.endswith("-thinking"), "isReasoning": model.endswith("-thinking") or model.endswith("-r1"),
} }
@classmethod @classmethod

View File

@@ -92,7 +92,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
} }
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response: async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
data = await response.json() data = await response.json()
cls.raise_error(data) cls.raise_error(data, response.status)
await raise_for_status(response) await raise_for_status(response)
yield ImageResponse([image["url"] for image in data["data"]], prompt) yield ImageResponse([image["url"] for image in data["data"]], prompt)
return return
@@ -135,7 +135,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json") content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
if content_type.startswith("application/json"): if content_type.startswith("application/json"):
data = await response.json() data = await response.json()
cls.raise_error(data) cls.raise_error(data, response.status)
await raise_for_status(response) await raise_for_status(response)
choice = data["choices"][0] choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]: if "content" in choice["message"] and choice["message"]["content"]:

View File

@@ -10,6 +10,7 @@ from email.utils import formatdate
import os.path import os.path
import hashlib import hashlib
import asyncio import asyncio
from urllib.parse import quote_plus
from fastapi import FastAPI, Response, Request, UploadFile, Depends from fastapi import FastAPI, Response, Request, UploadFile, Depends
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
@@ -562,6 +563,10 @@ class Api:
}) })
async def get_media(filename, request: Request): async def get_media(filename, request: Request):
target = os.path.join(images_dir, os.path.basename(filename)) target = os.path.join(images_dir, os.path.basename(filename))
if not os.path.isfile(target):
other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename)))
if os.path.isfile(other_name):
target = other_name
ext = os.path.splitext(filename)[1][1:] ext = os.path.splitext(filename)[1][1:]
mime_type = EXTENSIONS_MAP.get(ext) mime_type = EXTENSIONS_MAP.get(ext)
stat_result = SimpleNamespace() stat_result = SimpleNamespace()

View File

@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
from .image_models import ImageModels from .image_models import MediaModels
from .types import IterResponse, ImageProvider, Client as BaseClient from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose from .helper import find_stop, filter_json, filter_none, safe_aclose
@@ -267,8 +267,11 @@ class Client(BaseClient):
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.chat: Chat = Chat(self, provider) self.chat: Chat = Chat(self, provider)
if image_provider is None:
image_provider = provider
self.models: MediaModels = MediaModels(self, image_provider)
self.images: Images = Images(self, image_provider) self.images: Images = Images(self, image_provider)
self.media: Images = Images(self, image_provider) self.media: Images = self.images
class Completions: class Completions:
def __init__(self, client: Client, provider: Optional[ProviderType] = None): def __init__(self, client: Client, provider: Optional[ProviderType] = None):
@@ -349,7 +352,6 @@ class Images:
def __init__(self, client: Client, provider: Optional[ProviderType] = None): def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.client: Client = client self.client: Client = client
self.provider: Optional[ProviderType] = provider self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
def generate( def generate(
self, self,
@@ -369,7 +371,7 @@ class Images:
if provider is None: if provider is None:
provider_handler = self.provider provider_handler = self.provider
if provider_handler is None: if provider_handler is None:
provider_handler = self.models.get(model, default) provider_handler = self.client.models.get(model, default)
elif isinstance(provider, str): elif isinstance(provider, str):
provider_handler = convert_to_provider(provider) provider_handler = convert_to_provider(provider)
else: else:
@@ -385,19 +387,21 @@ class Images:
provider: Optional[ProviderType] = None, provider: Optional[ProviderType] = None,
response_format: Optional[str] = None, response_format: Optional[str] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs **kwargs
) -> ImagesResponse: ) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__ provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
if proxy is None: if proxy is None:
proxy = self.client.proxy proxy = self.client.proxy
if api_key is None:
api_key = self.client.api_key
error = None error = None
response = None response = None
if isinstance(provider_handler, IterListProvider): if isinstance(provider_handler, IterListProvider):
for provider in provider_handler.providers: for provider in provider_handler.providers:
try: try:
response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs) response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
if response is not None: if response is not None:
provider_name = provider.__name__ provider_name = provider.__name__
break break
@@ -405,7 +409,7 @@ class Images:
error = e error = e
debug.error(f"{provider.__name__} {type(e).__name__}: {e}") debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
if isinstance(response, MediaResponse): if isinstance(response, MediaResponse):
return await self._process_image_response( return await self._process_image_response(
@@ -534,7 +538,7 @@ class Images:
else: else:
# Save locally for None (default) case # Save locally for None (default) case
images = await copy_media(response.get_list(), response.get("cookies"), proxy) images = await copy_media(response.get_list(), response.get("cookies"), proxy)
images = [Image.model_construct(url=f"/media/{os.path.basename(image)}", revised_prompt=response.alt) for image in images] images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in images]
return ImagesResponse.model_construct( return ImagesResponse.model_construct(
created=int(time.time()), created=int(time.time()),
@@ -552,6 +556,9 @@ class AsyncClient(BaseClient):
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.chat: AsyncChat = AsyncChat(self, provider) self.chat: AsyncChat = AsyncChat(self, provider)
if image_provider is None:
image_provider = provider
self.models: MediaModels = MediaModels(self, image_provider)
self.images: AsyncImages = AsyncImages(self, image_provider) self.images: AsyncImages = AsyncImages(self, image_provider)
self.media: AsyncImages = self.images self.media: AsyncImages = self.images
@@ -635,7 +642,6 @@ class AsyncImages(Images):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None): def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client self.client: AsyncClient = client
self.provider: Optional[ProviderType] = provider self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
async def generate( async def generate(
self, self,

View File

@@ -1,15 +1,43 @@
from __future__ import annotations from __future__ import annotations
from ..models import ModelUtils from ..models import ModelUtils, ImageModel
from ..Provider import ProviderUtils from ..Provider import ProviderUtils
from ..providers.types import ProviderType
class ImageModels(): class MediaModels():
def __init__(self, client): def __init__(self, client, provider: ProviderType = None):
self.client = client self.client = client
self.provider = provider
def get(self, name, default=None): def get(self, name, default=None) -> ProviderType:
if name in ModelUtils.convert: if name in ModelUtils.convert:
return ModelUtils.convert[name].best_provider return ModelUtils.convert[name].best_provider
if name in ProviderUtils.convert: if name in ProviderUtils.convert:
return ProviderUtils.convert[name] return ProviderUtils.convert[name]
return default return default
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
if self.provider is None:
return []
if api_key is None:
api_key = self.client.api_key
return self.provider.get_models(
**kwargs,
**{} if api_key is None else {"api_key": api_key}
)
def get_image(self, **kwargs) -> list[str]:
if self.provider is None:
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
self.get_all(**kwargs)
if hasattr(self.provider, "image_models"):
return self.provider.image_models
return []
def get_video(self, **kwargs) -> list[str]:
if self.provider is None:
return []
self.get_all(**kwargs)
if hasattr(self.provider, "video_models"):
return self.provider.video_models
return []

View File

@@ -61,9 +61,12 @@
const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">'; const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">';
</script> </script>
<script src="/static/js/highlight.min.js" async></script> <script src="/static/js/highlight.min.js" async></script>
<script>
<script>window.conversation_id = "{{conversation_id}}"</script> window.conversation_id = "{{conversation_id}}";
<script>window.chat_id = "{{chat_id}}"; window.share_url = "{{share_url}}";</script> window.chat_id = "{{chat_id}}";
window.share_url = "{{share_url}}";
window.start_id = "{{conversation_id}}";
</script>
<title>G4F Chat</title> <title>G4F Chat</title>
</head> </head>
<body> <body>

View File

@@ -7,9 +7,9 @@
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
<style> <style>
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; } body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover;} video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover; max-width: 100%;}
#qrcode { margin-top: 20px; } #qrcode { margin-top: 20px; }
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; } #qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; max-width: 100%;}
button { margin: 5px; padding: 10px; } button { margin: 5px; padding: 10px; }
</style> </style>
</head> </head>

View File

@@ -881,7 +881,7 @@ input.model:hover
padding: var(--inner-gap) 28px; padding: var(--inner-gap) 28px;
} }
#systemPrompt, #chatPrompt, .settings textarea, form textarea { #systemPrompt, #chatPrompt, .settings textarea, form textarea, .chat-body textarea {
font-size: 15px; font-size: 15px;
color: var(--colour-3); color: var(--colour-3);
outline: none; outline: none;
@@ -1305,7 +1305,7 @@ form textarea {
padding: 0; padding: 0;
} }
.settings textarea { .settings textarea, .chat-body textarea {
height: 30px; height: 30px;
min-height: 30px; min-height: 30px;
padding: 6px; padding: 6px;
@@ -1315,7 +1315,7 @@ form textarea {
text-wrap: nowrap; text-wrap: nowrap;
} }
form .field .fa-xmark { .field .fa-xmark {
line-height: 20px; line-height: 20px;
cursor: pointer; cursor: pointer;
margin-left: auto; margin-left: auto;
@@ -1323,11 +1323,11 @@ form .field .fa-xmark {
margin-top: 0; margin-top: 0;
} }
form .field.saved .fa-xmark { .field.saved .fa-xmark {
color: var(--accent) color: var(--accent)
} }
.settings .field, form .field { .settings .field, form .field, .chat-body .field {
padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0; padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0;
} }
@@ -1359,7 +1359,7 @@ form .field.saved .fa-xmark {
border: none; border: none;
} }
.settings input, form input { .settings input, form input, .chat-body input {
background-color: transparent; background-color: transparent;
padding: 2px; padding: 2px;
border: none; border: none;
@@ -1368,11 +1368,11 @@ form .field.saved .fa-xmark {
color: var(--colour-3); color: var(--colour-3);
} }
.settings input:focus, form input:focus { .settings input:focus, form input:focus, .chat-body input:focus {
outline: none; outline: none;
} }
.settings .label, form .label, .settings label, form label { .settings .label, form .label, .settings label, form label, .chat-body label {
font-size: 15px; font-size: 15px;
margin-left: var(--inner-gap); margin-left: var(--inner-gap);
} }

View File

@@ -28,7 +28,7 @@ const switchInput = document.getElementById("switch");
const searchButton = document.getElementById("search"); const searchButton = document.getElementById("search");
const paperclip = document.querySelector(".user-input .fa-paperclip"); const paperclip = document.querySelector(".user-input .fa-paperclip");
const optionElementsSelector = ".settings input, .settings textarea, #model, #model2, #provider"; const optionElementsSelector = ".settings input, .settings textarea, .chat-body input, #model, #model2, #provider";
let provider_storage = {}; let provider_storage = {};
let message_storage = {}; let message_storage = {};
@@ -153,7 +153,7 @@ const iframe_close = Object.assign(document.createElement("button"), {
}); });
iframe_close.onclick = () => iframe_container.classList.add("hidden"); iframe_close.onclick = () => iframe_container.classList.add("hidden");
iframe_container.appendChild(iframe_close); iframe_container.appendChild(iframe_close);
chat.appendChild(iframe_container); document.body.appendChild(iframe_container);
class HtmlRenderPlugin { class HtmlRenderPlugin {
constructor(options = {}) { constructor(options = {}) {
@@ -843,6 +843,16 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
conversation.data[key] = value; conversation.data[key] = value;
} }
await save_conversation(conversation_id, conversation); await save_conversation(conversation_id, conversation);
} else if (message.type == "auth") {
error_storage[message_id] = message.message
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${message.message}`);
let provider = provider_storage[message_id]?.name;
let configEl = document.querySelector(`.settings .${provider}-api_key`);
if (configEl) {
configEl = configEl.parentElement.cloneNode(true);
content_map.content.appendChild(configEl);
await register_settings_storage();
}
} else if (message.type == "provider") { } else if (message.type == "provider") {
provider_storage[message_id] = message.provider; provider_storage[message_id] = message.provider;
let provider_el = content_map.content.querySelector('.provider'); let provider_el = content_map.content.querySelector('.provider');
@@ -1122,10 +1132,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
let api_key; let api_key;
if (is_demo && !provider) { if (is_demo && !provider) {
api_key = localStorage.getItem("HuggingFace-api_key"); api_key = localStorage.getItem("HuggingFace-api_key");
if (!api_key) {
location.href = "/";
return;
}
} else { } else {
api_key = get_api_key_by_provider(provider); api_key = get_api_key_by_provider(provider);
} }
@@ -1221,6 +1227,7 @@ function sanitize(input, replacement) {
} }
async function set_conversation_title(conversation_id, title) { async function set_conversation_title(conversation_id, title) {
window.chat_id = null;
conversation = await get_conversation(conversation_id) conversation = await get_conversation(conversation_id)
conversation.new_title = title; conversation.new_title = title;
const new_id = sanitize(title, " "); const new_id = sanitize(title, " ");
@@ -1742,12 +1749,22 @@ const load_conversations = async () => {
let html = []; let html = [];
conversations.forEach((conversation) => { conversations.forEach((conversation) => {
// const length = conversation.items.map((item) => (
// !item.content.toLowerCase().includes("hello") &&
// !item.content.toLowerCase().includes("hi") &&
// item.content
// ) ? 1 : 0).reduce((a,b)=>a+b, 0);
// if (!length) {
// appStorage.removeItem(`conversation:${conversation.id}`);
// return;
// }
const shareIcon = (conversation.id == window.start_id && window.chat_id) ? '<i class="fa-solid fa-qrcode"></i>': '';
html.push(` html.push(`
<div class="convo" id="convo-${conversation.id}"> <div class="convo" id="convo-${conversation.id}">
<div class="left" onclick="set_conversation('${conversation.id}')"> <div class="left" onclick="set_conversation('${conversation.id}')">
<i class="fa-regular fa-comments"></i> <i class="fa-regular fa-comments"></i>
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span> <span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
<span class="convo-title">${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span> <span class="convo-title">${shareIcon} ${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
</div> </div>
<i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i> <i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i>
<div id="cho-${conversation.id}" class="choise" style="display:none;"> <div id="cho-${conversation.id}" class="choise" style="display:none;">
@@ -2060,7 +2077,6 @@ window.addEventListener('load', async function() {
if (!window.conversation_id) { if (!window.conversation_id) {
window.conversation_id = window.chat_id; window.conversation_id = window.chat_id;
} }
window.start_id = window.conversation_id
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, { const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, {
headers: {'accept': 'application/json'}, headers: {'accept': 'application/json'},
}); });
@@ -2075,6 +2091,7 @@ window.addEventListener('load', async function() {
`conversation:${conversation.id}`, `conversation:${conversation.id}`,
JSON.stringify(conversation) JSON.stringify(conversation)
); );
await load_conversations();
let refreshOnHide = true; let refreshOnHide = true;
document.addEventListener("visibilitychange", () => { document.addEventListener("visibilitychange", () => {
if (document.hidden) { if (document.hidden) {
@@ -2091,6 +2108,9 @@ window.addEventListener('load', async function() {
if (!refreshOnHide) { if (!refreshOnHide) {
return; return;
} }
if (window.conversation_id != window.start_id) {
return;
}
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id}`, { const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id}`, {
headers: {'accept': 'application/json', 'if-none-match': conversation.updated}, headers: {'accept': 'application/json', 'if-none-match': conversation.updated},
}); });
@@ -2102,6 +2122,7 @@ window.addEventListener('load', async function() {
`conversation:${conversation.id}`, `conversation:${conversation.id}`,
JSON.stringify(conversation) JSON.stringify(conversation)
); );
await load_conversations();
await load_conversation(conversation); await load_conversation(conversation);
} }
} }
@@ -2284,7 +2305,7 @@ async function on_api() {
} }
} else if (provider.login_url) { } else if (provider.login_url) {
if (!login_urls[provider.name]) { if (!login_urls[provider.name]) {
login_urls[provider.name] = [provider.label, provider.login_url, [], provider.auth]; login_urls[provider.name] = [provider.label, provider.login_url, [provider.name], provider.auth];
} else { } else {
login_urls[provider.name][0] = provider.label; login_urls[provider.name][0] = provider.label;
login_urls[provider.name][1] = provider.login_url; login_urls[provider.name][1] = provider.login_url;

View File

@@ -7,7 +7,7 @@ from typing import Iterator
from flask import send_from_directory from flask import send_from_directory
from inspect import signature from inspect import signature
from ...errors import VersionNotFoundError from ...errors import VersionNotFoundError, MissingAuthError
from ...image.copy_images import copy_media, ensure_images_dir, images_dir from ...image.copy_images import copy_media, ensure_images_dir, images_dir
from ...tools.run_tools import iter_run_tools from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__ from ...Provider import ProviderUtils, __providers__
@@ -187,7 +187,8 @@ class Api:
media = chunk media = chunk
if download_media or chunk.get("cookies"): if download_media or chunk.get("cookies"):
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt) chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt)) tags = [tag for tag in [model, kwargs.get("aspect_ratio")] if tag]
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags))
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt) media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt) yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData): elif isinstance(chunk, SynthesizeData):
@@ -214,6 +215,8 @@ class Api:
yield self._format_json(chunk.type, **chunk.get_dict()) yield self._format_json(chunk.type, **chunk.get_dict())
else: else:
yield self._format_json("content", str(chunk)) yield self._format_json("content", str(chunk))
except MissingAuthError as e:
yield self._format_json('auth', type(e).__name__, message=get_error_message(e))
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
debug.error(e) debug.error(e)

View File

@@ -16,7 +16,6 @@ from pathlib import Path
from urllib.parse import quote_plus from urllib.parse import quote_plus
from hashlib import sha256 from hashlib import sha256
from ...image import is_allowed_extension
from ...client.service import convert_to_provider from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown from ...client.helper import filter_markdown
@@ -25,7 +24,7 @@ from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension from ...image import is_allowed_extension
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename, get_source_url from ...image.copy_images import secure_filename, get_source_url, images_dir
from ... import ChatCompletion from ... import ChatCompletion
from ... import models from ... import models
from .api import Api from .api import Api
@@ -351,9 +350,30 @@ class Backend_Api(Api):
return redirect(source_url) return redirect(source_url)
raise raise
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET']) @app.route('/search/<search>', methods=['GET'])
def get_media_sub(dirname, bucket_id, filename): def find_media(search: str, min: int = None):
return get_media(bucket_id, filename, dirname) search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
if min is None:
min = len(search)
if not os.access(images_dir, os.R_OK):
return jsonify({"error": {"message": "Not found"}}), 404
match_files = {}
for root, _, files in os.walk(images_dir):
for file in files:
mime_type = is_allowed_extension(file)
if mime_type is not None:
mime_type = secure_filename(mime_type)
for tag in search:
if tag in mime_type:
match_files[file] = match_files.get(file, 0) + 1
break
for tag in search:
if tag in file.lower():
match_files[file] = match_files.get(file, 0) + 1
match_files = [file for file, count in match_files.items() if count >= min]
if not match_files:
return jsonify({"error": {"message": "Not found"}}), 404
return redirect(f"/media/{random.choice(match_files)}")
@app.route('/backend-api/v2/upload_cookies', methods=['POST']) @app.route('/backend-api/v2/upload_cookies', methods=['POST'])
def upload_cookies(): def upload_cookies():
@@ -371,7 +391,7 @@ class Backend_Api(Api):
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET']) @self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET'])
def get_chat(chat_id: str) -> str: def get_chat(chat_id: str) -> str:
chat_id = secure_filename(chat_id) chat_id = secure_filename(chat_id)
if int(self.chat_cache.get(chat_id, -1)) == int(request.headers.get("if-none-match", 0)): if self.chat_cache.get(chat_id, 0) == request.headers.get("if-none-match", 0):
return jsonify({"error": {"message": "Not modified"}}), 304 return jsonify({"error": {"message": "Not modified"}}), 304
bucket_dir = get_bucket_dir(chat_id) bucket_dir = get_bucket_dir(chat_id)
file = os.path.join(bucket_dir, "chat.json") file = os.path.join(bucket_dir, "chat.json")
@@ -379,7 +399,7 @@ class Backend_Api(Api):
return jsonify({"error": {"message": "Not found"}}), 404 return jsonify({"error": {"message": "Not found"}}), 404
with open(file, 'r') as f: with open(file, 'r') as f:
chat_data = json.load(f) chat_data = json.load(f)
if int(chat_data.get("updated", 0)) == int(request.headers.get("if-none-match", 0)): if chat_data.get("updated", 0) == request.headers.get("if-none-match", 0):
return jsonify({"error": {"message": "Not modified"}}), 304 return jsonify({"error": {"message": "Not modified"}}), 304
self.chat_cache[chat_id] = chat_data.get("updated", 0) self.chat_cache[chat_id] = chat_data.get("updated", 0)
return jsonify(chat_data), 200 return jsonify(chat_data), 200
@@ -387,12 +407,16 @@ class Backend_Api(Api):
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST']) @self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST'])
def upload_chat(chat_id: str) -> dict: def upload_chat(chat_id: str) -> dict:
chat_data = {**request.json} chat_data = {**request.json}
updated = chat_data.get("updated", 0)
cache_value = self.chat_cache.get(chat_id, 0)
if updated == cache_value:
return jsonify({"error": {"message": "invalid date"}}), 400
chat_id = secure_filename(chat_id) chat_id = secure_filename(chat_id)
bucket_dir = get_bucket_dir(chat_id) bucket_dir = get_bucket_dir(chat_id)
os.makedirs(bucket_dir, exist_ok=True) os.makedirs(bucket_dir, exist_ok=True)
with open(os.path.join(bucket_dir, "chat.json"), 'w') as f: with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
json.dump(chat_data, f) json.dump(chat_data, f)
self.chat_cache[chat_id] = chat_data.get("updated", 0) self.chat_cache[chat_id] = updated
return {"chat_id": chat_id} return {"chat_id": chat_id}
def handle_synthesize(self, provider: str): def handle_synthesize(self, provider: str):

View File

@@ -6,6 +6,7 @@ import io
import base64 import base64
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Optional
try: try:
from PIL.Image import open as open_image, new as new_image from PIL.Image import open as open_image, new as new_image
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
@@ -17,15 +18,6 @@ from ..providers.helper import filter_none
from ..typing import ImageType, Union, Image from ..typing import ImageType, Union, Image
from ..errors import MissingRequirementsError from ..errors import MissingRequirementsError
ALLOWED_EXTENSIONS = {
# Image
'png', 'jpg', 'jpeg', 'gif', 'webp',
# Audio
'wav', 'mp3', 'flac', 'opus', 'ogg',
# Video
'mkv', 'webm', 'mp4'
}
MEDIA_TYPE_MAP: dict[str, str] = { MEDIA_TYPE_MAP: dict[str, str] = {
"image/png": "png", "image/png": "png",
"image/jpeg": "jpg", "image/jpeg": "jpg",
@@ -90,7 +82,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
return image return image
def is_allowed_extension(filename: str) -> bool: def is_allowed_extension(filename: str) -> Optional[str]:
""" """
Checks if the given filename has an allowed extension. Checks if the given filename has an allowed extension.
@@ -100,8 +92,8 @@ def is_allowed_extension(filename: str) -> bool:
Returns: Returns:
bool: True if the extension is allowed, False otherwise. bool: True if the extension is allowed, False otherwise.
""" """
return '.' in filename and \ ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None
def is_data_an_media(data, filename: str = None) -> str: def is_data_an_media(data, filename: str = None) -> str:
content_type = is_data_an_audio(data, filename) content_type = is_data_an_audio(data, filename)
@@ -138,7 +130,7 @@ def is_data_uri_an_image(data_uri: str) -> bool:
# Extract the image format from the data URI # Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower() image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml": if image_format not in EXTENSIONS_MAP and image_format != "svg+xml":
raise ValueError("Invalid image format (from mime file type).") raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> str: def is_accepted_format(binary_data: bytes) -> str:

View File

@@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError
from ..typing import Optional, Cookies from ..typing import Optional, Cookies
from ..requests.aiohttp import get_connector, StreamResponse from ..requests.aiohttp import get_connector, StreamResponse
from ..image import MEDIA_TYPE_MAP, ALLOWED_EXTENSIONS from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
from ..tools.files import get_bucket_dir from ..tools.files import get_bucket_dir
from ..providers.response import ImageResponse, AudioResponse, VideoResponse from ..providers.response import ImageResponse, AudioResponse, VideoResponse
from ..Provider.template import BackendApi from ..Provider.template import BackendApi
@@ -58,7 +58,7 @@ async def save_response_media(response: StreamResponse, prompt: str):
content_type = response.headers["content-type"] content_type = response.headers["content-type"]
if is_valid_media_type(content_type): if is_valid_media_type(content_type):
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3") extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
if extension not in ALLOWED_EXTENSIONS: if extension not in EXTENSIONS_MAP:
raise ValueError(f"Unsupported media type: {content_type}") raise ValueError(f"Unsupported media type: {content_type}")
bucket_id = str(uuid.uuid4()) bucket_id = str(uuid.uuid4())
dirname = str(int(time.time())) dirname = str(int(time.time()))
@@ -86,6 +86,7 @@ async def copy_media(
headers: Optional[dict] = None, headers: Optional[dict] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
alt: str = None, alt: str = None,
tags: list[str] = None,
add_url: bool = True, add_url: bool = True,
target: str = None, target: str = None,
ssl: bool = None ssl: bool = None
@@ -113,6 +114,7 @@ async def copy_media(
# Build safe filename with full Unicode support # Build safe filename with full Unicode support
filename = secure_filename("".join(( filename = secure_filename("".join((
f"{int(time.time())}_", f"{int(time.time())}_",
(f"{''.join(tags, '_')}_" if tags else ""),
(f"{alt}_" if alt else ""), (f"{alt}_" if alt else ""),
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}", f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
f"{get_media_extension(image)}" f"{get_media_extension(image)}"

View File

@@ -18,6 +18,7 @@ from .Provider import (
FreeGpt, FreeGpt,
HuggingSpace, HuggingSpace,
G4F, G4F,
Grok,
DeepseekAI_JanusPro7b, DeepseekAI_JanusPro7b,
Glider, Glider,
Goabror, Goabror,
@@ -356,19 +357,19 @@ gemini_1_5_pro = Model(
gemini_2_0_flash = Model( gemini_2_0_flash = Model(
name = 'gemini-2.0-flash', name = 'gemini-2.0-flash',
base_provider = 'Google DeepMind', base_provider = 'Google DeepMind',
best_provider = IterListProvider([Dynaspark, GeminiPro, Liaobots]) best_provider = IterListProvider([Dynaspark, GeminiPro, Gemini])
) )
gemini_2_0_flash_thinking = Model( gemini_2_0_flash_thinking = Model(
name = 'gemini-2.0-flash-thinking', name = 'gemini-2.0-flash-thinking',
base_provider = 'Google DeepMind', base_provider = 'Google DeepMind',
best_provider = Liaobots best_provider = Gemini
) )
gemini_2_0_pro = Model( gemini_2_0_flash_thinking_with_apps = Model(
name = 'gemini-2.0-pro', name = 'gemini-2.0-flash-thinking-with-apps',
base_provider = 'Google DeepMind', base_provider = 'Google DeepMind',
best_provider = Liaobots best_provider = Gemini
) )
### Anthropic ### ### Anthropic ###
@@ -379,19 +380,6 @@ claude_3_haiku = Model(
best_provider = IterListProvider([DDG, Jmuz]) best_provider = IterListProvider([DDG, Jmuz])
) )
claude_3_sonnet = Model(
name = 'claude-3-sonnet',
base_provider = 'Anthropic',
best_provider = Liaobots
)
claude_3_opus = Model(
name = 'claude-3-opus',
base_provider = 'Anthropic',
best_provider = IterListProvider([Jmuz, Liaobots])
)
# claude 3.5 # claude 3.5
claude_3_5_sonnet = Model( claude_3_5_sonnet = Model(
name = 'claude-3.5-sonnet', name = 'claude-3.5-sonnet',
@@ -406,12 +394,6 @@ claude_3_7_sonnet = Model(
best_provider = IterListProvider([Blackbox, Liaobots]) best_provider = IterListProvider([Blackbox, Liaobots])
) )
claude_3_7_sonnet_thinking = Model(
name = 'claude-3.7-sonnet-thinking',
base_provider = 'Anthropic',
best_provider = Liaobots
)
### Reka AI ### ### Reka AI ###
reka_core = Model( reka_core = Model(
name = 'reka-core', name = 'reka-core',
@@ -548,13 +530,13 @@ janus_pro_7b = VisionModel(
grok_3 = Model( grok_3 = Model(
name = 'grok-3', name = 'grok-3',
base_provider = 'x.ai', base_provider = 'x.ai',
best_provider = Liaobots best_provider = Grok
) )
grok_3_r1 = Model( grok_3_r1 = Model(
name = 'grok-3-r1', name = 'grok-3-r1',
base_provider = 'x.ai', base_provider = 'x.ai',
best_provider = Liaobots best_provider = Grok
) )
### Perplexity AI ### ### Perplexity AI ###
@@ -841,12 +823,10 @@ class ModelUtils:
gemini_1_5_flash.name: gemini_1_5_flash, gemini_1_5_flash.name: gemini_1_5_flash,
gemini_2_0_flash.name: gemini_2_0_flash, gemini_2_0_flash.name: gemini_2_0_flash,
gemini_2_0_flash_thinking.name: gemini_2_0_flash_thinking, gemini_2_0_flash_thinking.name: gemini_2_0_flash_thinking,
gemini_2_0_pro.name: gemini_2_0_pro, gemini_2_0_flash_thinking_with_apps.name: gemini_2_0_flash_thinking_with_apps,
### Anthropic ### ### Anthropic ###
# claude 3 # claude 3
claude_3_opus.name: claude_3_opus,
claude_3_sonnet.name: claude_3_sonnet,
claude_3_haiku.name: claude_3_haiku, claude_3_haiku.name: claude_3_haiku,
# claude 3.5 # claude 3.5
@@ -854,7 +834,6 @@ class ModelUtils:
# claude 3.7 # claude 3.7
claude_3_7_sonnet.name: claude_3_7_sonnet, claude_3_7_sonnet.name: claude_3_7_sonnet,
claude_3_7_sonnet_thinking.name: claude_3_7_sonnet_thinking,
### Reka AI ### ### Reka AI ###
reka_core.name: reka_core, reka_core.name: reka_core,

View File

@@ -366,11 +366,15 @@ class ProviderModelMixin:
class RaiseErrorMixin(): class RaiseErrorMixin():
@staticmethod @staticmethod
def raise_error(data: dict): def raise_error(data: dict, status: int = None):
if "error_message" in data: if "error_message" in data:
raise ResponseError(data["error_message"]) raise ResponseError(data["error_message"])
elif "error" in data: elif "error" in data:
if isinstance(data["error"], str): if isinstance(data["error"], str):
if status is not None:
if status in (401, 402):
raise MissingAuthError(f"Error {status}: {data['error']}")
raise ResponseError(f"Error {status}: {data['error']}")
raise ResponseError(data["error"]) raise ResponseError(data["error"])
elif "code" in data["error"]: elif "code" in data["error"]:
raise ResponseError("\n".join( raise ResponseError("\n".join(

View File

@@ -4,7 +4,7 @@ from typing import Union
from aiohttp import ClientResponse from aiohttp import ClientResponse
from requests import Response as RequestsResponse from requests import Response as RequestsResponse
from ..errors import ResponseStatusError, RateLimitError from ..errors import ResponseStatusError, RateLimitError, MissingAuthError
from . import Response, StreamResponse from . import Response, StreamResponse
class CloudflareError(ResponseStatusError): class CloudflareError(ResponseStatusError):
@@ -23,6 +23,7 @@ def is_openai(text: str) -> bool:
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None): async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
if response.ok: if response.ok:
return return
is_html = False
if message is None: if message is None:
content_type = response.headers.get("content-type", "") content_type = response.headers.get("content-type", "")
if content_type.startswith("application/json"): if content_type.startswith("application/json"):
@@ -31,39 +32,42 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
if isinstance(message, dict): if isinstance(message, dict):
message = message.get("message", message) message = message.get("message", message)
else: else:
text = (await response.text()).strip() message = (await response.text()).strip()
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE") is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else text if message is None or is_html:
if message is None or message == "HTML content":
if response.status == 520: if response.status == 520:
message = "Unknown error (Cloudflare)" message = "Unknown error (Cloudflare)"
elif response.status in (429, 402): elif response.status in (429, 402):
message = "Rate limit" message = "Rate limit"
if response.status == 403 and is_cloudflare(text): if response.status in (401, 402):
raise MissingAuthError(f"Response {response.status}: {message}")
if response.status == 403 and is_cloudflare(message):
raise CloudflareError(f"Response {response.status}: Cloudflare detected") raise CloudflareError(f"Response {response.status}: Cloudflare detected")
elif response.status == 403 and is_openai(text): elif response.status == 403 and is_openai(message):
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected") raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
elif response.status == 502: elif response.status == 502:
raise ResponseStatusError(f"Response {response.status}: Bad Gateway") raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
elif response.status == 504: elif response.status == 504:
raise RateLimitError(f"Response {response.status}: Gateway Timeout ") raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
else: else:
raise ResponseStatusError(f"Response {response.status}: {message}") raise ResponseStatusError(f"Response {response.status}: {"HTML content" if is_html else message}")
def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None): def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None):
if hasattr(response, "status"): if hasattr(response, "status"):
return raise_for_status_async(response, message) return raise_for_status_async(response, message)
if response.ok: if response.ok:
return return
is_html = False
if message is None: if message is None:
is_html = response.headers.get("content-type", "").startswith("text/html") or response.text.startswith("<!DOCTYPE") is_html = response.headers.get("content-type", "").startswith("text/html") or response.text.startswith("<!DOCTYPE")
message = "HTML content" if is_html else response.text message = response.text
if message == "HTML content": if message is None or is_html:
if response.status_code == 520: if response.status_code == 520:
message = "Unknown error (Cloudflare)" message = "Unknown error (Cloudflare)"
elif response.status_code in (429, 402): elif response.status_code in (429, 402):
message = "Rate limit" raise RateLimitError(f"Response {response.status_code}: Rate Limit")
raise RateLimitError(f"Response {response.status_code}: {message}") if response.status_code in (401, 402):
raise MissingAuthError(f"Response {response.status_code}: {message}")
if response.status_code == 403 and is_cloudflare(response.text): if response.status_code == 403 and is_cloudflare(response.text):
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected") raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
elif response.status_code == 403 and is_openai(response.text): elif response.status_code == 403 and is_openai(response.text):
@@ -73,4 +77,4 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
elif response.status_code == 504: elif response.status_code == 504:
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ") raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
else: else:
raise ResponseStatusError(f"Response {response.status_code}: {message}") raise ResponseStatusError(f"Response {response.status_code}: {"HTML content" if is_html else message}")