mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-18 06:10:44 +08:00
Set default model in HuggingFaceMedia
Improve handling of shared chats Show api_key input if required
This commit is contained in:
19
etc/examples/video.py
Normal file
19
etc/examples/video.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import g4f.Provider
|
||||
from g4f.client import Client
|
||||
|
||||
client = Client(
|
||||
provider=g4f.Provider.HuggingFaceMedia,
|
||||
api_key="hf_***" # Your API key here
|
||||
)
|
||||
|
||||
video_models = client.models.get_video()
|
||||
|
||||
print(video_models)
|
||||
|
||||
result = client.media.generate(
|
||||
model=video_models[0],
|
||||
prompt="G4F AI technology is the best in the world.",
|
||||
response_format="url"
|
||||
)
|
||||
|
||||
print(result.data[0].url)
|
@@ -66,6 +66,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
for provider_data in provider_keys:
|
||||
prepend_models.append(f"{model}:{provider_data.get('provider')}")
|
||||
cls.models = prepend_models + [model for model in new_models if model not in prepend_models]
|
||||
cls.image_models = [model for model, task in cls.task_mapping.items() if task == "text-to-image"]
|
||||
cls.video_models = [model for model, task in cls.task_mapping.items() if task == "text-to-video"]
|
||||
else:
|
||||
cls.models = []
|
||||
return cls.models
|
||||
@@ -99,12 +101,14 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
timeout: int = 0,
|
||||
aspect_ratio: str = "1:1",
|
||||
aspect_ratio: str = None,
|
||||
**kwargs
|
||||
):
|
||||
selected_provider = None
|
||||
if ":" in model:
|
||||
if model and ":" in model:
|
||||
model, selected_provider = model.split(":", 1)
|
||||
elif not model:
|
||||
model = cls.get_models()[0]
|
||||
provider_mapping = await cls.get_mapping(model, api_key)
|
||||
headers = {
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
@@ -133,11 +137,11 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
extra_data = {
|
||||
"num_inference_steps": 20,
|
||||
"resolution": "480p",
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"aspect_ratio": "16:9" if aspect_ratio is None else aspect_ratio,
|
||||
**extra_data
|
||||
}
|
||||
else:
|
||||
extra_data = use_aspect_ratio(extra_data, aspect_ratio)
|
||||
extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio)
|
||||
if provider_key == "fal-ai":
|
||||
url = f"{api_base}/{provider_id}"
|
||||
data = {
|
||||
|
@@ -30,6 +30,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
||||
|
||||
default_model = "grok-3"
|
||||
models = [default_model, "grok-3-thinking", "grok-2"]
|
||||
model_aliases = {"grok-3-r1": "grok-3-thinking"}
|
||||
|
||||
@classmethod
|
||||
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
|
||||
@@ -73,7 +74,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
||||
"sendFinalMetadata": True,
|
||||
"customInstructions": "",
|
||||
"deepsearchPreset": "",
|
||||
"isReasoning": model.endswith("-thinking"),
|
||||
"isReasoning": model.endswith("-thinking") or model.endswith("-r1"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@@ -92,7 +92,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
}
|
||||
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
cls.raise_error(data, response.status)
|
||||
await raise_for_status(response)
|
||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||
return
|
||||
@@ -135,7 +135,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
|
||||
if content_type.startswith("application/json"):
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
cls.raise_error(data, response.status)
|
||||
await raise_for_status(response)
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["message"] and choice["message"]["content"]:
|
||||
|
@@ -10,6 +10,7 @@ from email.utils import formatdate
|
||||
import os.path
|
||||
import hashlib
|
||||
import asyncio
|
||||
from urllib.parse import quote_plus
|
||||
from fastapi import FastAPI, Response, Request, UploadFile, Depends
|
||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
|
||||
@@ -562,6 +563,10 @@ class Api:
|
||||
})
|
||||
async def get_media(filename, request: Request):
|
||||
target = os.path.join(images_dir, os.path.basename(filename))
|
||||
if not os.path.isfile(target):
|
||||
other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename)))
|
||||
if os.path.isfile(other_name):
|
||||
target = other_name
|
||||
ext = os.path.splitext(filename)[1][1:]
|
||||
mime_type = EXTENSIONS_MAP.get(ext)
|
||||
stat_result = SimpleNamespace()
|
||||
|
@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
||||
from .image_models import ImageModels
|
||||
from .image_models import MediaModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||
@@ -267,8 +267,11 @@ class Client(BaseClient):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.chat: Chat = Chat(self, provider)
|
||||
if image_provider is None:
|
||||
image_provider = provider
|
||||
self.models: MediaModels = MediaModels(self, image_provider)
|
||||
self.images: Images = Images(self, image_provider)
|
||||
self.media: Images = Images(self, image_provider)
|
||||
self.media: Images = self.images
|
||||
|
||||
class Completions:
|
||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||
@@ -349,7 +352,6 @@ class Images:
|
||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||
self.client: Client = client
|
||||
self.provider: Optional[ProviderType] = provider
|
||||
self.models: ImageModels = ImageModels(client)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -369,7 +371,7 @@ class Images:
|
||||
if provider is None:
|
||||
provider_handler = self.provider
|
||||
if provider_handler is None:
|
||||
provider_handler = self.models.get(model, default)
|
||||
provider_handler = self.client.models.get(model, default)
|
||||
elif isinstance(provider, str):
|
||||
provider_handler = convert_to_provider(provider)
|
||||
else:
|
||||
@@ -385,19 +387,21 @@ class Images:
|
||||
provider: Optional[ProviderType] = None,
|
||||
response_format: Optional[str] = None,
|
||||
proxy: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
|
||||
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
|
||||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
|
||||
if api_key is None:
|
||||
api_key = self.client.api_key
|
||||
error = None
|
||||
response = None
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
for provider in provider_handler.providers:
|
||||
try:
|
||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
|
||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
|
||||
if response is not None:
|
||||
provider_name = provider.__name__
|
||||
break
|
||||
@@ -405,7 +409,7 @@ class Images:
|
||||
error = e
|
||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
|
||||
|
||||
if isinstance(response, MediaResponse):
|
||||
return await self._process_image_response(
|
||||
@@ -534,7 +538,7 @@ class Images:
|
||||
else:
|
||||
# Save locally for None (default) case
|
||||
images = await copy_media(response.get_list(), response.get("cookies"), proxy)
|
||||
images = [Image.model_construct(url=f"/media/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
|
||||
images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in images]
|
||||
|
||||
return ImagesResponse.model_construct(
|
||||
created=int(time.time()),
|
||||
@@ -552,6 +556,9 @@ class AsyncClient(BaseClient):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.chat: AsyncChat = AsyncChat(self, provider)
|
||||
if image_provider is None:
|
||||
image_provider = provider
|
||||
self.models: MediaModels = MediaModels(self, image_provider)
|
||||
self.images: AsyncImages = AsyncImages(self, image_provider)
|
||||
self.media: AsyncImages = self.images
|
||||
|
||||
@@ -635,7 +642,6 @@ class AsyncImages(Images):
|
||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||
self.client: AsyncClient = client
|
||||
self.provider: Optional[ProviderType] = provider
|
||||
self.models: ImageModels = ImageModels(client)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
|
@@ -1,15 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..models import ModelUtils
|
||||
from ..models import ModelUtils, ImageModel
|
||||
from ..Provider import ProviderUtils
|
||||
from ..providers.types import ProviderType
|
||||
|
||||
class ImageModels():
|
||||
def __init__(self, client):
|
||||
class MediaModels():
|
||||
def __init__(self, client, provider: ProviderType = None):
|
||||
self.client = client
|
||||
self.provider = provider
|
||||
|
||||
def get(self, name, default=None):
|
||||
def get(self, name, default=None) -> ProviderType:
|
||||
if name in ModelUtils.convert:
|
||||
return ModelUtils.convert[name].best_provider
|
||||
if name in ProviderUtils.convert:
|
||||
return ProviderUtils.convert[name]
|
||||
return default
|
||||
|
||||
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return []
|
||||
if api_key is None:
|
||||
api_key = self.client.api_key
|
||||
return self.provider.get_models(
|
||||
**kwargs,
|
||||
**{} if api_key is None else {"api_key": api_key}
|
||||
)
|
||||
|
||||
def get_image(self, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
|
||||
self.get_all(**kwargs)
|
||||
if hasattr(self.provider, "image_models"):
|
||||
return self.provider.image_models
|
||||
return []
|
||||
|
||||
def get_video(self, **kwargs) -> list[str]:
|
||||
if self.provider is None:
|
||||
return []
|
||||
self.get_all(**kwargs)
|
||||
if hasattr(self.provider, "video_models"):
|
||||
return self.provider.video_models
|
||||
return []
|
@@ -61,9 +61,12 @@
|
||||
const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">';
|
||||
</script>
|
||||
<script src="/static/js/highlight.min.js" async></script>
|
||||
|
||||
<script>window.conversation_id = "{{conversation_id}}"</script>
|
||||
<script>window.chat_id = "{{chat_id}}"; window.share_url = "{{share_url}}";</script>
|
||||
<script>
|
||||
window.conversation_id = "{{conversation_id}}";
|
||||
window.chat_id = "{{chat_id}}";
|
||||
window.share_url = "{{share_url}}";
|
||||
window.start_id = "{{conversation_id}}";
|
||||
</script>
|
||||
<title>G4F Chat</title>
|
||||
</head>
|
||||
<body>
|
||||
|
@@ -7,9 +7,9 @@
|
||||
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
|
||||
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover;}
|
||||
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover; max-width: 100%;}
|
||||
#qrcode { margin-top: 20px; }
|
||||
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; }
|
||||
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; max-width: 100%;}
|
||||
button { margin: 5px; padding: 10px; }
|
||||
</style>
|
||||
</head>
|
||||
|
@@ -881,7 +881,7 @@ input.model:hover
|
||||
padding: var(--inner-gap) 28px;
|
||||
}
|
||||
|
||||
#systemPrompt, #chatPrompt, .settings textarea, form textarea {
|
||||
#systemPrompt, #chatPrompt, .settings textarea, form textarea, .chat-body textarea {
|
||||
font-size: 15px;
|
||||
color: var(--colour-3);
|
||||
outline: none;
|
||||
@@ -1305,7 +1305,7 @@ form textarea {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.settings textarea {
|
||||
.settings textarea, .chat-body textarea {
|
||||
height: 30px;
|
||||
min-height: 30px;
|
||||
padding: 6px;
|
||||
@@ -1315,7 +1315,7 @@ form textarea {
|
||||
text-wrap: nowrap;
|
||||
}
|
||||
|
||||
form .field .fa-xmark {
|
||||
.field .fa-xmark {
|
||||
line-height: 20px;
|
||||
cursor: pointer;
|
||||
margin-left: auto;
|
||||
@@ -1323,11 +1323,11 @@ form .field .fa-xmark {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
form .field.saved .fa-xmark {
|
||||
.field.saved .fa-xmark {
|
||||
color: var(--accent)
|
||||
}
|
||||
|
||||
.settings .field, form .field {
|
||||
.settings .field, form .field, .chat-body .field {
|
||||
padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0;
|
||||
}
|
||||
|
||||
@@ -1359,7 +1359,7 @@ form .field.saved .fa-xmark {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.settings input, form input {
|
||||
.settings input, form input, .chat-body input {
|
||||
background-color: transparent;
|
||||
padding: 2px;
|
||||
border: none;
|
||||
@@ -1368,11 +1368,11 @@ form .field.saved .fa-xmark {
|
||||
color: var(--colour-3);
|
||||
}
|
||||
|
||||
.settings input:focus, form input:focus {
|
||||
.settings input:focus, form input:focus, .chat-body input:focus {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.settings .label, form .label, .settings label, form label {
|
||||
.settings .label, form .label, .settings label, form label, .chat-body label {
|
||||
font-size: 15px;
|
||||
margin-left: var(--inner-gap);
|
||||
}
|
||||
|
@@ -28,7 +28,7 @@ const switchInput = document.getElementById("switch");
|
||||
const searchButton = document.getElementById("search");
|
||||
const paperclip = document.querySelector(".user-input .fa-paperclip");
|
||||
|
||||
const optionElementsSelector = ".settings input, .settings textarea, #model, #model2, #provider";
|
||||
const optionElementsSelector = ".settings input, .settings textarea, .chat-body input, #model, #model2, #provider";
|
||||
|
||||
let provider_storage = {};
|
||||
let message_storage = {};
|
||||
@@ -153,7 +153,7 @@ const iframe_close = Object.assign(document.createElement("button"), {
|
||||
});
|
||||
iframe_close.onclick = () => iframe_container.classList.add("hidden");
|
||||
iframe_container.appendChild(iframe_close);
|
||||
chat.appendChild(iframe_container);
|
||||
document.body.appendChild(iframe_container);
|
||||
|
||||
class HtmlRenderPlugin {
|
||||
constructor(options = {}) {
|
||||
@@ -843,6 +843,16 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
|
||||
conversation.data[key] = value;
|
||||
}
|
||||
await save_conversation(conversation_id, conversation);
|
||||
} else if (message.type == "auth") {
|
||||
error_storage[message_id] = message.message
|
||||
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${message.message}`);
|
||||
let provider = provider_storage[message_id]?.name;
|
||||
let configEl = document.querySelector(`.settings .${provider}-api_key`);
|
||||
if (configEl) {
|
||||
configEl = configEl.parentElement.cloneNode(true);
|
||||
content_map.content.appendChild(configEl);
|
||||
await register_settings_storage();
|
||||
}
|
||||
} else if (message.type == "provider") {
|
||||
provider_storage[message_id] = message.provider;
|
||||
let provider_el = content_map.content.querySelector('.provider');
|
||||
@@ -1122,10 +1132,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
||||
let api_key;
|
||||
if (is_demo && !provider) {
|
||||
api_key = localStorage.getItem("HuggingFace-api_key");
|
||||
if (!api_key) {
|
||||
location.href = "/";
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
api_key = get_api_key_by_provider(provider);
|
||||
}
|
||||
@@ -1221,6 +1227,7 @@ function sanitize(input, replacement) {
|
||||
}
|
||||
|
||||
async function set_conversation_title(conversation_id, title) {
|
||||
window.chat_id = null;
|
||||
conversation = await get_conversation(conversation_id)
|
||||
conversation.new_title = title;
|
||||
const new_id = sanitize(title, " ");
|
||||
@@ -1742,12 +1749,22 @@ const load_conversations = async () => {
|
||||
|
||||
let html = [];
|
||||
conversations.forEach((conversation) => {
|
||||
// const length = conversation.items.map((item) => (
|
||||
// !item.content.toLowerCase().includes("hello") &&
|
||||
// !item.content.toLowerCase().includes("hi") &&
|
||||
// item.content
|
||||
// ) ? 1 : 0).reduce((a,b)=>a+b, 0);
|
||||
// if (!length) {
|
||||
// appStorage.removeItem(`conversation:${conversation.id}`);
|
||||
// return;
|
||||
// }
|
||||
const shareIcon = (conversation.id == window.start_id && window.chat_id) ? '<i class="fa-solid fa-qrcode"></i>': '';
|
||||
html.push(`
|
||||
<div class="convo" id="convo-${conversation.id}">
|
||||
<div class="left" onclick="set_conversation('${conversation.id}')">
|
||||
<i class="fa-regular fa-comments"></i>
|
||||
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
|
||||
<span class="convo-title">${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
|
||||
<span class="convo-title">${shareIcon} ${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
|
||||
</div>
|
||||
<i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i>
|
||||
<div id="cho-${conversation.id}" class="choise" style="display:none;">
|
||||
@@ -2060,7 +2077,6 @@ window.addEventListener('load', async function() {
|
||||
if (!window.conversation_id) {
|
||||
window.conversation_id = window.chat_id;
|
||||
}
|
||||
window.start_id = window.conversation_id
|
||||
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, {
|
||||
headers: {'accept': 'application/json'},
|
||||
});
|
||||
@@ -2075,6 +2091,7 @@ window.addEventListener('load', async function() {
|
||||
`conversation:${conversation.id}`,
|
||||
JSON.stringify(conversation)
|
||||
);
|
||||
await load_conversations();
|
||||
let refreshOnHide = true;
|
||||
document.addEventListener("visibilitychange", () => {
|
||||
if (document.hidden) {
|
||||
@@ -2091,6 +2108,9 @@ window.addEventListener('load', async function() {
|
||||
if (!refreshOnHide) {
|
||||
return;
|
||||
}
|
||||
if (window.conversation_id != window.start_id) {
|
||||
return;
|
||||
}
|
||||
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id}`, {
|
||||
headers: {'accept': 'application/json', 'if-none-match': conversation.updated},
|
||||
});
|
||||
@@ -2102,6 +2122,7 @@ window.addEventListener('load', async function() {
|
||||
`conversation:${conversation.id}`,
|
||||
JSON.stringify(conversation)
|
||||
);
|
||||
await load_conversations();
|
||||
await load_conversation(conversation);
|
||||
}
|
||||
}
|
||||
@@ -2284,7 +2305,7 @@ async function on_api() {
|
||||
}
|
||||
} else if (provider.login_url) {
|
||||
if (!login_urls[provider.name]) {
|
||||
login_urls[provider.name] = [provider.label, provider.login_url, [], provider.auth];
|
||||
login_urls[provider.name] = [provider.label, provider.login_url, [provider.name], provider.auth];
|
||||
} else {
|
||||
login_urls[provider.name][0] = provider.label;
|
||||
login_urls[provider.name][1] = provider.login_url;
|
||||
|
@@ -7,7 +7,7 @@ from typing import Iterator
|
||||
from flask import send_from_directory
|
||||
from inspect import signature
|
||||
|
||||
from ...errors import VersionNotFoundError
|
||||
from ...errors import VersionNotFoundError, MissingAuthError
|
||||
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
|
||||
from ...tools.run_tools import iter_run_tools
|
||||
from ...Provider import ProviderUtils, __providers__
|
||||
@@ -187,7 +187,8 @@ class Api:
|
||||
media = chunk
|
||||
if download_media or chunk.get("cookies"):
|
||||
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
|
||||
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
|
||||
tags = [tag for tag in [model, kwargs.get("aspect_ratio")] if tag]
|
||||
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags))
|
||||
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
|
||||
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
|
||||
elif isinstance(chunk, SynthesizeData):
|
||||
@@ -214,6 +215,8 @@ class Api:
|
||||
yield self._format_json(chunk.type, **chunk.get_dict())
|
||||
else:
|
||||
yield self._format_json("content", str(chunk))
|
||||
except MissingAuthError as e:
|
||||
yield self._format_json('auth', type(e).__name__, message=get_error_message(e))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
debug.error(e)
|
||||
|
@@ -16,7 +16,6 @@ from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from hashlib import sha256
|
||||
|
||||
from ...image import is_allowed_extension
|
||||
from ...client.service import convert_to_provider
|
||||
from ...providers.asyncio import to_sync_generator
|
||||
from ...client.helper import filter_markdown
|
||||
@@ -25,7 +24,7 @@ from ...tools.run_tools import iter_run_tools
|
||||
from ...errors import ProviderNotFoundError
|
||||
from ...image import is_allowed_extension
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...image.copy_images import secure_filename, get_source_url
|
||||
from ...image.copy_images import secure_filename, get_source_url, images_dir
|
||||
from ... import ChatCompletion
|
||||
from ... import models
|
||||
from .api import Api
|
||||
@@ -351,9 +350,30 @@ class Backend_Api(Api):
|
||||
return redirect(source_url)
|
||||
raise
|
||||
|
||||
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
|
||||
def get_media_sub(dirname, bucket_id, filename):
|
||||
return get_media(bucket_id, filename, dirname)
|
||||
@app.route('/search/<search>', methods=['GET'])
|
||||
def find_media(search: str, min: int = None):
|
||||
search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
|
||||
if min is None:
|
||||
min = len(search)
|
||||
if not os.access(images_dir, os.R_OK):
|
||||
return jsonify({"error": {"message": "Not found"}}), 404
|
||||
match_files = {}
|
||||
for root, _, files in os.walk(images_dir):
|
||||
for file in files:
|
||||
mime_type = is_allowed_extension(file)
|
||||
if mime_type is not None:
|
||||
mime_type = secure_filename(mime_type)
|
||||
for tag in search:
|
||||
if tag in mime_type:
|
||||
match_files[file] = match_files.get(file, 0) + 1
|
||||
break
|
||||
for tag in search:
|
||||
if tag in file.lower():
|
||||
match_files[file] = match_files.get(file, 0) + 1
|
||||
match_files = [file for file, count in match_files.items() if count >= min]
|
||||
if not match_files:
|
||||
return jsonify({"error": {"message": "Not found"}}), 404
|
||||
return redirect(f"/media/{random.choice(match_files)}")
|
||||
|
||||
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
|
||||
def upload_cookies():
|
||||
@@ -371,7 +391,7 @@ class Backend_Api(Api):
|
||||
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET'])
|
||||
def get_chat(chat_id: str) -> str:
|
||||
chat_id = secure_filename(chat_id)
|
||||
if int(self.chat_cache.get(chat_id, -1)) == int(request.headers.get("if-none-match", 0)):
|
||||
if self.chat_cache.get(chat_id, 0) == request.headers.get("if-none-match", 0):
|
||||
return jsonify({"error": {"message": "Not modified"}}), 304
|
||||
bucket_dir = get_bucket_dir(chat_id)
|
||||
file = os.path.join(bucket_dir, "chat.json")
|
||||
@@ -379,7 +399,7 @@ class Backend_Api(Api):
|
||||
return jsonify({"error": {"message": "Not found"}}), 404
|
||||
with open(file, 'r') as f:
|
||||
chat_data = json.load(f)
|
||||
if int(chat_data.get("updated", 0)) == int(request.headers.get("if-none-match", 0)):
|
||||
if chat_data.get("updated", 0) == request.headers.get("if-none-match", 0):
|
||||
return jsonify({"error": {"message": "Not modified"}}), 304
|
||||
self.chat_cache[chat_id] = chat_data.get("updated", 0)
|
||||
return jsonify(chat_data), 200
|
||||
@@ -387,12 +407,16 @@ class Backend_Api(Api):
|
||||
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST'])
|
||||
def upload_chat(chat_id: str) -> dict:
|
||||
chat_data = {**request.json}
|
||||
updated = chat_data.get("updated", 0)
|
||||
cache_value = self.chat_cache.get(chat_id, 0)
|
||||
if updated == cache_value:
|
||||
return jsonify({"error": {"message": "invalid date"}}), 400
|
||||
chat_id = secure_filename(chat_id)
|
||||
bucket_dir = get_bucket_dir(chat_id)
|
||||
os.makedirs(bucket_dir, exist_ok=True)
|
||||
with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
|
||||
json.dump(chat_data, f)
|
||||
self.chat_cache[chat_id] = chat_data.get("updated", 0)
|
||||
self.chat_cache[chat_id] = updated
|
||||
return {"chat_id": chat_id}
|
||||
|
||||
def handle_synthesize(self, provider: str):
|
||||
|
@@ -6,6 +6,7 @@ import io
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
try:
|
||||
from PIL.Image import open as open_image, new as new_image
|
||||
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
|
||||
@@ -17,15 +18,6 @@ from ..providers.helper import filter_none
|
||||
from ..typing import ImageType, Union, Image
|
||||
from ..errors import MissingRequirementsError
|
||||
|
||||
ALLOWED_EXTENSIONS = {
|
||||
# Image
|
||||
'png', 'jpg', 'jpeg', 'gif', 'webp',
|
||||
# Audio
|
||||
'wav', 'mp3', 'flac', 'opus', 'ogg',
|
||||
# Video
|
||||
'mkv', 'webm', 'mp4'
|
||||
}
|
||||
|
||||
MEDIA_TYPE_MAP: dict[str, str] = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
@@ -90,7 +82,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
|
||||
return image
|
||||
|
||||
def is_allowed_extension(filename: str) -> bool:
|
||||
def is_allowed_extension(filename: str) -> Optional[str]:
|
||||
"""
|
||||
Checks if the given filename has an allowed extension.
|
||||
|
||||
@@ -100,8 +92,8 @@ def is_allowed_extension(filename: str) -> bool:
|
||||
Returns:
|
||||
bool: True if the extension is allowed, False otherwise.
|
||||
"""
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None
|
||||
return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None
|
||||
|
||||
def is_data_an_media(data, filename: str = None) -> str:
|
||||
content_type = is_data_an_audio(data, filename)
|
||||
@@ -138,7 +130,7 @@ def is_data_uri_an_image(data_uri: str) -> bool:
|
||||
# Extract the image format from the data URI
|
||||
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
|
||||
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
|
||||
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
|
||||
if image_format not in EXTENSIONS_MAP and image_format != "svg+xml":
|
||||
raise ValueError("Invalid image format (from mime file type).")
|
||||
|
||||
def is_accepted_format(binary_data: bytes) -> str:
|
||||
|
@@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError
|
||||
|
||||
from ..typing import Optional, Cookies
|
||||
from ..requests.aiohttp import get_connector, StreamResponse
|
||||
from ..image import MEDIA_TYPE_MAP, ALLOWED_EXTENSIONS
|
||||
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
|
||||
from ..tools.files import get_bucket_dir
|
||||
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
|
||||
from ..Provider.template import BackendApi
|
||||
@@ -58,7 +58,7 @@ async def save_response_media(response: StreamResponse, prompt: str):
|
||||
content_type = response.headers["content-type"]
|
||||
if is_valid_media_type(content_type):
|
||||
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
|
||||
if extension not in ALLOWED_EXTENSIONS:
|
||||
if extension not in EXTENSIONS_MAP:
|
||||
raise ValueError(f"Unsupported media type: {content_type}")
|
||||
bucket_id = str(uuid.uuid4())
|
||||
dirname = str(int(time.time()))
|
||||
@@ -86,6 +86,7 @@ async def copy_media(
|
||||
headers: Optional[dict] = None,
|
||||
proxy: Optional[str] = None,
|
||||
alt: str = None,
|
||||
tags: list[str] = None,
|
||||
add_url: bool = True,
|
||||
target: str = None,
|
||||
ssl: bool = None
|
||||
@@ -113,6 +114,7 @@ async def copy_media(
|
||||
# Build safe filename with full Unicode support
|
||||
filename = secure_filename("".join((
|
||||
f"{int(time.time())}_",
|
||||
(f"{''.join(tags, '_')}_" if tags else ""),
|
||||
(f"{alt}_" if alt else ""),
|
||||
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
|
||||
f"{get_media_extension(image)}"
|
||||
|
@@ -18,6 +18,7 @@ from .Provider import (
|
||||
FreeGpt,
|
||||
HuggingSpace,
|
||||
G4F,
|
||||
Grok,
|
||||
DeepseekAI_JanusPro7b,
|
||||
Glider,
|
||||
Goabror,
|
||||
@@ -356,19 +357,19 @@ gemini_1_5_pro = Model(
|
||||
gemini_2_0_flash = Model(
|
||||
name = 'gemini-2.0-flash',
|
||||
base_provider = 'Google DeepMind',
|
||||
best_provider = IterListProvider([Dynaspark, GeminiPro, Liaobots])
|
||||
best_provider = IterListProvider([Dynaspark, GeminiPro, Gemini])
|
||||
)
|
||||
|
||||
gemini_2_0_flash_thinking = Model(
|
||||
name = 'gemini-2.0-flash-thinking',
|
||||
base_provider = 'Google DeepMind',
|
||||
best_provider = Liaobots
|
||||
best_provider = Gemini
|
||||
)
|
||||
|
||||
gemini_2_0_pro = Model(
|
||||
name = 'gemini-2.0-pro',
|
||||
gemini_2_0_flash_thinking_with_apps = Model(
|
||||
name = 'gemini-2.0-flash-thinking-with-apps',
|
||||
base_provider = 'Google DeepMind',
|
||||
best_provider = Liaobots
|
||||
best_provider = Gemini
|
||||
)
|
||||
|
||||
### Anthropic ###
|
||||
@@ -379,19 +380,6 @@ claude_3_haiku = Model(
|
||||
best_provider = IterListProvider([DDG, Jmuz])
|
||||
)
|
||||
|
||||
claude_3_sonnet = Model(
|
||||
name = 'claude-3-sonnet',
|
||||
base_provider = 'Anthropic',
|
||||
best_provider = Liaobots
|
||||
)
|
||||
|
||||
claude_3_opus = Model(
|
||||
name = 'claude-3-opus',
|
||||
base_provider = 'Anthropic',
|
||||
best_provider = IterListProvider([Jmuz, Liaobots])
|
||||
)
|
||||
|
||||
|
||||
# claude 3.5
|
||||
claude_3_5_sonnet = Model(
|
||||
name = 'claude-3.5-sonnet',
|
||||
@@ -406,12 +394,6 @@ claude_3_7_sonnet = Model(
|
||||
best_provider = IterListProvider([Blackbox, Liaobots])
|
||||
)
|
||||
|
||||
claude_3_7_sonnet_thinking = Model(
|
||||
name = 'claude-3.7-sonnet-thinking',
|
||||
base_provider = 'Anthropic',
|
||||
best_provider = Liaobots
|
||||
)
|
||||
|
||||
### Reka AI ###
|
||||
reka_core = Model(
|
||||
name = 'reka-core',
|
||||
@@ -548,13 +530,13 @@ janus_pro_7b = VisionModel(
|
||||
grok_3 = Model(
|
||||
name = 'grok-3',
|
||||
base_provider = 'x.ai',
|
||||
best_provider = Liaobots
|
||||
best_provider = Grok
|
||||
)
|
||||
|
||||
grok_3_r1 = Model(
|
||||
name = 'grok-3-r1',
|
||||
base_provider = 'x.ai',
|
||||
best_provider = Liaobots
|
||||
best_provider = Grok
|
||||
)
|
||||
|
||||
### Perplexity AI ###
|
||||
@@ -841,12 +823,10 @@ class ModelUtils:
|
||||
gemini_1_5_flash.name: gemini_1_5_flash,
|
||||
gemini_2_0_flash.name: gemini_2_0_flash,
|
||||
gemini_2_0_flash_thinking.name: gemini_2_0_flash_thinking,
|
||||
gemini_2_0_pro.name: gemini_2_0_pro,
|
||||
gemini_2_0_flash_thinking_with_apps.name: gemini_2_0_flash_thinking_with_apps,
|
||||
|
||||
### Anthropic ###
|
||||
# claude 3
|
||||
claude_3_opus.name: claude_3_opus,
|
||||
claude_3_sonnet.name: claude_3_sonnet,
|
||||
claude_3_haiku.name: claude_3_haiku,
|
||||
|
||||
# claude 3.5
|
||||
@@ -854,7 +834,6 @@ class ModelUtils:
|
||||
|
||||
# claude 3.7
|
||||
claude_3_7_sonnet.name: claude_3_7_sonnet,
|
||||
claude_3_7_sonnet_thinking.name: claude_3_7_sonnet_thinking,
|
||||
|
||||
### Reka AI ###
|
||||
reka_core.name: reka_core,
|
||||
|
@@ -366,11 +366,15 @@ class ProviderModelMixin:
|
||||
class RaiseErrorMixin():
|
||||
|
||||
@staticmethod
|
||||
def raise_error(data: dict):
|
||||
def raise_error(data: dict, status: int = None):
|
||||
if "error_message" in data:
|
||||
raise ResponseError(data["error_message"])
|
||||
elif "error" in data:
|
||||
if isinstance(data["error"], str):
|
||||
if status is not None:
|
||||
if status in (401, 402):
|
||||
raise MissingAuthError(f"Error {status}: {data['error']}")
|
||||
raise ResponseError(f"Error {status}: {data['error']}")
|
||||
raise ResponseError(data["error"])
|
||||
elif "code" in data["error"]:
|
||||
raise ResponseError("\n".join(
|
||||
|
@@ -4,7 +4,7 @@ from typing import Union
|
||||
from aiohttp import ClientResponse
|
||||
from requests import Response as RequestsResponse
|
||||
|
||||
from ..errors import ResponseStatusError, RateLimitError
|
||||
from ..errors import ResponseStatusError, RateLimitError, MissingAuthError
|
||||
from . import Response, StreamResponse
|
||||
|
||||
class CloudflareError(ResponseStatusError):
|
||||
@@ -23,6 +23,7 @@ def is_openai(text: str) -> bool:
|
||||
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
|
||||
if response.ok:
|
||||
return
|
||||
is_html = False
|
||||
if message is None:
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if content_type.startswith("application/json"):
|
||||
@@ -31,39 +32,42 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
||||
if isinstance(message, dict):
|
||||
message = message.get("message", message)
|
||||
else:
|
||||
text = (await response.text()).strip()
|
||||
message = (await response.text()).strip()
|
||||
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
|
||||
message = "HTML content" if is_html else text
|
||||
if message is None or message == "HTML content":
|
||||
if message is None or is_html:
|
||||
if response.status == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
elif response.status in (429, 402):
|
||||
message = "Rate limit"
|
||||
if response.status == 403 and is_cloudflare(text):
|
||||
if response.status in (401, 402):
|
||||
raise MissingAuthError(f"Response {response.status}: {message}")
|
||||
if response.status == 403 and is_cloudflare(message):
|
||||
raise CloudflareError(f"Response {response.status}: Cloudflare detected")
|
||||
elif response.status == 403 and is_openai(text):
|
||||
elif response.status == 403 and is_openai(message):
|
||||
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
|
||||
elif response.status == 502:
|
||||
raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
|
||||
elif response.status == 504:
|
||||
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
|
||||
else:
|
||||
raise ResponseStatusError(f"Response {response.status}: {message}")
|
||||
raise ResponseStatusError(f"Response {response.status}: {"HTML content" if is_html else message}")
|
||||
|
||||
def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None):
|
||||
if hasattr(response, "status"):
|
||||
return raise_for_status_async(response, message)
|
||||
if response.ok:
|
||||
return
|
||||
is_html = False
|
||||
if message is None:
|
||||
is_html = response.headers.get("content-type", "").startswith("text/html") or response.text.startswith("<!DOCTYPE")
|
||||
message = "HTML content" if is_html else response.text
|
||||
if message == "HTML content":
|
||||
message = response.text
|
||||
if message is None or is_html:
|
||||
if response.status_code == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
elif response.status_code in (429, 402):
|
||||
message = "Rate limit"
|
||||
raise RateLimitError(f"Response {response.status_code}: {message}")
|
||||
raise RateLimitError(f"Response {response.status_code}: Rate Limit")
|
||||
if response.status_code in (401, 402):
|
||||
raise MissingAuthError(f"Response {response.status_code}: {message}")
|
||||
if response.status_code == 403 and is_cloudflare(response.text):
|
||||
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
|
||||
elif response.status_code == 403 and is_openai(response.text):
|
||||
@@ -73,4 +77,4 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
|
||||
elif response.status_code == 504:
|
||||
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
|
||||
else:
|
||||
raise ResponseStatusError(f"Response {response.status_code}: {message}")
|
||||
raise ResponseStatusError(f"Response {response.status_code}: {"HTML content" if is_html else message}")
|
Reference in New Issue
Block a user