Files
FastDeploy/tests/entrypoints/openai/test_api_authentication.py
Echo-Nie 1b1bfab341 [CI] Add unittest (#5328)
* add test_worker_eplb

* remove tesnsor_wise_fp8

* add copyright
2025-12-09 19:19:42 +08:00

172 lines
6.1 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import hashlib
import secrets
from unittest.mock import AsyncMock, Mock
import pytest
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
def mock_asgi_app() -> tuple[ASGIApp, Mock]:
mock_send = Mock()
async def mock_app(scope: Scope, receive: Receive, send: Send) -> None:
await asyncio.sleep(0)
mock_send(scope=scope, receive=receive, send=send)
return mock_app, mock_send
def create_test_scope(
path: str = "/v1/chat/completions", method: str = "POST", headers: dict = None, scope_type: str = "http"
) -> Scope:
headers = headers or {}
scope_headers = []
for k, v in headers.items():
key_bytes = str(k).lower().encode("latin-1")
value_bytes = str(v).lower().encode("latin-1")
scope_headers.append((key_bytes, value_bytes))
return {
"type": scope_type,
"method": method,
"headers": scope_headers,
"path": path,
"root_path": "",
}
class TestAuthenticationMiddleware:
VALID_TOKENS = ["test_key_123", "another_valid_key_456"]
INVALID_TOKEN = "wrong_key_789"
EXPECTED_TOKEN_HASHES = [hashlib.sha256(t.encode("utf-8")).digest() for t in VALID_TOKENS]
@pytest.fixture(autouse=True)
def setup_middleware(self):
self.mock_app, self.mock_send = mock_asgi_app()
self.middleware = AuthenticationMiddleware(app=self.mock_app, tokens=self.VALID_TOKENS)
def test_verify_token_no_authorization_header(self):
headers = Headers()
assert self.middleware.verify_token(headers) is False
def test_verify_token_invalid_scheme(self):
headers = Headers({"Authorization": "Basic wrong_scheme"})
assert self.middleware.verify_token(headers) is False
def test_verify_token_valid_token(self):
for valid_token in self.VALID_TOKENS:
headers = Headers({"Authorization": f"Bearer {valid_token}"})
assert self.middleware.verify_token(headers) is True
def test_verify_token_invalid_token(self):
headers = Headers({"Authorization": f"Bearer {self.INVALID_TOKEN}"})
assert self.middleware.verify_token(headers) is False
def test_verify_token_hash_comparison(self):
valid_token = self.VALID_TOKENS[0]
param_hash = hashlib.sha256(valid_token.encode("utf-8")).digest()
assert self.middleware.api_tokens == self.EXPECTED_TOKEN_HASHES
with pytest.MonkeyPatch.context() as mp:
mock_compare = Mock(return_value=True)
mp.setattr(secrets, "compare_digest", mock_compare)
headers = Headers({"Authorization": f"Bearer {valid_token}"})
self.middleware.verify_token(headers)
assert mock_compare.call_count == len(self.EXPECTED_TOKEN_HASHES)
mock_compare.assert_any_call(param_hash, self.EXPECTED_TOKEN_HASHES[0])
@pytest.mark.asyncio
async def test_call_skip_non_v1_path(self):
for path in ["/health", "/metrics", "/docs"]:
scope = create_test_scope(path=path)
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
self.mock_send.reset_mock()
@pytest.mark.asyncio
async def test_call_skip_options_method(self):
scope = create_test_scope(method="OPTIONS", path="/v1/chat/completions")
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
@pytest.mark.asyncio
async def test_call_skip_non_http_websocket_scope(self):
for scope_type in ["lifespan", "startup", "shutdown"]:
scope = create_test_scope(scope_type=scope_type)
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
self.mock_send.reset_mock()
@pytest.mark.asyncio
async def test_call_v1_path_valid_token(self):
scope = create_test_scope(headers={"Authorization": f"Bearer {self.VALID_TOKENS[0]}"})
print(scope)
receive = AsyncMock()
send = AsyncMock()
headers = Headers(scope=scope)
assert self.middleware.verify_token(headers) is True
await self.middleware(scope, receive, send)
self.mock_send.assert_called_once_with(scope=scope, receive=receive, send=send)
@pytest.mark.asyncio
async def test_call_v1_path_invalid_token(self):
scope = create_test_scope(headers={"Authorization": f"Bearer {self.INVALID_TOKEN}"})
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_not_called()
assert send.called
send_call = send.call_args[0][0]
assert isinstance(send_call, dict)
assert "Unauthorized" in send_call["body"].decode("utf-8")
@pytest.mark.asyncio
async def test_call_v1_path_no_token(self):
scope = create_test_scope(headers={})
receive = AsyncMock()
send = AsyncMock()
await self.middleware(scope, receive, send)
self.mock_send.assert_not_called()
send_call = send.call_args[0][0]
assert isinstance(send_call, dict)
assert "Unauthorized" in send_call["body"].decode("utf-8")