mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-05 08:16:58 +08:00
Add secret validation
This commit is contained in:
@@ -3,11 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import flask
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
import shutil
|
||||
import random
|
||||
import datetime
|
||||
import base64
|
||||
from urllib.parse import quote_plus
|
||||
from flask import Flask, Response, redirect, request, jsonify, send_from_directory
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -26,6 +28,12 @@ try:
|
||||
has_markitdown = True
|
||||
except ImportError as e:
|
||||
has_markitdown = False
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
has_crypto = True
|
||||
except ImportError:
|
||||
has_crypto = False
|
||||
|
||||
from ...client.service import convert_to_provider
|
||||
from ...providers.asyncio import to_sync_generator
|
||||
@@ -39,7 +47,6 @@ from ...cookies import get_cookies_dir
|
||||
from ...image.copy_images import secure_filename, get_source_url, get_media_dir, copy_media
|
||||
from ... import ChatCompletion
|
||||
from ... import models
|
||||
from ... import debug
|
||||
from .api import Api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -72,6 +79,43 @@ class Backend_Api(Api):
|
||||
self.app: Flask = app
|
||||
self.chat_cache = {}
|
||||
|
||||
if has_crypto:
|
||||
private_key_obj = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_key_obj = private_key_obj.public_key()
|
||||
public_key_pem = public_key_obj.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
def decrypt_data(encrypted_data: str):
|
||||
decrypted = private_key_obj.decrypt(
|
||||
base64.b64decode(encrypted_data),
|
||||
padding.PKCS1v15()
|
||||
)
|
||||
return decrypted.decode('utf-8')
|
||||
|
||||
def validate_secret(secret: str) -> bool:
|
||||
"""
|
||||
Validates the provided secret against the stored public key.
|
||||
|
||||
Args:
|
||||
secret (str): The secret to validate.
|
||||
|
||||
Returns:
|
||||
bool: True if the secret is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
decrypted_secret = decrypt_data(secret)
|
||||
return int(decrypted_secret) >= time.time() - 2
|
||||
except Exception as e:
|
||||
logger.error(f"Secret validation failed: {e}")
|
||||
return False
|
||||
|
||||
@app.route('/backend-api/v2/public-key', methods=['GET'])
|
||||
def get_public_key():
|
||||
# Send the public key to the client for encryption
|
||||
return jsonify({"public_key": public_key_pem.decode('utf-8'), "data": str(int(time.time()))})
|
||||
|
||||
@app.route('/backend-api/v2/models', methods=['GET'])
|
||||
def jsonify_models(**kwargs):
|
||||
response = get_demo_models() if app.demo else self.get_models(**kwargs)
|
||||
@@ -110,9 +154,18 @@ class Backend_Api(Api):
|
||||
Response: A Flask response object for streaming.
|
||||
"""
|
||||
if "json" in request.form:
|
||||
json_data = json.loads(request.form['json'])
|
||||
json_data = request.form['json']
|
||||
else:
|
||||
json_data = request.json
|
||||
json_data = request.data
|
||||
try:
|
||||
json_data = json.loads(json_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.exception(e)
|
||||
return jsonify({"error": {"message": "Invalid JSON data"}}), 400
|
||||
if app.demo and has_crypto:
|
||||
secret = request.headers.get("x_secret")
|
||||
if not secret or not validate_secret(secret):
|
||||
return jsonify({"error": {"message": "Invalid or missing secret"}}), 403
|
||||
tempfiles = []
|
||||
media = []
|
||||
if "files" in request.files:
|
||||
@@ -136,17 +189,10 @@ class Backend_Api(Api):
|
||||
json_data["provider"] = models.HuggingFace
|
||||
if app.demo:
|
||||
ip = request.headers.get("X-Forwarded-For", "")
|
||||
ip_bans = Path(get_cookies_dir()) / ".ip_bans"
|
||||
if ip_bans.exists():
|
||||
ip_bans = ip_bans.read_text().splitlines()
|
||||
if (ip and ip in ip_bans):
|
||||
return "You are banned from using this service.", 403
|
||||
user = request.headers.get("Cf-Ipcountry", "")
|
||||
json_data["user"] = request.headers.get("x_user", f"{user}:{ip}")
|
||||
json_data["referer"] = request.headers.get("referer", "")
|
||||
json_data["user-agent"] = request.headers.get("user-agent", "")
|
||||
if not json_data.get("referer") or "python" in json_data.get("user-agent", "").lower():
|
||||
return "Reduce your requests to 2 at the same time. I only have a budget of 4. More requests cause errors in the console.", 403
|
||||
kwargs = self._prepare_conversation_kwargs(json_data)
|
||||
return self.app.response_class(
|
||||
safe_iter_generator(self._create_response_stream(
|
||||
|
Reference in New Issue
Block a user