mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-16 13:20:43 +08:00
Set default model in HuggingFaceMedia
Improve handling of shared chats Show api_key input if required
This commit is contained in:
19
etc/examples/video.py
Normal file
19
etc/examples/video.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import g4f.Provider
|
||||||
|
from g4f.client import Client
|
||||||
|
|
||||||
|
client = Client(
|
||||||
|
provider=g4f.Provider.HuggingFaceMedia,
|
||||||
|
api_key="hf_***" # Your API key here
|
||||||
|
)
|
||||||
|
|
||||||
|
video_models = client.models.get_video()
|
||||||
|
|
||||||
|
print(video_models)
|
||||||
|
|
||||||
|
result = client.media.generate(
|
||||||
|
model=video_models[0],
|
||||||
|
prompt="G4F AI technology is the best in the world.",
|
||||||
|
response_format="url"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result.data[0].url)
|
@@ -66,6 +66,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
for provider_data in provider_keys:
|
for provider_data in provider_keys:
|
||||||
prepend_models.append(f"{model}:{provider_data.get('provider')}")
|
prepend_models.append(f"{model}:{provider_data.get('provider')}")
|
||||||
cls.models = prepend_models + [model for model in new_models if model not in prepend_models]
|
cls.models = prepend_models + [model for model in new_models if model not in prepend_models]
|
||||||
|
cls.image_models = [model for model, task in cls.task_mapping.items() if task == "text-to-image"]
|
||||||
|
cls.video_models = [model for model, task in cls.task_mapping.items() if task == "text-to-video"]
|
||||||
else:
|
else:
|
||||||
cls.models = []
|
cls.models = []
|
||||||
return cls.models
|
return cls.models
|
||||||
@@ -99,12 +101,14 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 0,
|
timeout: int = 0,
|
||||||
aspect_ratio: str = "1:1",
|
aspect_ratio: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
selected_provider = None
|
selected_provider = None
|
||||||
if ":" in model:
|
if model and ":" in model:
|
||||||
model, selected_provider = model.split(":", 1)
|
model, selected_provider = model.split(":", 1)
|
||||||
|
elif not model:
|
||||||
|
model = cls.get_models()[0]
|
||||||
provider_mapping = await cls.get_mapping(model, api_key)
|
provider_mapping = await cls.get_mapping(model, api_key)
|
||||||
headers = {
|
headers = {
|
||||||
'Accept-Encoding': 'gzip, deflate',
|
'Accept-Encoding': 'gzip, deflate',
|
||||||
@@ -133,11 +137,11 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
extra_data = {
|
extra_data = {
|
||||||
"num_inference_steps": 20,
|
"num_inference_steps": 20,
|
||||||
"resolution": "480p",
|
"resolution": "480p",
|
||||||
"aspect_ratio": aspect_ratio,
|
"aspect_ratio": "16:9" if aspect_ratio is None else aspect_ratio,
|
||||||
**extra_data
|
**extra_data
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
extra_data = use_aspect_ratio(extra_data, aspect_ratio)
|
extra_data = use_aspect_ratio(extra_data, "1:1" if aspect_ratio is None else aspect_ratio)
|
||||||
if provider_key == "fal-ai":
|
if provider_key == "fal-ai":
|
||||||
url = f"{api_base}/{provider_id}"
|
url = f"{api_base}/{provider_id}"
|
||||||
data = {
|
data = {
|
||||||
|
@@ -30,6 +30,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
default_model = "grok-3"
|
default_model = "grok-3"
|
||||||
models = [default_model, "grok-3-thinking", "grok-2"]
|
models = [default_model, "grok-3-thinking", "grok-2"]
|
||||||
|
model_aliases = {"grok-3-r1": "grok-3-thinking"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
|
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
|
||||||
@@ -73,7 +74,7 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
"sendFinalMetadata": True,
|
"sendFinalMetadata": True,
|
||||||
"customInstructions": "",
|
"customInstructions": "",
|
||||||
"deepsearchPreset": "",
|
"deepsearchPreset": "",
|
||||||
"isReasoning": model.endswith("-thinking"),
|
"isReasoning": model.endswith("-thinking") or model.endswith("-r1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@@ -92,7 +92,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
}
|
}
|
||||||
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
cls.raise_error(data)
|
cls.raise_error(data, response.status)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||||
return
|
return
|
||||||
@@ -135,7 +135,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
|
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
|
||||||
if content_type.startswith("application/json"):
|
if content_type.startswith("application/json"):
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
cls.raise_error(data)
|
cls.raise_error(data, response.status)
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
if "content" in choice["message"] and choice["message"]["content"]:
|
if "content" in choice["message"] and choice["message"]["content"]:
|
||||||
|
@@ -10,6 +10,7 @@ from email.utils import formatdate
|
|||||||
import os.path
|
import os.path
|
||||||
import hashlib
|
import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from urllib.parse import quote_plus
|
||||||
from fastapi import FastAPI, Response, Request, UploadFile, Depends
|
from fastapi import FastAPI, Response, Request, UploadFile, Depends
|
||||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||||
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
|
||||||
@@ -562,6 +563,10 @@ class Api:
|
|||||||
})
|
})
|
||||||
async def get_media(filename, request: Request):
|
async def get_media(filename, request: Request):
|
||||||
target = os.path.join(images_dir, os.path.basename(filename))
|
target = os.path.join(images_dir, os.path.basename(filename))
|
||||||
|
if not os.path.isfile(target):
|
||||||
|
other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename)))
|
||||||
|
if os.path.isfile(other_name):
|
||||||
|
target = other_name
|
||||||
ext = os.path.splitext(filename)[1][1:]
|
ext = os.path.splitext(filename)[1][1:]
|
||||||
mime_type = EXTENSIONS_MAP.get(ext)
|
mime_type = EXTENSIONS_MAP.get(ext)
|
||||||
stat_result = SimpleNamespace()
|
stat_result = SimpleNamespace()
|
||||||
|
@@ -19,7 +19,7 @@ from ..providers.asyncio import to_sync_generator
|
|||||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
||||||
from .image_models import ImageModels
|
from .image_models import MediaModels
|
||||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||||
from .service import get_model_and_provider, convert_to_provider
|
from .service import get_model_and_provider, convert_to_provider
|
||||||
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
from .helper import find_stop, filter_json, filter_none, safe_aclose
|
||||||
@@ -267,8 +267,11 @@ class Client(BaseClient):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.chat: Chat = Chat(self, provider)
|
self.chat: Chat = Chat(self, provider)
|
||||||
|
if image_provider is None:
|
||||||
|
image_provider = provider
|
||||||
|
self.models: MediaModels = MediaModels(self, image_provider)
|
||||||
self.images: Images = Images(self, image_provider)
|
self.images: Images = Images(self, image_provider)
|
||||||
self.media: Images = Images(self, image_provider)
|
self.media: Images = self.images
|
||||||
|
|
||||||
class Completions:
|
class Completions:
|
||||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||||
@@ -349,7 +352,6 @@ class Images:
|
|||||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||||
self.client: Client = client
|
self.client: Client = client
|
||||||
self.provider: Optional[ProviderType] = provider
|
self.provider: Optional[ProviderType] = provider
|
||||||
self.models: ImageModels = ImageModels(client)
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -369,7 +371,7 @@ class Images:
|
|||||||
if provider is None:
|
if provider is None:
|
||||||
provider_handler = self.provider
|
provider_handler = self.provider
|
||||||
if provider_handler is None:
|
if provider_handler is None:
|
||||||
provider_handler = self.models.get(model, default)
|
provider_handler = self.client.models.get(model, default)
|
||||||
elif isinstance(provider, str):
|
elif isinstance(provider, str):
|
||||||
provider_handler = convert_to_provider(provider)
|
provider_handler = convert_to_provider(provider)
|
||||||
else:
|
else:
|
||||||
@@ -385,19 +387,21 @@ class Images:
|
|||||||
provider: Optional[ProviderType] = None,
|
provider: Optional[ProviderType] = None,
|
||||||
response_format: Optional[str] = None,
|
response_format: Optional[str] = None,
|
||||||
proxy: Optional[str] = None,
|
proxy: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> ImagesResponse:
|
) -> ImagesResponse:
|
||||||
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
|
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
|
||||||
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
|
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
|
||||||
if proxy is None:
|
if proxy is None:
|
||||||
proxy = self.client.proxy
|
proxy = self.client.proxy
|
||||||
|
if api_key is None:
|
||||||
|
api_key = self.client.api_key
|
||||||
error = None
|
error = None
|
||||||
response = None
|
response = None
|
||||||
if isinstance(provider_handler, IterListProvider):
|
if isinstance(provider_handler, IterListProvider):
|
||||||
for provider in provider_handler.providers:
|
for provider in provider_handler.providers:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
|
response = await self._generate_image_response(provider, provider.__name__, model, prompt, proxy=proxy, **kwargs)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
provider_name = provider.__name__
|
provider_name = provider.__name__
|
||||||
break
|
break
|
||||||
@@ -405,7 +409,7 @@ class Images:
|
|||||||
error = e
|
error = e
|
||||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||||
else:
|
else:
|
||||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
if isinstance(response, MediaResponse):
|
if isinstance(response, MediaResponse):
|
||||||
return await self._process_image_response(
|
return await self._process_image_response(
|
||||||
@@ -534,7 +538,7 @@ class Images:
|
|||||||
else:
|
else:
|
||||||
# Save locally for None (default) case
|
# Save locally for None (default) case
|
||||||
images = await copy_media(response.get_list(), response.get("cookies"), proxy)
|
images = await copy_media(response.get_list(), response.get("cookies"), proxy)
|
||||||
images = [Image.model_construct(url=f"/media/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
|
images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in images]
|
||||||
|
|
||||||
return ImagesResponse.model_construct(
|
return ImagesResponse.model_construct(
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -552,6 +556,9 @@ class AsyncClient(BaseClient):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.chat: AsyncChat = AsyncChat(self, provider)
|
self.chat: AsyncChat = AsyncChat(self, provider)
|
||||||
|
if image_provider is None:
|
||||||
|
image_provider = provider
|
||||||
|
self.models: MediaModels = MediaModels(self, image_provider)
|
||||||
self.images: AsyncImages = AsyncImages(self, image_provider)
|
self.images: AsyncImages = AsyncImages(self, image_provider)
|
||||||
self.media: AsyncImages = self.images
|
self.media: AsyncImages = self.images
|
||||||
|
|
||||||
@@ -635,7 +642,6 @@ class AsyncImages(Images):
|
|||||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||||
self.client: AsyncClient = client
|
self.client: AsyncClient = client
|
||||||
self.provider: Optional[ProviderType] = provider
|
self.provider: Optional[ProviderType] = provider
|
||||||
self.models: ImageModels = ImageModels(client)
|
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
@@ -1,15 +1,43 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..models import ModelUtils
|
from ..models import ModelUtils, ImageModel
|
||||||
from ..Provider import ProviderUtils
|
from ..Provider import ProviderUtils
|
||||||
|
from ..providers.types import ProviderType
|
||||||
|
|
||||||
class ImageModels():
|
class MediaModels():
|
||||||
def __init__(self, client):
|
def __init__(self, client, provider: ProviderType = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
def get(self, name, default=None):
|
def get(self, name, default=None) -> ProviderType:
|
||||||
if name in ModelUtils.convert:
|
if name in ModelUtils.convert:
|
||||||
return ModelUtils.convert[name].best_provider
|
return ModelUtils.convert[name].best_provider
|
||||||
if name in ProviderUtils.convert:
|
if name in ProviderUtils.convert:
|
||||||
return ProviderUtils.convert[name]
|
return ProviderUtils.convert[name]
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
|
||||||
|
if self.provider is None:
|
||||||
|
return []
|
||||||
|
if api_key is None:
|
||||||
|
api_key = self.client.api_key
|
||||||
|
return self.provider.get_models(
|
||||||
|
**kwargs,
|
||||||
|
**{} if api_key is None else {"api_key": api_key}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image(self, **kwargs) -> list[str]:
|
||||||
|
if self.provider is None:
|
||||||
|
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
|
||||||
|
self.get_all(**kwargs)
|
||||||
|
if hasattr(self.provider, "image_models"):
|
||||||
|
return self.provider.image_models
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_video(self, **kwargs) -> list[str]:
|
||||||
|
if self.provider is None:
|
||||||
|
return []
|
||||||
|
self.get_all(**kwargs)
|
||||||
|
if hasattr(self.provider, "video_models"):
|
||||||
|
return self.provider.video_models
|
||||||
|
return []
|
@@ -61,9 +61,12 @@
|
|||||||
const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">';
|
const gpt_image = '<img src="/static/img/gpt.png" alt="your avatar">';
|
||||||
</script>
|
</script>
|
||||||
<script src="/static/js/highlight.min.js" async></script>
|
<script src="/static/js/highlight.min.js" async></script>
|
||||||
|
<script>
|
||||||
<script>window.conversation_id = "{{conversation_id}}"</script>
|
window.conversation_id = "{{conversation_id}}";
|
||||||
<script>window.chat_id = "{{chat_id}}"; window.share_url = "{{share_url}}";</script>
|
window.chat_id = "{{chat_id}}";
|
||||||
|
window.share_url = "{{share_url}}";
|
||||||
|
window.start_id = "{{conversation_id}}";
|
||||||
|
</script>
|
||||||
<title>G4F Chat</title>
|
<title>G4F Chat</title>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
|
@@ -7,9 +7,9 @@
|
|||||||
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
|
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
|
||||||
<style>
|
<style>
|
||||||
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
|
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
|
||||||
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover;}
|
video { width: 400px; height: 400px; border: 1px solid black; display: block; margin: auto; object-fit: cover; max-width: 100%;}
|
||||||
#qrcode { margin-top: 20px; }
|
#qrcode { margin-top: 20px; }
|
||||||
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; }
|
#qrcode img, #qrcode canvas { margin: 0 auto; width: 400px; height: 400px; max-width: 100%;}
|
||||||
button { margin: 5px; padding: 10px; }
|
button { margin: 5px; padding: 10px; }
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
|
@@ -881,7 +881,7 @@ input.model:hover
|
|||||||
padding: var(--inner-gap) 28px;
|
padding: var(--inner-gap) 28px;
|
||||||
}
|
}
|
||||||
|
|
||||||
#systemPrompt, #chatPrompt, .settings textarea, form textarea {
|
#systemPrompt, #chatPrompt, .settings textarea, form textarea, .chat-body textarea {
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
color: var(--colour-3);
|
color: var(--colour-3);
|
||||||
outline: none;
|
outline: none;
|
||||||
@@ -1305,7 +1305,7 @@ form textarea {
|
|||||||
padding: 0;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.settings textarea {
|
.settings textarea, .chat-body textarea {
|
||||||
height: 30px;
|
height: 30px;
|
||||||
min-height: 30px;
|
min-height: 30px;
|
||||||
padding: 6px;
|
padding: 6px;
|
||||||
@@ -1315,7 +1315,7 @@ form textarea {
|
|||||||
text-wrap: nowrap;
|
text-wrap: nowrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
form .field .fa-xmark {
|
.field .fa-xmark {
|
||||||
line-height: 20px;
|
line-height: 20px;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
margin-left: auto;
|
margin-left: auto;
|
||||||
@@ -1323,11 +1323,11 @@ form .field .fa-xmark {
|
|||||||
margin-top: 0;
|
margin-top: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
form .field.saved .fa-xmark {
|
.field.saved .fa-xmark {
|
||||||
color: var(--accent)
|
color: var(--accent)
|
||||||
}
|
}
|
||||||
|
|
||||||
.settings .field, form .field {
|
.settings .field, form .field, .chat-body .field {
|
||||||
padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0;
|
padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1359,7 +1359,7 @@ form .field.saved .fa-xmark {
|
|||||||
border: none;
|
border: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.settings input, form input {
|
.settings input, form input, .chat-body input {
|
||||||
background-color: transparent;
|
background-color: transparent;
|
||||||
padding: 2px;
|
padding: 2px;
|
||||||
border: none;
|
border: none;
|
||||||
@@ -1368,11 +1368,11 @@ form .field.saved .fa-xmark {
|
|||||||
color: var(--colour-3);
|
color: var(--colour-3);
|
||||||
}
|
}
|
||||||
|
|
||||||
.settings input:focus, form input:focus {
|
.settings input:focus, form input:focus, .chat-body input:focus {
|
||||||
outline: none;
|
outline: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.settings .label, form .label, .settings label, form label {
|
.settings .label, form .label, .settings label, form label, .chat-body label {
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
margin-left: var(--inner-gap);
|
margin-left: var(--inner-gap);
|
||||||
}
|
}
|
||||||
|
@@ -28,7 +28,7 @@ const switchInput = document.getElementById("switch");
|
|||||||
const searchButton = document.getElementById("search");
|
const searchButton = document.getElementById("search");
|
||||||
const paperclip = document.querySelector(".user-input .fa-paperclip");
|
const paperclip = document.querySelector(".user-input .fa-paperclip");
|
||||||
|
|
||||||
const optionElementsSelector = ".settings input, .settings textarea, #model, #model2, #provider";
|
const optionElementsSelector = ".settings input, .settings textarea, .chat-body input, #model, #model2, #provider";
|
||||||
|
|
||||||
let provider_storage = {};
|
let provider_storage = {};
|
||||||
let message_storage = {};
|
let message_storage = {};
|
||||||
@@ -153,7 +153,7 @@ const iframe_close = Object.assign(document.createElement("button"), {
|
|||||||
});
|
});
|
||||||
iframe_close.onclick = () => iframe_container.classList.add("hidden");
|
iframe_close.onclick = () => iframe_container.classList.add("hidden");
|
||||||
iframe_container.appendChild(iframe_close);
|
iframe_container.appendChild(iframe_close);
|
||||||
chat.appendChild(iframe_container);
|
document.body.appendChild(iframe_container);
|
||||||
|
|
||||||
class HtmlRenderPlugin {
|
class HtmlRenderPlugin {
|
||||||
constructor(options = {}) {
|
constructor(options = {}) {
|
||||||
@@ -843,6 +843,16 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
|
|||||||
conversation.data[key] = value;
|
conversation.data[key] = value;
|
||||||
}
|
}
|
||||||
await save_conversation(conversation_id, conversation);
|
await save_conversation(conversation_id, conversation);
|
||||||
|
} else if (message.type == "auth") {
|
||||||
|
error_storage[message_id] = message.message
|
||||||
|
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${message.message}`);
|
||||||
|
let provider = provider_storage[message_id]?.name;
|
||||||
|
let configEl = document.querySelector(`.settings .${provider}-api_key`);
|
||||||
|
if (configEl) {
|
||||||
|
configEl = configEl.parentElement.cloneNode(true);
|
||||||
|
content_map.content.appendChild(configEl);
|
||||||
|
await register_settings_storage();
|
||||||
|
}
|
||||||
} else if (message.type == "provider") {
|
} else if (message.type == "provider") {
|
||||||
provider_storage[message_id] = message.provider;
|
provider_storage[message_id] = message.provider;
|
||||||
let provider_el = content_map.content.querySelector('.provider');
|
let provider_el = content_map.content.querySelector('.provider');
|
||||||
@@ -1122,10 +1132,6 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||||||
let api_key;
|
let api_key;
|
||||||
if (is_demo && !provider) {
|
if (is_demo && !provider) {
|
||||||
api_key = localStorage.getItem("HuggingFace-api_key");
|
api_key = localStorage.getItem("HuggingFace-api_key");
|
||||||
if (!api_key) {
|
|
||||||
location.href = "/";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
api_key = get_api_key_by_provider(provider);
|
api_key = get_api_key_by_provider(provider);
|
||||||
}
|
}
|
||||||
@@ -1221,6 +1227,7 @@ function sanitize(input, replacement) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function set_conversation_title(conversation_id, title) {
|
async function set_conversation_title(conversation_id, title) {
|
||||||
|
window.chat_id = null;
|
||||||
conversation = await get_conversation(conversation_id)
|
conversation = await get_conversation(conversation_id)
|
||||||
conversation.new_title = title;
|
conversation.new_title = title;
|
||||||
const new_id = sanitize(title, " ");
|
const new_id = sanitize(title, " ");
|
||||||
@@ -1742,12 +1749,22 @@ const load_conversations = async () => {
|
|||||||
|
|
||||||
let html = [];
|
let html = [];
|
||||||
conversations.forEach((conversation) => {
|
conversations.forEach((conversation) => {
|
||||||
|
// const length = conversation.items.map((item) => (
|
||||||
|
// !item.content.toLowerCase().includes("hello") &&
|
||||||
|
// !item.content.toLowerCase().includes("hi") &&
|
||||||
|
// item.content
|
||||||
|
// ) ? 1 : 0).reduce((a,b)=>a+b, 0);
|
||||||
|
// if (!length) {
|
||||||
|
// appStorage.removeItem(`conversation:${conversation.id}`);
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
const shareIcon = (conversation.id == window.start_id && window.chat_id) ? '<i class="fa-solid fa-qrcode"></i>': '';
|
||||||
html.push(`
|
html.push(`
|
||||||
<div class="convo" id="convo-${conversation.id}">
|
<div class="convo" id="convo-${conversation.id}">
|
||||||
<div class="left" onclick="set_conversation('${conversation.id}')">
|
<div class="left" onclick="set_conversation('${conversation.id}')">
|
||||||
<i class="fa-regular fa-comments"></i>
|
<i class="fa-regular fa-comments"></i>
|
||||||
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
|
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
|
||||||
<span class="convo-title">${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
|
<span class="convo-title">${shareIcon} ${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
|
||||||
</div>
|
</div>
|
||||||
<i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i>
|
<i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i>
|
||||||
<div id="cho-${conversation.id}" class="choise" style="display:none;">
|
<div id="cho-${conversation.id}" class="choise" style="display:none;">
|
||||||
@@ -2060,7 +2077,6 @@ window.addEventListener('load', async function() {
|
|||||||
if (!window.conversation_id) {
|
if (!window.conversation_id) {
|
||||||
window.conversation_id = window.chat_id;
|
window.conversation_id = window.chat_id;
|
||||||
}
|
}
|
||||||
window.start_id = window.conversation_id
|
|
||||||
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, {
|
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id ? window.chat_id : window.conversation_id}`, {
|
||||||
headers: {'accept': 'application/json'},
|
headers: {'accept': 'application/json'},
|
||||||
});
|
});
|
||||||
@@ -2075,6 +2091,7 @@ window.addEventListener('load', async function() {
|
|||||||
`conversation:${conversation.id}`,
|
`conversation:${conversation.id}`,
|
||||||
JSON.stringify(conversation)
|
JSON.stringify(conversation)
|
||||||
);
|
);
|
||||||
|
await load_conversations();
|
||||||
let refreshOnHide = true;
|
let refreshOnHide = true;
|
||||||
document.addEventListener("visibilitychange", () => {
|
document.addEventListener("visibilitychange", () => {
|
||||||
if (document.hidden) {
|
if (document.hidden) {
|
||||||
@@ -2091,6 +2108,9 @@ window.addEventListener('load', async function() {
|
|||||||
if (!refreshOnHide) {
|
if (!refreshOnHide) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (window.conversation_id != window.start_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id}`, {
|
const response = await fetch(`${window.share_url}/backend-api/v2/chat/${window.chat_id}`, {
|
||||||
headers: {'accept': 'application/json', 'if-none-match': conversation.updated},
|
headers: {'accept': 'application/json', 'if-none-match': conversation.updated},
|
||||||
});
|
});
|
||||||
@@ -2102,6 +2122,7 @@ window.addEventListener('load', async function() {
|
|||||||
`conversation:${conversation.id}`,
|
`conversation:${conversation.id}`,
|
||||||
JSON.stringify(conversation)
|
JSON.stringify(conversation)
|
||||||
);
|
);
|
||||||
|
await load_conversations();
|
||||||
await load_conversation(conversation);
|
await load_conversation(conversation);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2284,7 +2305,7 @@ async function on_api() {
|
|||||||
}
|
}
|
||||||
} else if (provider.login_url) {
|
} else if (provider.login_url) {
|
||||||
if (!login_urls[provider.name]) {
|
if (!login_urls[provider.name]) {
|
||||||
login_urls[provider.name] = [provider.label, provider.login_url, [], provider.auth];
|
login_urls[provider.name] = [provider.label, provider.login_url, [provider.name], provider.auth];
|
||||||
} else {
|
} else {
|
||||||
login_urls[provider.name][0] = provider.label;
|
login_urls[provider.name][0] = provider.label;
|
||||||
login_urls[provider.name][1] = provider.login_url;
|
login_urls[provider.name][1] = provider.login_url;
|
||||||
|
@@ -7,7 +7,7 @@ from typing import Iterator
|
|||||||
from flask import send_from_directory
|
from flask import send_from_directory
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
from ...errors import VersionNotFoundError
|
from ...errors import VersionNotFoundError, MissingAuthError
|
||||||
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
|
from ...image.copy_images import copy_media, ensure_images_dir, images_dir
|
||||||
from ...tools.run_tools import iter_run_tools
|
from ...tools.run_tools import iter_run_tools
|
||||||
from ...Provider import ProviderUtils, __providers__
|
from ...Provider import ProviderUtils, __providers__
|
||||||
@@ -187,7 +187,8 @@ class Api:
|
|||||||
media = chunk
|
media = chunk
|
||||||
if download_media or chunk.get("cookies"):
|
if download_media or chunk.get("cookies"):
|
||||||
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
|
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
|
||||||
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
|
tags = [tag for tag in [model, kwargs.get("aspect_ratio")] if tag]
|
||||||
|
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags))
|
||||||
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
|
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
|
||||||
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
|
yield self._format_json("content", str(media), images=chunk.get_list(), alt=chunk.alt)
|
||||||
elif isinstance(chunk, SynthesizeData):
|
elif isinstance(chunk, SynthesizeData):
|
||||||
@@ -214,6 +215,8 @@ class Api:
|
|||||||
yield self._format_json(chunk.type, **chunk.get_dict())
|
yield self._format_json(chunk.type, **chunk.get_dict())
|
||||||
else:
|
else:
|
||||||
yield self._format_json("content", str(chunk))
|
yield self._format_json("content", str(chunk))
|
||||||
|
except MissingAuthError as e:
|
||||||
|
yield self._format_json('auth', type(e).__name__, message=get_error_message(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
debug.error(e)
|
debug.error(e)
|
||||||
|
@@ -16,7 +16,6 @@ from pathlib import Path
|
|||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
||||||
from ...image import is_allowed_extension
|
|
||||||
from ...client.service import convert_to_provider
|
from ...client.service import convert_to_provider
|
||||||
from ...providers.asyncio import to_sync_generator
|
from ...providers.asyncio import to_sync_generator
|
||||||
from ...client.helper import filter_markdown
|
from ...client.helper import filter_markdown
|
||||||
@@ -25,7 +24,7 @@ from ...tools.run_tools import iter_run_tools
|
|||||||
from ...errors import ProviderNotFoundError
|
from ...errors import ProviderNotFoundError
|
||||||
from ...image import is_allowed_extension
|
from ...image import is_allowed_extension
|
||||||
from ...cookies import get_cookies_dir
|
from ...cookies import get_cookies_dir
|
||||||
from ...image.copy_images import secure_filename, get_source_url
|
from ...image.copy_images import secure_filename, get_source_url, images_dir
|
||||||
from ... import ChatCompletion
|
from ... import ChatCompletion
|
||||||
from ... import models
|
from ... import models
|
||||||
from .api import Api
|
from .api import Api
|
||||||
@@ -351,9 +350,30 @@ class Backend_Api(Api):
|
|||||||
return redirect(source_url)
|
return redirect(source_url)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET'])
|
@app.route('/search/<search>', methods=['GET'])
|
||||||
def get_media_sub(dirname, bucket_id, filename):
|
def find_media(search: str, min: int = None):
|
||||||
return get_media(bucket_id, filename, dirname)
|
search = [secure_filename(chunk.lower()) for chunk in search.split("+")]
|
||||||
|
if min is None:
|
||||||
|
min = len(search)
|
||||||
|
if not os.access(images_dir, os.R_OK):
|
||||||
|
return jsonify({"error": {"message": "Not found"}}), 404
|
||||||
|
match_files = {}
|
||||||
|
for root, _, files in os.walk(images_dir):
|
||||||
|
for file in files:
|
||||||
|
mime_type = is_allowed_extension(file)
|
||||||
|
if mime_type is not None:
|
||||||
|
mime_type = secure_filename(mime_type)
|
||||||
|
for tag in search:
|
||||||
|
if tag in mime_type:
|
||||||
|
match_files[file] = match_files.get(file, 0) + 1
|
||||||
|
break
|
||||||
|
for tag in search:
|
||||||
|
if tag in file.lower():
|
||||||
|
match_files[file] = match_files.get(file, 0) + 1
|
||||||
|
match_files = [file for file, count in match_files.items() if count >= min]
|
||||||
|
if not match_files:
|
||||||
|
return jsonify({"error": {"message": "Not found"}}), 404
|
||||||
|
return redirect(f"/media/{random.choice(match_files)}")
|
||||||
|
|
||||||
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
|
@app.route('/backend-api/v2/upload_cookies', methods=['POST'])
|
||||||
def upload_cookies():
|
def upload_cookies():
|
||||||
@@ -371,7 +391,7 @@ class Backend_Api(Api):
|
|||||||
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET'])
|
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['GET'])
|
||||||
def get_chat(chat_id: str) -> str:
|
def get_chat(chat_id: str) -> str:
|
||||||
chat_id = secure_filename(chat_id)
|
chat_id = secure_filename(chat_id)
|
||||||
if int(self.chat_cache.get(chat_id, -1)) == int(request.headers.get("if-none-match", 0)):
|
if self.chat_cache.get(chat_id, 0) == request.headers.get("if-none-match", 0):
|
||||||
return jsonify({"error": {"message": "Not modified"}}), 304
|
return jsonify({"error": {"message": "Not modified"}}), 304
|
||||||
bucket_dir = get_bucket_dir(chat_id)
|
bucket_dir = get_bucket_dir(chat_id)
|
||||||
file = os.path.join(bucket_dir, "chat.json")
|
file = os.path.join(bucket_dir, "chat.json")
|
||||||
@@ -379,7 +399,7 @@ class Backend_Api(Api):
|
|||||||
return jsonify({"error": {"message": "Not found"}}), 404
|
return jsonify({"error": {"message": "Not found"}}), 404
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
chat_data = json.load(f)
|
chat_data = json.load(f)
|
||||||
if int(chat_data.get("updated", 0)) == int(request.headers.get("if-none-match", 0)):
|
if chat_data.get("updated", 0) == request.headers.get("if-none-match", 0):
|
||||||
return jsonify({"error": {"message": "Not modified"}}), 304
|
return jsonify({"error": {"message": "Not modified"}}), 304
|
||||||
self.chat_cache[chat_id] = chat_data.get("updated", 0)
|
self.chat_cache[chat_id] = chat_data.get("updated", 0)
|
||||||
return jsonify(chat_data), 200
|
return jsonify(chat_data), 200
|
||||||
@@ -387,12 +407,16 @@ class Backend_Api(Api):
|
|||||||
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST'])
|
@self.app.route('/backend-api/v2/chat/<chat_id>', methods=['POST'])
|
||||||
def upload_chat(chat_id: str) -> dict:
|
def upload_chat(chat_id: str) -> dict:
|
||||||
chat_data = {**request.json}
|
chat_data = {**request.json}
|
||||||
|
updated = chat_data.get("updated", 0)
|
||||||
|
cache_value = self.chat_cache.get(chat_id, 0)
|
||||||
|
if updated == cache_value:
|
||||||
|
return jsonify({"error": {"message": "invalid date"}}), 400
|
||||||
chat_id = secure_filename(chat_id)
|
chat_id = secure_filename(chat_id)
|
||||||
bucket_dir = get_bucket_dir(chat_id)
|
bucket_dir = get_bucket_dir(chat_id)
|
||||||
os.makedirs(bucket_dir, exist_ok=True)
|
os.makedirs(bucket_dir, exist_ok=True)
|
||||||
with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
|
with open(os.path.join(bucket_dir, "chat.json"), 'w') as f:
|
||||||
json.dump(chat_data, f)
|
json.dump(chat_data, f)
|
||||||
self.chat_cache[chat_id] = chat_data.get("updated", 0)
|
self.chat_cache[chat_id] = updated
|
||||||
return {"chat_id": chat_id}
|
return {"chat_id": chat_id}
|
||||||
|
|
||||||
def handle_synthesize(self, provider: str):
|
def handle_synthesize(self, provider: str):
|
||||||
|
@@ -6,6 +6,7 @@ import io
|
|||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
try:
|
try:
|
||||||
from PIL.Image import open as open_image, new as new_image
|
from PIL.Image import open as open_image, new as new_image
|
||||||
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
|
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
|
||||||
@@ -17,15 +18,6 @@ from ..providers.helper import filter_none
|
|||||||
from ..typing import ImageType, Union, Image
|
from ..typing import ImageType, Union, Image
|
||||||
from ..errors import MissingRequirementsError
|
from ..errors import MissingRequirementsError
|
||||||
|
|
||||||
ALLOWED_EXTENSIONS = {
|
|
||||||
# Image
|
|
||||||
'png', 'jpg', 'jpeg', 'gif', 'webp',
|
|
||||||
# Audio
|
|
||||||
'wav', 'mp3', 'flac', 'opus', 'ogg',
|
|
||||||
# Video
|
|
||||||
'mkv', 'webm', 'mp4'
|
|
||||||
}
|
|
||||||
|
|
||||||
MEDIA_TYPE_MAP: dict[str, str] = {
|
MEDIA_TYPE_MAP: dict[str, str] = {
|
||||||
"image/png": "png",
|
"image/png": "png",
|
||||||
"image/jpeg": "jpg",
|
"image/jpeg": "jpg",
|
||||||
@@ -90,7 +82,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def is_allowed_extension(filename: str) -> bool:
|
def is_allowed_extension(filename: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Checks if the given filename has an allowed extension.
|
Checks if the given filename has an allowed extension.
|
||||||
|
|
||||||
@@ -100,8 +92,8 @@ def is_allowed_extension(filename: str) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the extension is allowed, False otherwise.
|
bool: True if the extension is allowed, False otherwise.
|
||||||
"""
|
"""
|
||||||
return '.' in filename and \
|
ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None
|
||||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None
|
||||||
|
|
||||||
def is_data_an_media(data, filename: str = None) -> str:
|
def is_data_an_media(data, filename: str = None) -> str:
|
||||||
content_type = is_data_an_audio(data, filename)
|
content_type = is_data_an_audio(data, filename)
|
||||||
@@ -138,7 +130,7 @@ def is_data_uri_an_image(data_uri: str) -> bool:
|
|||||||
# Extract the image format from the data URI
|
# Extract the image format from the data URI
|
||||||
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
|
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
|
||||||
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
|
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
|
||||||
if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
|
if image_format not in EXTENSIONS_MAP and image_format != "svg+xml":
|
||||||
raise ValueError("Invalid image format (from mime file type).")
|
raise ValueError("Invalid image format (from mime file type).")
|
||||||
|
|
||||||
def is_accepted_format(binary_data: bytes) -> str:
|
def is_accepted_format(binary_data: bytes) -> str:
|
||||||
|
@@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError
|
|||||||
|
|
||||||
from ..typing import Optional, Cookies
|
from ..typing import Optional, Cookies
|
||||||
from ..requests.aiohttp import get_connector, StreamResponse
|
from ..requests.aiohttp import get_connector, StreamResponse
|
||||||
from ..image import MEDIA_TYPE_MAP, ALLOWED_EXTENSIONS
|
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
|
||||||
from ..tools.files import get_bucket_dir
|
from ..tools.files import get_bucket_dir
|
||||||
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
|
from ..providers.response import ImageResponse, AudioResponse, VideoResponse
|
||||||
from ..Provider.template import BackendApi
|
from ..Provider.template import BackendApi
|
||||||
@@ -58,7 +58,7 @@ async def save_response_media(response: StreamResponse, prompt: str):
|
|||||||
content_type = response.headers["content-type"]
|
content_type = response.headers["content-type"]
|
||||||
if is_valid_media_type(content_type):
|
if is_valid_media_type(content_type):
|
||||||
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
|
extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
|
||||||
if extension not in ALLOWED_EXTENSIONS:
|
if extension not in EXTENSIONS_MAP:
|
||||||
raise ValueError(f"Unsupported media type: {content_type}")
|
raise ValueError(f"Unsupported media type: {content_type}")
|
||||||
bucket_id = str(uuid.uuid4())
|
bucket_id = str(uuid.uuid4())
|
||||||
dirname = str(int(time.time()))
|
dirname = str(int(time.time()))
|
||||||
@@ -86,6 +86,7 @@ async def copy_media(
|
|||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
proxy: Optional[str] = None,
|
proxy: Optional[str] = None,
|
||||||
alt: str = None,
|
alt: str = None,
|
||||||
|
tags: list[str] = None,
|
||||||
add_url: bool = True,
|
add_url: bool = True,
|
||||||
target: str = None,
|
target: str = None,
|
||||||
ssl: bool = None
|
ssl: bool = None
|
||||||
@@ -113,6 +114,7 @@ async def copy_media(
|
|||||||
# Build safe filename with full Unicode support
|
# Build safe filename with full Unicode support
|
||||||
filename = secure_filename("".join((
|
filename = secure_filename("".join((
|
||||||
f"{int(time.time())}_",
|
f"{int(time.time())}_",
|
||||||
|
(f"{''.join(tags, '_')}_" if tags else ""),
|
||||||
(f"{alt}_" if alt else ""),
|
(f"{alt}_" if alt else ""),
|
||||||
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
|
f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
|
||||||
f"{get_media_extension(image)}"
|
f"{get_media_extension(image)}"
|
||||||
|
@@ -18,6 +18,7 @@ from .Provider import (
|
|||||||
FreeGpt,
|
FreeGpt,
|
||||||
HuggingSpace,
|
HuggingSpace,
|
||||||
G4F,
|
G4F,
|
||||||
|
Grok,
|
||||||
DeepseekAI_JanusPro7b,
|
DeepseekAI_JanusPro7b,
|
||||||
Glider,
|
Glider,
|
||||||
Goabror,
|
Goabror,
|
||||||
@@ -356,19 +357,19 @@ gemini_1_5_pro = Model(
|
|||||||
gemini_2_0_flash = Model(
|
gemini_2_0_flash = Model(
|
||||||
name = 'gemini-2.0-flash',
|
name = 'gemini-2.0-flash',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = IterListProvider([Dynaspark, GeminiPro, Liaobots])
|
best_provider = IterListProvider([Dynaspark, GeminiPro, Gemini])
|
||||||
)
|
)
|
||||||
|
|
||||||
gemini_2_0_flash_thinking = Model(
|
gemini_2_0_flash_thinking = Model(
|
||||||
name = 'gemini-2.0-flash-thinking',
|
name = 'gemini-2.0-flash-thinking',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = Liaobots
|
best_provider = Gemini
|
||||||
)
|
)
|
||||||
|
|
||||||
gemini_2_0_pro = Model(
|
gemini_2_0_flash_thinking_with_apps = Model(
|
||||||
name = 'gemini-2.0-pro',
|
name = 'gemini-2.0-flash-thinking-with-apps',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = Liaobots
|
best_provider = Gemini
|
||||||
)
|
)
|
||||||
|
|
||||||
### Anthropic ###
|
### Anthropic ###
|
||||||
@@ -379,19 +380,6 @@ claude_3_haiku = Model(
|
|||||||
best_provider = IterListProvider([DDG, Jmuz])
|
best_provider = IterListProvider([DDG, Jmuz])
|
||||||
)
|
)
|
||||||
|
|
||||||
claude_3_sonnet = Model(
|
|
||||||
name = 'claude-3-sonnet',
|
|
||||||
base_provider = 'Anthropic',
|
|
||||||
best_provider = Liaobots
|
|
||||||
)
|
|
||||||
|
|
||||||
claude_3_opus = Model(
|
|
||||||
name = 'claude-3-opus',
|
|
||||||
base_provider = 'Anthropic',
|
|
||||||
best_provider = IterListProvider([Jmuz, Liaobots])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# claude 3.5
|
# claude 3.5
|
||||||
claude_3_5_sonnet = Model(
|
claude_3_5_sonnet = Model(
|
||||||
name = 'claude-3.5-sonnet',
|
name = 'claude-3.5-sonnet',
|
||||||
@@ -406,12 +394,6 @@ claude_3_7_sonnet = Model(
|
|||||||
best_provider = IterListProvider([Blackbox, Liaobots])
|
best_provider = IterListProvider([Blackbox, Liaobots])
|
||||||
)
|
)
|
||||||
|
|
||||||
claude_3_7_sonnet_thinking = Model(
|
|
||||||
name = 'claude-3.7-sonnet-thinking',
|
|
||||||
base_provider = 'Anthropic',
|
|
||||||
best_provider = Liaobots
|
|
||||||
)
|
|
||||||
|
|
||||||
### Reka AI ###
|
### Reka AI ###
|
||||||
reka_core = Model(
|
reka_core = Model(
|
||||||
name = 'reka-core',
|
name = 'reka-core',
|
||||||
@@ -548,13 +530,13 @@ janus_pro_7b = VisionModel(
|
|||||||
grok_3 = Model(
|
grok_3 = Model(
|
||||||
name = 'grok-3',
|
name = 'grok-3',
|
||||||
base_provider = 'x.ai',
|
base_provider = 'x.ai',
|
||||||
best_provider = Liaobots
|
best_provider = Grok
|
||||||
)
|
)
|
||||||
|
|
||||||
grok_3_r1 = Model(
|
grok_3_r1 = Model(
|
||||||
name = 'grok-3-r1',
|
name = 'grok-3-r1',
|
||||||
base_provider = 'x.ai',
|
base_provider = 'x.ai',
|
||||||
best_provider = Liaobots
|
best_provider = Grok
|
||||||
)
|
)
|
||||||
|
|
||||||
### Perplexity AI ###
|
### Perplexity AI ###
|
||||||
@@ -841,12 +823,10 @@ class ModelUtils:
|
|||||||
gemini_1_5_flash.name: gemini_1_5_flash,
|
gemini_1_5_flash.name: gemini_1_5_flash,
|
||||||
gemini_2_0_flash.name: gemini_2_0_flash,
|
gemini_2_0_flash.name: gemini_2_0_flash,
|
||||||
gemini_2_0_flash_thinking.name: gemini_2_0_flash_thinking,
|
gemini_2_0_flash_thinking.name: gemini_2_0_flash_thinking,
|
||||||
gemini_2_0_pro.name: gemini_2_0_pro,
|
gemini_2_0_flash_thinking_with_apps.name: gemini_2_0_flash_thinking_with_apps,
|
||||||
|
|
||||||
### Anthropic ###
|
### Anthropic ###
|
||||||
# claude 3
|
# claude 3
|
||||||
claude_3_opus.name: claude_3_opus,
|
|
||||||
claude_3_sonnet.name: claude_3_sonnet,
|
|
||||||
claude_3_haiku.name: claude_3_haiku,
|
claude_3_haiku.name: claude_3_haiku,
|
||||||
|
|
||||||
# claude 3.5
|
# claude 3.5
|
||||||
@@ -854,7 +834,6 @@ class ModelUtils:
|
|||||||
|
|
||||||
# claude 3.7
|
# claude 3.7
|
||||||
claude_3_7_sonnet.name: claude_3_7_sonnet,
|
claude_3_7_sonnet.name: claude_3_7_sonnet,
|
||||||
claude_3_7_sonnet_thinking.name: claude_3_7_sonnet_thinking,
|
|
||||||
|
|
||||||
### Reka AI ###
|
### Reka AI ###
|
||||||
reka_core.name: reka_core,
|
reka_core.name: reka_core,
|
||||||
|
@@ -366,11 +366,15 @@ class ProviderModelMixin:
|
|||||||
class RaiseErrorMixin():
|
class RaiseErrorMixin():
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def raise_error(data: dict):
|
def raise_error(data: dict, status: int = None):
|
||||||
if "error_message" in data:
|
if "error_message" in data:
|
||||||
raise ResponseError(data["error_message"])
|
raise ResponseError(data["error_message"])
|
||||||
elif "error" in data:
|
elif "error" in data:
|
||||||
if isinstance(data["error"], str):
|
if isinstance(data["error"], str):
|
||||||
|
if status is not None:
|
||||||
|
if status in (401, 402):
|
||||||
|
raise MissingAuthError(f"Error {status}: {data['error']}")
|
||||||
|
raise ResponseError(f"Error {status}: {data['error']}")
|
||||||
raise ResponseError(data["error"])
|
raise ResponseError(data["error"])
|
||||||
elif "code" in data["error"]:
|
elif "code" in data["error"]:
|
||||||
raise ResponseError("\n".join(
|
raise ResponseError("\n".join(
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Union
|
|||||||
from aiohttp import ClientResponse
|
from aiohttp import ClientResponse
|
||||||
from requests import Response as RequestsResponse
|
from requests import Response as RequestsResponse
|
||||||
|
|
||||||
from ..errors import ResponseStatusError, RateLimitError
|
from ..errors import ResponseStatusError, RateLimitError, MissingAuthError
|
||||||
from . import Response, StreamResponse
|
from . import Response, StreamResponse
|
||||||
|
|
||||||
class CloudflareError(ResponseStatusError):
|
class CloudflareError(ResponseStatusError):
|
||||||
@@ -23,6 +23,7 @@ def is_openai(text: str) -> bool:
|
|||||||
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
|
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
|
||||||
if response.ok:
|
if response.ok:
|
||||||
return
|
return
|
||||||
|
is_html = False
|
||||||
if message is None:
|
if message is None:
|
||||||
content_type = response.headers.get("content-type", "")
|
content_type = response.headers.get("content-type", "")
|
||||||
if content_type.startswith("application/json"):
|
if content_type.startswith("application/json"):
|
||||||
@@ -31,39 +32,42 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
|||||||
if isinstance(message, dict):
|
if isinstance(message, dict):
|
||||||
message = message.get("message", message)
|
message = message.get("message", message)
|
||||||
else:
|
else:
|
||||||
text = (await response.text()).strip()
|
message = (await response.text()).strip()
|
||||||
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
|
is_html = content_type.startswith("text/html") or text.startswith("<!DOCTYPE")
|
||||||
message = "HTML content" if is_html else text
|
if message is None or is_html:
|
||||||
if message is None or message == "HTML content":
|
|
||||||
if response.status == 520:
|
if response.status == 520:
|
||||||
message = "Unknown error (Cloudflare)"
|
message = "Unknown error (Cloudflare)"
|
||||||
elif response.status in (429, 402):
|
elif response.status in (429, 402):
|
||||||
message = "Rate limit"
|
message = "Rate limit"
|
||||||
if response.status == 403 and is_cloudflare(text):
|
if response.status in (401, 402):
|
||||||
|
raise MissingAuthError(f"Response {response.status}: {message}")
|
||||||
|
if response.status == 403 and is_cloudflare(message):
|
||||||
raise CloudflareError(f"Response {response.status}: Cloudflare detected")
|
raise CloudflareError(f"Response {response.status}: Cloudflare detected")
|
||||||
elif response.status == 403 and is_openai(text):
|
elif response.status == 403 and is_openai(message):
|
||||||
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
|
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
|
||||||
elif response.status == 502:
|
elif response.status == 502:
|
||||||
raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
|
raise ResponseStatusError(f"Response {response.status}: Bad Gateway")
|
||||||
elif response.status == 504:
|
elif response.status == 504:
|
||||||
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
|
raise RateLimitError(f"Response {response.status}: Gateway Timeout ")
|
||||||
else:
|
else:
|
||||||
raise ResponseStatusError(f"Response {response.status}: {message}")
|
raise ResponseStatusError(f"Response {response.status}: {"HTML content" if is_html else message}")
|
||||||
|
|
||||||
def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None):
|
def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None):
|
||||||
if hasattr(response, "status"):
|
if hasattr(response, "status"):
|
||||||
return raise_for_status_async(response, message)
|
return raise_for_status_async(response, message)
|
||||||
if response.ok:
|
if response.ok:
|
||||||
return
|
return
|
||||||
|
is_html = False
|
||||||
if message is None:
|
if message is None:
|
||||||
is_html = response.headers.get("content-type", "").startswith("text/html") or response.text.startswith("<!DOCTYPE")
|
is_html = response.headers.get("content-type", "").startswith("text/html") or response.text.startswith("<!DOCTYPE")
|
||||||
message = "HTML content" if is_html else response.text
|
message = response.text
|
||||||
if message == "HTML content":
|
if message is None or is_html:
|
||||||
if response.status_code == 520:
|
if response.status_code == 520:
|
||||||
message = "Unknown error (Cloudflare)"
|
message = "Unknown error (Cloudflare)"
|
||||||
elif response.status_code in (429, 402):
|
elif response.status_code in (429, 402):
|
||||||
message = "Rate limit"
|
raise RateLimitError(f"Response {response.status_code}: Rate Limit")
|
||||||
raise RateLimitError(f"Response {response.status_code}: {message}")
|
if response.status_code in (401, 402):
|
||||||
|
raise MissingAuthError(f"Response {response.status_code}: {message}")
|
||||||
if response.status_code == 403 and is_cloudflare(response.text):
|
if response.status_code == 403 and is_cloudflare(response.text):
|
||||||
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
|
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
|
||||||
elif response.status_code == 403 and is_openai(response.text):
|
elif response.status_code == 403 and is_openai(response.text):
|
||||||
@@ -73,4 +77,4 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
|
|||||||
elif response.status_code == 504:
|
elif response.status_code == 504:
|
||||||
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
|
raise RateLimitError(f"Response {response.status_code}: Gateway Timeout ")
|
||||||
else:
|
else:
|
||||||
raise ResponseStatusError(f"Response {response.status_code}: {message}")
|
raise ResponseStatusError(f"Response {response.status_code}: {"HTML content" if is_html else message}")
|
Reference in New Issue
Block a user