Files
FastDeploy/tests/entrypoints/openai/test_api_authentication.py
kxz2002 87911b7cf1 [Feature] Enable FastDeploy to support adding the “--api-key” authentication parameter. (#4806)
* add api key initial commit

* add unit test

* modify unit test

* move middleware to a single file and add unit tests
2025-11-08 18:24:02 +08:00

156 lines
5.5 KiB
Python

import asyncio
import hashlib
import secrets
from unittest.mock import AsyncMock, Mock
import pytest
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
def mock_asgi_app() -> tuple[ASGIApp, Mock]:
mock_send = Mock()
async def mock_app(scope: Scope, receive: Receive, send: Send) -> None:
await asyncio.sleep(0)
mock_send(scope=scope, receive=receive, send=send)
return mock_app, mock_send
def create_test_scope(
path: str = "/v1/chat/completions", method: str = "POST", headers: dict = None, scope_type: str = "http"
) -> Scope:
headers = headers or {}
scope_headers = []
for k, v in headers.items():
key_bytes = str(k).lower().encode("latin-1")
value_bytes = str(v).lower().encode("latin-1")
scope_headers.append((key_bytes, value_bytes))
return {
"type": scope_type,
"method": method,
"headers": scope_headers,
"path": path,
"root_path": "",
}
class TestAuthenticationMiddleware:
VALID_TOKENS = ["test_key_123", "another_valid_key_456"]
INVALID_TOKEN = "wrong_key_789"
EXPECTED_TOKEN_HASHES = [hashlib.sha256(t.encode("utf-8")).digest() for t in VALID_TOKENS]
@pytest.fixture(autouse=True)
def setup_middleware(self):
self.mock_app, self.mock_send = mock_asgi_app()
self.middleware = AuthenticationMiddleware(app=self.mock_app, tokens=self.VALID_TOKENS)
def test_verify_token_no_authorization_header(self):
headers = Headers()
assert self.middleware.verify_token(headers) is False
def test_verify_token_invalid_scheme(self):
headers = Headers({"Authorization": "Basic wrong_scheme"})
assert self.middleware.verify_token(headers) is False
def test_verify_token_valid_token(self):
for valid_token in self.VALID_TOKENS:
headers = Headers({"Authorization": f"Bearer {valid_token}"})
assert self.middleware.verify_token(headers) is True
def test_verify_token_invalid_token(self):
headers = Headers({"Authorization": f"Bearer {self.INVALID_TOKEN}"})
assert self.middleware.verify_token(headers) is False
def test_verify_token_hash_comparison(self):
valid_token = self.VALID_TOKENS[0]
param_hash = hashlib.sha256(valid_token.encode("utf-8")).digest()
assert self.middleware.api_tokens == self.EXPECTED_TOKEN_HASHES
with pytest.MonkeyPatch.context() as mp:
mock_compare = Mock(return_value=True)
mp.setattr(secrets, "compare_digest", mock_compare)
headers = Headers({"Authorization": f"Bearer {valid_token}"})
self.middleware.verify_token(headers)
assert mock_compare.call_count == len(self.EXPECTED_TOKEN_HASHES)
mock_compare.assert_any_call(param_hash, self.EXPECTED_TOKEN_HASHES[0])
@pytest.mark.asyncio
async def test_call_skip_non_v1_path(self):
for path in ["/health", "/metrics", "/docs"]:
scope = create_test_scope(path=path)
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
self.mock_send.reset_mock()
@pytest.mark.asyncio
async def test_call_skip_options_method(self):
scope = create_test_scope(method="OPTIONS", path="/v1/chat/completions")
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
@pytest.mark.asyncio
async def test_call_skip_non_http_websocket_scope(self):
for scope_type in ["lifespan", "startup", "shutdown"]:
scope = create_test_scope(scope_type=scope_type)
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
self.mock_send.reset_mock()
@pytest.mark.asyncio
async def test_call_v1_path_valid_token(self):
scope = create_test_scope(headers={"Authorization": f"Bearer {self.VALID_TOKENS[0]}"})
print(scope)
receive = AsyncMock()
send = AsyncMock()
headers = Headers(scope=scope)
assert self.middleware.verify_token(headers) is True
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
@pytest.mark.asyncio
async def test_call_v1_path_invalid_token(self):
scope = create_test_scope(headers={"Authorization": f"Bearer {self.INVALID_TOKEN}"})
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_not_called()
assert send.called
send_call = send.call_args[0][0]
assert isinstance(send_call, dict)
assert "Unauthorized" in send_call["body"].decode("utf-8")
@pytest.mark.asyncio
async def test_call_v1_path_no_token(self):
scope = create_test_scope(headers={})
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_not_called()
send_call = send.call_args[0][0]
assert isinstance(send_call, dict)
assert "Unauthorized" in send_call["body"].decode("utf-8")