mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Enable FastDeploy to support adding the “--api-key” authentication parameter. (#4806)
* add api key initial commit * add unit test * modify unit test * move middleware to a single file and add unit tests
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
"""# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -38,6 +37,7 @@ from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.expert_service import ExpertService
|
||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -266,6 +266,12 @@ app.add_exception_handler(Exception, ExceptionHandler.handle_exception)
|
||||
instrument(app)
|
||||
|
||||
|
||||
env_api_key_func = environment_variables.get("FD_API_KEY")
|
||||
env_tokens = env_api_key_func() if env_api_key_func else []
|
||||
if tokens := [key for key in (args.api_key or env_tokens) if key]:
|
||||
app.add_middleware(AuthenticationMiddleware, tokens)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection_manager():
|
||||
"""
|
||||
|
||||
55
fastdeploy/entrypoints/openai/middleware.py
Normal file
55
fastdeploy/entrypoints/openai/middleware.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are two cases in which authentication is skipped:
|
||||
1. The HTTP method is OPTIONS.
|
||||
2. The request path doesn't start with /v1 (e.g. /health).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
|
||||
# scope["type"] can be "lifespan" or "startup" for example,
|
||||
# in which case we don't need to do anything
|
||||
return self.app(scope, receive, send)
|
||||
root_path = scope.get("root_path", "")
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
||||
return response(scope, receive, send)
|
||||
return self.app(scope, receive, send)
|
||||
@@ -239,5 +239,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
help="Workers silent for more than this many seconds are killed and restarted.Value is a positive number or 0. Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.",
|
||||
)
|
||||
|
||||
parser.add_argument("--api-key", type=str, action="append", help="API_KEY required for service authentication")
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
@@ -130,6 +130,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")),
|
||||
# Count for cache_transfer_manager process error
|
||||
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),
|
||||
# API_KEY required for service authentication
|
||||
"FD_API_KEY": lambda: [] if "FD_API_KEY" not in os.environ else os.environ["FD_API_KEY"].split(","),
|
||||
# EPLB related
|
||||
"FD_ENABLE_REDUNDANT_EXPERTS": lambda: int(os.getenv("FD_ENABLE_REDUNDANT_EXPERTS", "0")) == 1,
|
||||
"FD_REDUNDANT_EXPERTS_NUM": lambda: int(os.getenv("FD_REDUNDANT_EXPERTS_NUM", "0")),
|
||||
|
||||
231
tests/e2e/test_api_key.py
Normal file
231
tests/e2e/test_api_key.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]
|
||||
|
||||
current_server_process: Optional[subprocess.Popen] = None
|
||||
|
||||
|
||||
def is_port_open(host: str, port: int, timeout=1.0):
|
||||
"""
|
||||
Check if a TCP port is open on the given host.
|
||||
Returns True if connection succeeds, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def kill_process_on_port(port: int):
|
||||
"""
|
||||
Kill processes that are listening on the given port.
|
||||
Uses `lsof` to find process ids and sends SIGKILL.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||
current_pid = os.getpid()
|
||||
parent_pid = os.getppid()
|
||||
for pid in output.splitlines():
|
||||
pid = int(pid)
|
||||
if pid in (current_pid, parent_pid):
|
||||
print(f"Skip killing current process (pid={pid}) on port {port}")
|
||||
continue
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
print(f"Killed process on port {port}, pid={pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
def clean_ports():
|
||||
"""
|
||||
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||
"""
|
||||
for port in PORTS_TO_CLEAN:
|
||||
kill_process_on_port(port)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def start_api_server(api_key_cli: Optional[list[str]] = None, api_key_env: Optional[str] = None):
|
||||
global current_server_process
|
||||
clean_ports()
|
||||
|
||||
env = os.environ.copy()
|
||||
if api_key_env is not None:
|
||||
env["FD_API_KEY"] = api_key_env
|
||||
else:
|
||||
env.pop("FD_API_KEY", None)
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "./ERNIE-4.5-0.3B-Paddle"
|
||||
log_path = "server.log"
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT),
|
||||
"--max-model-len",
|
||||
"32768",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--quantization",
|
||||
"wint4",
|
||||
"--graph-optimization-config",
|
||||
'{"cudagraph_capture_sizes": [1], "use_cudagraph":true}',
|
||||
]
|
||||
|
||||
if api_key_cli is not None:
|
||||
for key in api_key_cli:
|
||||
cmd.extend(["--api-key", key])
|
||||
|
||||
with open(log_path, "w") as logfile:
|
||||
process = subprocess.Popen(cmd, stdout=logfile, stderr=subprocess.STDOUT, start_new_session=True, env=env)
|
||||
|
||||
for _ in range(300):
|
||||
if is_port_open("127.0.0.1", FD_API_PORT):
|
||||
print(f"API server started (port: {FD_API_PORT}, cli_key: {api_key_cli}, env_key: {api_key_env})")
|
||||
current_server_process = process
|
||||
return process
|
||||
time.sleep(1)
|
||||
else:
|
||||
if process.poll() is None:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
raise RuntimeError(f"API server failed to start in 5 minutes (port: {FD_API_PORT})")
|
||||
|
||||
|
||||
def stop_api_server():
|
||||
global current_server_process
|
||||
if current_server_process and current_server_process.poll() is None:
|
||||
try:
|
||||
os.killpg(current_server_process.pid, signal.SIGTERM)
|
||||
current_server_process.wait(timeout=10)
|
||||
print(f"API server stopped (pid: {current_server_process.pid})")
|
||||
except Exception as e:
|
||||
print(f"Failed to stop server: {e}")
|
||||
current_server_process = None
|
||||
clean_ports()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def teardown_server():
|
||||
yield
|
||||
stop_api_server()
|
||||
os.environ.pop("FD_API_KEY", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def api_url():
|
||||
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_headers():
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_auth_headers():
|
||||
return {"Content-Type": "application/json", "Authorization": "Bearer {api_key}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_payload():
|
||||
return {"messages": [{"role": "user", "content": "hello"}], "temperature": 0.9, "max_tokens": 100}
|
||||
|
||||
|
||||
def test_api_key_cli_only(api_url, common_headers, valid_auth_headers, test_payload):
|
||||
test_api_key = ["cli_test_key_123", "cli_test_key_456"]
|
||||
start_api_server(api_key_cli=test_api_key)
|
||||
|
||||
response = requests.post(api_url, json=test_payload, headers=common_headers)
|
||||
assert response.status_code == 401
|
||||
assert "error" in response.json()
|
||||
assert "unauthorized" in response.json()["error"].lower()
|
||||
|
||||
invalid_headers = valid_auth_headers.copy()
|
||||
invalid_headers["Authorization"] = invalid_headers["Authorization"].format(api_key="wrong_key")
|
||||
response = requests.post(api_url, json=test_payload, headers=invalid_headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
valid_headers = valid_auth_headers.copy()
|
||||
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key=test_api_key[0])
|
||||
response = requests.post(api_url, json=test_payload, headers=valid_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
valid_headers = valid_auth_headers.copy()
|
||||
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key=test_api_key[1])
|
||||
response = requests.post(api_url, json=test_payload, headers=valid_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_api_key_env_only(api_url, common_headers, valid_auth_headers, test_payload):
|
||||
test_api_key = "env_test_key_456,env_test_key_789"
|
||||
start_api_server(api_key_env=test_api_key)
|
||||
|
||||
response = requests.post(api_url, json=test_payload, headers=common_headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
valid_headers = valid_auth_headers.copy()
|
||||
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key="env_test_key_456")
|
||||
response = requests.post(api_url, json=test_payload, headers=valid_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
valid_headers = valid_auth_headers.copy()
|
||||
valid_headers["Authorization"] = valid_headers["Authorization"].format(api_key="env_test_key_789")
|
||||
response = requests.post(api_url, json=test_payload, headers=valid_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_api_key_cli_priority_over_env(api_url, valid_auth_headers, test_payload):
|
||||
cli_key = ["cli_priority_key_789"]
|
||||
env_key = "env_low_priority_key_000"
|
||||
start_api_server(api_key_cli=cli_key, api_key_env=env_key)
|
||||
|
||||
env_headers = valid_auth_headers.copy()
|
||||
env_headers["Authorization"] = env_headers["Authorization"].format(api_key=env_key)
|
||||
response = requests.post(api_url, json=test_payload, headers=env_headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
cli_headers = valid_auth_headers.copy()
|
||||
cli_headers["Authorization"] = cli_headers["Authorization"].format(api_key=cli_key[0])
|
||||
response = requests.post(api_url, json=test_payload, headers=cli_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_api_key_not_set(api_url, common_headers, valid_auth_headers, test_payload):
|
||||
start_api_server(api_key_cli=None, api_key_env=None)
|
||||
|
||||
response = requests.post(api_url, json=test_payload, headers=common_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
cli_headers = valid_auth_headers.copy()
|
||||
cli_headers["Authorization"] = cli_headers["Authorization"].format(api_key="some_api_key")
|
||||
response = requests.post(api_url, json=test_payload, headers=cli_headers)
|
||||
assert response.status_code == 200
|
||||
155
tests/entrypoints/openai/test_api_authentication.py
Normal file
155
tests/entrypoints/openai/test_api_authentication.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import secrets
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
||||
|
||||
|
||||
def mock_asgi_app() -> tuple[ASGIApp, Mock]:
|
||||
mock_send = Mock()
|
||||
|
||||
async def mock_app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await asyncio.sleep(0)
|
||||
mock_send(scope=scope, receive=receive, send=send)
|
||||
|
||||
return mock_app, mock_send
|
||||
|
||||
|
||||
def create_test_scope(
|
||||
path: str = "/v1/chat/completions", method: str = "POST", headers: dict = None, scope_type: str = "http"
|
||||
) -> Scope:
|
||||
headers = headers or {}
|
||||
scope_headers = []
|
||||
for k, v in headers.items():
|
||||
key_bytes = str(k).lower().encode("latin-1")
|
||||
value_bytes = str(v).lower().encode("latin-1")
|
||||
scope_headers.append((key_bytes, value_bytes))
|
||||
return {
|
||||
"type": scope_type,
|
||||
"method": method,
|
||||
"headers": scope_headers,
|
||||
"path": path,
|
||||
"root_path": "",
|
||||
}
|
||||
|
||||
|
||||
class TestAuthenticationMiddleware:
|
||||
VALID_TOKENS = ["test_key_123", "another_valid_key_456"]
|
||||
INVALID_TOKEN = "wrong_key_789"
|
||||
EXPECTED_TOKEN_HASHES = [hashlib.sha256(t.encode("utf-8")).digest() for t in VALID_TOKENS]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_middleware(self):
|
||||
self.mock_app, self.mock_send = mock_asgi_app()
|
||||
self.middleware = AuthenticationMiddleware(app=self.mock_app, tokens=self.VALID_TOKENS)
|
||||
|
||||
def test_verify_token_no_authorization_header(self):
|
||||
headers = Headers()
|
||||
assert self.middleware.verify_token(headers) is False
|
||||
|
||||
def test_verify_token_invalid_scheme(self):
|
||||
headers = Headers({"Authorization": "Basic wrong_scheme"})
|
||||
assert self.middleware.verify_token(headers) is False
|
||||
|
||||
def test_verify_token_valid_token(self):
|
||||
for valid_token in self.VALID_TOKENS:
|
||||
headers = Headers({"Authorization": f"Bearer {valid_token}"})
|
||||
assert self.middleware.verify_token(headers) is True
|
||||
|
||||
def test_verify_token_invalid_token(self):
|
||||
headers = Headers({"Authorization": f"Bearer {self.INVALID_TOKEN}"})
|
||||
assert self.middleware.verify_token(headers) is False
|
||||
|
||||
def test_verify_token_hash_comparison(self):
|
||||
valid_token = self.VALID_TOKENS[0]
|
||||
param_hash = hashlib.sha256(valid_token.encode("utf-8")).digest()
|
||||
|
||||
assert self.middleware.api_tokens == self.EXPECTED_TOKEN_HASHES
|
||||
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mock_compare = Mock(return_value=True)
|
||||
mp.setattr(secrets, "compare_digest", mock_compare)
|
||||
|
||||
headers = Headers({"Authorization": f"Bearer {valid_token}"})
|
||||
self.middleware.verify_token(headers)
|
||||
assert mock_compare.call_count == len(self.EXPECTED_TOKEN_HASHES)
|
||||
mock_compare.assert_any_call(param_hash, self.EXPECTED_TOKEN_HASHES[0])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_skip_non_v1_path(self):
|
||||
for path in ["/health", "/metrics", "/docs"]:
|
||||
scope = create_test_scope(path=path)
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await self.middleware(scope, receive, send)
|
||||
|
||||
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
||||
self.mock_send.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_skip_options_method(self):
|
||||
scope = create_test_scope(method="OPTIONS", path="/v1/chat/completions")
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await self.middleware(scope, receive, send)
|
||||
|
||||
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_skip_non_http_websocket_scope(self):
|
||||
for scope_type in ["lifespan", "startup", "shutdown"]:
|
||||
scope = create_test_scope(scope_type=scope_type)
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await self.middleware(scope, receive, send)
|
||||
|
||||
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
||||
self.mock_send.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_v1_path_valid_token(self):
|
||||
scope = create_test_scope(headers={"Authorization": f"Bearer {self.VALID_TOKENS[0]}"})
|
||||
print(scope)
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
assert self.middleware.verify_token(headers) is True
|
||||
await self.middleware(scope, receive, send)
|
||||
|
||||
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_v1_path_invalid_token(self):
|
||||
scope = create_test_scope(headers={"Authorization": f"Bearer {self.INVALID_TOKEN}"})
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await self.middleware(scope, receive, send)
|
||||
|
||||
self.mock_send.assert_not_called()
|
||||
assert send.called
|
||||
send_call = send.call_args[0][0]
|
||||
assert isinstance(send_call, dict)
|
||||
assert "Unauthorized" in send_call["body"].decode("utf-8")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_v1_path_no_token(self):
|
||||
scope = create_test_scope(headers={})
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await self.middleware(scope, receive, send)
|
||||
|
||||
self.mock_send.assert_not_called()
|
||||
send_call = send.call_args[0][0]
|
||||
assert isinstance(send_call, dict)
|
||||
assert "Unauthorized" in send_call["body"].decode("utf-8")
|
||||
@@ -35,6 +35,7 @@ with (
|
||||
scheduler_db=None,
|
||||
scheduler_password=None,
|
||||
scheduler_topic=None,
|
||||
api_key=None,
|
||||
)
|
||||
mock_parse_args.return_value = mock_args
|
||||
mock_retrive_model.return_value = "test-model" # Just return the model name without downloading
|
||||
|
||||
Reference in New Issue
Block a user