Add ip ban

This commit is contained in:
hlohaus
2025-06-24 22:25:42 +02:00
parent 3ce8edce27
commit 744dfeb957
4 changed files with 28 additions and 10 deletions

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import base64 import base64
import json import json
import requests import requests
import random
from typing import Optional from typing import Optional
from aiohttp import ClientSession, BaseConnector from aiohttp import ClientSession, BaseConnector
@@ -109,10 +108,10 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
] ]
if media is not None: if media is not None:
for media_data, filename in media: for media_data, filename in media:
image = to_bytes(image) media_data = to_bytes(media_data)
contents[-1]["parts"].append({ contents[-1]["parts"].append({
"inline_data": { "inline_data": {
"mime_type": is_data_an_media(image, filename), "mime_type": is_data_an_media(media_data, filename),
"data": base64.b64encode(media_data).decode() "data": base64.b64encode(media_data).decode()
} }
}) })

View File

@@ -135,9 +135,14 @@ class Backend_Api(Api):
else: else:
json_data["provider"] = models.HuggingFace json_data["provider"] = models.HuggingFace
if app.demo: if app.demo:
ip = request.headers.get("X-Forwarded-For", "")
ip_bans = Path(get_cookies_dir()) / ".ip_bans"
if ip_bans.exists():
ip_bans = ip_bans.read_text().splitlines()
if (ip and ip in ip_bans):
return "You are banned from using this service.", 403
user = request.headers.get("Cf-Ipcountry", "") user = request.headers.get("Cf-Ipcountry", "")
ip = request.headers.get("X-Forwarded-For", "").split(":")[-1] json_data["user"] = request.headers.get("x_user", f"{user}:{ip.split(':')[-1]}")
json_data["user"] = request.headers.get("x_user", f"{user}:{ip}")
json_data["referer"] = request.headers.get("referer", "") json_data["referer"] = request.headers.get("referer", "")
json_data["user-agent"] = request.headers.get("user-agent", "") json_data["user-agent"] = request.headers.get("user-agent", "")
if not json_data.get("referer") or "python" in json_data.get("user-agent", "").lower(): if not json_data.get("referer") or "python" in json_data.get("user-agent", "").lower():

View File

@@ -7,6 +7,7 @@ import base64
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import urlparse
try: try:
from PIL import Image, ImageOps from PIL import Image, ImageOps
@@ -16,6 +17,7 @@ except ImportError:
from ..typing import ImageType from ..typing import ImageType
from ..errors import MissingRequirementsError from ..errors import MissingRequirementsError
from ..tools.files import get_bucket_dir
EXTENSIONS_MAP: dict[str, str] = { EXTENSIONS_MAP: dict[str, str] = {
# Image # Image
@@ -241,15 +243,26 @@ def to_bytes(image: ImageType) -> bytes:
""" """
if isinstance(image, bytes): if isinstance(image, bytes):
return image return image
elif isinstance(image, str) and image.startswith("data:"): elif isinstance(image, str):
is_data_an_media(image) if image.startswith("data:"):
return extract_data_uri(image) is_data_uri_an_image(image)
return extract_data_uri(image)
elif image.startswith("http://") or image.startswith("https://"):
path: str = urlparse(image).path
if path.startswith("/files/"):
path = get_bucket_dir(path.split(path, "/")[1:])
if os.path.exists(path):
return Path(path).read_bytes()
else:
raise FileNotFoundError(f"File not found: {path}")
else:
raise ValueError("Invalid image format. Expected bytes, str, or PIL Image.")
elif isinstance(image, Image): elif isinstance(image, Image):
bytes_io = BytesIO() bytes_io = BytesIO()
image.save(bytes_io, image.format) image.save(bytes_io, image.format)
image.seek(0) image.seek(0)
return bytes_io.getvalue() return bytes_io.getvalue()
elif isinstance(image, (str, os.PathLike)): elif isinstance(image, os.PathLike):
return Path(image).read_bytes() return Path(image).read_bytes()
elif isinstance(image, Path): elif isinstance(image, Path):
return image.read_bytes() return image.read_bytes()

View File

@@ -70,7 +70,8 @@ async def save_response_media(response, prompt: str, tags: list[str]) -> AsyncIt
raise ValueError(f"Unsupported media type: {content_type}") raise ValueError(f"Unsupported media type: {content_type}")
filename = get_filename(tags, prompt, f".{extension}", prompt) filename = get_filename(tags, prompt, f".{extension}", prompt)
filename = update_filename(response, filename) if hasattr(response, "headers"):
filename = update_filename(response, filename)
target_path = os.path.join(get_media_dir(), filename) target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir() ensure_media_dir()
with open(target_path, 'wb') as f: with open(target_path, 'wb') as f: