Add secret validation

This commit is contained in:
hlohaus
2025-06-25 21:30:58 +02:00
parent c466d414c9
commit 7b5ecaa4fb

View File

@@ -3,11 +3,13 @@ from __future__ import annotations
import json
import flask
import os
import time
import logging
import asyncio
import shutil
import random
import datetime
import base64
from urllib.parse import quote_plus
from flask import Flask, Response, redirect, request, jsonify, send_from_directory
from werkzeug.exceptions import NotFound
@@ -26,6 +28,12 @@ try:
has_markitdown = True
except ImportError as e:
has_markitdown = False
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
has_crypto = True
except ImportError:
has_crypto = False
from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator
@@ -39,7 +47,6 @@ from ...cookies import get_cookies_dir
from ...image.copy_images import secure_filename, get_source_url, get_media_dir, copy_media
from ... import ChatCompletion
from ... import models
from ... import debug
from .api import Api
logger = logging.getLogger(__name__)
@@ -72,6 +79,43 @@ class Backend_Api(Api):
self.app: Flask = app
self.chat_cache = {}
if has_crypto:
private_key_obj = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key_obj = private_key_obj.public_key()
public_key_pem = public_key_obj.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def decrypt_data(encrypted_data: str):
decrypted = private_key_obj.decrypt(
base64.b64decode(encrypted_data),
padding.PKCS1v15()
)
return decrypted.decode('utf-8')
def validate_secret(secret: str) -> bool:
"""
Validates the provided secret against the stored public key.
Args:
secret (str): The secret to validate.
Returns:
bool: True if the secret is valid, False otherwise.
"""
try:
decrypted_secret = decrypt_data(secret)
return int(decrypted_secret) >= time.time() - 2
except Exception as e:
logger.error(f"Secret validation failed: {e}")
return False
@app.route('/backend-api/v2/public-key', methods=['GET'])
def get_public_key():
# Send the public key to the client for encryption
return jsonify({"public_key": public_key_pem.decode('utf-8'), "data": str(int(time.time()))})
@app.route('/backend-api/v2/models', methods=['GET'])
def jsonify_models(**kwargs):
response = get_demo_models() if app.demo else self.get_models(**kwargs)
@@ -110,9 +154,18 @@ class Backend_Api(Api):
Response: A Flask response object for streaming.
"""
if "json" in request.form:
json_data = json.loads(request.form['json'])
json_data = request.form['json']
else:
json_data = request.json
json_data = request.data
try:
json_data = json.loads(json_data)
except json.JSONDecodeError as e:
logger.exception(e)
return jsonify({"error": {"message": "Invalid JSON data"}}), 400
if app.demo and has_crypto:
secret = request.headers.get("x_secret")
if not secret or not validate_secret(secret):
return jsonify({"error": {"message": "Invalid or missing secret"}}), 403
tempfiles = []
media = []
if "files" in request.files:
@@ -136,17 +189,10 @@ class Backend_Api(Api):
json_data["provider"] = models.HuggingFace
if app.demo:
ip = request.headers.get("X-Forwarded-For", "")
ip_bans = Path(get_cookies_dir()) / ".ip_bans"
if ip_bans.exists():
ip_bans = ip_bans.read_text().splitlines()
if (ip and ip in ip_bans):
return "You are banned from using this service.", 403
user = request.headers.get("Cf-Ipcountry", "")
json_data["user"] = request.headers.get("x_user", f"{user}:{ip}")
json_data["referer"] = request.headers.get("referer", "")
json_data["user-agent"] = request.headers.get("user-agent", "")
if not json_data.get("referer") or "python" in json_data.get("user-agent", "").lower():
return "Reduce your requests to 2 at the same time. I only have a budget of 4. More requests cause errors in the console.", 403
kwargs = self._prepare_conversation_kwargs(json_data)
return self.app.response_class(
safe_iter_generator(self._create_response_stream(