mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* add api key initial commit * add unit test * modify unit test * move middleware to a single file and add unit tests
56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
import hashlib
|
|
import secrets
|
|
from collections.abc import Awaitable
|
|
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.datastructures import URL, Headers
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
|
|
class AuthenticationMiddleware:
|
|
"""
|
|
Pure ASGI middleware that authenticates each request by checking
|
|
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
|
|
|
Notes
|
|
-----
|
|
There are two cases in which authentication is skipped:
|
|
1. The HTTP method is OPTIONS.
|
|
2. The request path doesn't start with /v1 (e.g. /health).
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
|
self.app = app
|
|
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
|
|
|
|
def verify_token(self, headers: Headers) -> bool:
|
|
authorization_header_value = headers.get("Authorization")
|
|
if not authorization_header_value:
|
|
return False
|
|
|
|
scheme, _, param = authorization_header_value.partition(" ")
|
|
if scheme.lower() != "bearer":
|
|
return False
|
|
|
|
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
|
|
|
token_match = False
|
|
for token_hash in self.api_tokens:
|
|
token_match |= secrets.compare_digest(param_hash, token_hash)
|
|
|
|
return token_match
|
|
|
|
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
|
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
|
|
# scope["type"] can be "lifespan" or "startup" for example,
|
|
# in which case we don't need to do anything
|
|
return self.app(scope, receive, send)
|
|
root_path = scope.get("root_path", "")
|
|
url_path = URL(scope=scope).path.removeprefix(root_path)
|
|
headers = Headers(scope=scope)
|
|
# Type narrow to satisfy mypy.
|
|
if url_path.startswith("/v1") and not self.verify_token(headers):
|
|
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
|
return response(scope, receive, send)
|
|
return self.app(scope, receive, send)
|