mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* add api key initial commit * add unit test * modify unit test * move middleware to a single file and add unit tests
156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
import asyncio
|
|
import hashlib
|
|
import secrets
|
|
from unittest.mock import AsyncMock, Mock
|
|
|
|
import pytest
|
|
from starlette.datastructures import Headers
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
|
|
|
|
|
def mock_asgi_app() -> tuple[ASGIApp, Mock]:
|
|
mock_send = Mock()
|
|
|
|
async def mock_app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
await asyncio.sleep(0)
|
|
mock_send(scope=scope, receive=receive, send=send)
|
|
|
|
return mock_app, mock_send
|
|
|
|
|
|
def create_test_scope(
|
|
path: str = "/v1/chat/completions", method: str = "POST", headers: dict = None, scope_type: str = "http"
|
|
) -> Scope:
|
|
headers = headers or {}
|
|
scope_headers = []
|
|
for k, v in headers.items():
|
|
key_bytes = str(k).lower().encode("latin-1")
|
|
value_bytes = str(v).lower().encode("latin-1")
|
|
scope_headers.append((key_bytes, value_bytes))
|
|
return {
|
|
"type": scope_type,
|
|
"method": method,
|
|
"headers": scope_headers,
|
|
"path": path,
|
|
"root_path": "",
|
|
}
|
|
|
|
|
|
class TestAuthenticationMiddleware:
|
|
VALID_TOKENS = ["test_key_123", "another_valid_key_456"]
|
|
INVALID_TOKEN = "wrong_key_789"
|
|
EXPECTED_TOKEN_HASHES = [hashlib.sha256(t.encode("utf-8")).digest() for t in VALID_TOKENS]
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_middleware(self):
|
|
self.mock_app, self.mock_send = mock_asgi_app()
|
|
self.middleware = AuthenticationMiddleware(app=self.mock_app, tokens=self.VALID_TOKENS)
|
|
|
|
def test_verify_token_no_authorization_header(self):
|
|
headers = Headers()
|
|
assert self.middleware.verify_token(headers) is False
|
|
|
|
def test_verify_token_invalid_scheme(self):
|
|
headers = Headers({"Authorization": "Basic wrong_scheme"})
|
|
assert self.middleware.verify_token(headers) is False
|
|
|
|
def test_verify_token_valid_token(self):
|
|
for valid_token in self.VALID_TOKENS:
|
|
headers = Headers({"Authorization": f"Bearer {valid_token}"})
|
|
assert self.middleware.verify_token(headers) is True
|
|
|
|
def test_verify_token_invalid_token(self):
|
|
headers = Headers({"Authorization": f"Bearer {self.INVALID_TOKEN}"})
|
|
assert self.middleware.verify_token(headers) is False
|
|
|
|
def test_verify_token_hash_comparison(self):
|
|
valid_token = self.VALID_TOKENS[0]
|
|
param_hash = hashlib.sha256(valid_token.encode("utf-8")).digest()
|
|
|
|
assert self.middleware.api_tokens == self.EXPECTED_TOKEN_HASHES
|
|
|
|
with pytest.MonkeyPatch.context() as mp:
|
|
mock_compare = Mock(return_value=True)
|
|
mp.setattr(secrets, "compare_digest", mock_compare)
|
|
|
|
headers = Headers({"Authorization": f"Bearer {valid_token}"})
|
|
self.middleware.verify_token(headers)
|
|
assert mock_compare.call_count == len(self.EXPECTED_TOKEN_HASHES)
|
|
mock_compare.assert_any_call(param_hash, self.EXPECTED_TOKEN_HASHES[0])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_skip_non_v1_path(self):
|
|
for path in ["/health", "/metrics", "/docs"]:
|
|
scope = create_test_scope(path=path)
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await self.middleware(scope, receive, send)
|
|
|
|
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
|
self.mock_send.reset_mock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_skip_options_method(self):
|
|
scope = create_test_scope(method="OPTIONS", path="/v1/chat/completions")
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await self.middleware(scope, receive, send)
|
|
|
|
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_skip_non_http_websocket_scope(self):
|
|
for scope_type in ["lifespan", "startup", "shutdown"]:
|
|
scope = create_test_scope(scope_type=scope_type)
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await self.middleware(scope, receive, send)
|
|
|
|
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
|
self.mock_send.reset_mock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_v1_path_valid_token(self):
|
|
scope = create_test_scope(headers={"Authorization": f"Bearer {self.VALID_TOKENS[0]}"})
|
|
print(scope)
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
headers = Headers(scope=scope)
|
|
assert self.middleware.verify_token(headers) is True
|
|
await self.middleware(scope, receive, send)
|
|
|
|
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_v1_path_invalid_token(self):
|
|
scope = create_test_scope(headers={"Authorization": f"Bearer {self.INVALID_TOKEN}"})
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await self.middleware(scope, receive, send)
|
|
|
|
self.mock_send.assert_not_called()
|
|
assert send.called
|
|
send_call = send.call_args[0][0]
|
|
assert isinstance(send_call, dict)
|
|
assert "Unauthorized" in send_call["body"].decode("utf-8")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_v1_path_no_token(self):
|
|
scope = create_test_scope(headers={})
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await self.middleware(scope, receive, send)
|
|
|
|
self.mock_send.assert_not_called()
|
|
send_call = send.call_args[0][0]
|
|
assert isinstance(send_call, dict)
|
|
assert "Unauthorized" in send_call["body"].decode("utf-8")
|