mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 21:32:29 +08:00
340 lines
14 KiB
Python
340 lines
14 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Adapt from
|
|
# https://github.com/sgl-project/sglang/blob/main/python/sglang/test/attention/test_flashattn_backend.py
|
|
|
|
import unittest
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.layers.attention import (
|
|
Attention, PaddleNativeAttnBackend)
|
|
from fastdeploy.worker.forward_meta import (ForwardMeta, ForwardMode,
|
|
MHATokenToKVPool)
|
|
|
|
|
|
class MockModelRunner:
|
|
|
|
def __init__(
|
|
self,
|
|
page_size=1,
|
|
num_heads=2,
|
|
head_dim=8,
|
|
):
|
|
self.device = "cuda"
|
|
self.dtype = paddle.float16
|
|
# Max batch size for the test.
|
|
max_batch_size = 160
|
|
# Total tokens(prefix + extend + decode) in the test should not exceed this length.
|
|
max_context_len = 2048
|
|
self.model_config = type(
|
|
"ModelConfig",
|
|
(),
|
|
{
|
|
"context_len": max_context_len,
|
|
},
|
|
)
|
|
self.sliding_window_size = None
|
|
self.device = self.device
|
|
# Create a large enough req_to_token_pool to fit the test usage.
|
|
self.req_to_token_pool = type(
|
|
"TokenPool",
|
|
(),
|
|
{
|
|
# A typical max_bs * max_context_len for cuda graph decode
|
|
"size":
|
|
max_batch_size,
|
|
# Add req_to_token attribute
|
|
"req_to_token":
|
|
paddle.zeros([max_batch_size, max_context_len],
|
|
dtype=paddle.int32),
|
|
},
|
|
)
|
|
self.page_size = page_size
|
|
max_total_num_tokens = max_batch_size * max_context_len
|
|
self.token_to_kv_pool = MHATokenToKVPool(
|
|
size=max_total_num_tokens,
|
|
page_size=page_size,
|
|
dtype=self.dtype,
|
|
head_num=num_heads,
|
|
head_dim=head_dim,
|
|
layer_num=1, # only consider layer=1 for unit test
|
|
device=self.device)
|
|
|
|
|
|
class TestNativePaddleAttentionBackend(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Test parameters
|
|
self.batch_size = 2
|
|
self.seq_len = 256
|
|
self.num_heads = 2
|
|
self.head_dim = 128
|
|
self.device = "gpu"
|
|
self.dtype = paddle.float16
|
|
|
|
def _init_model_runner(self, page_size=1):
|
|
self.model_runner = MockModelRunner(
|
|
page_size=page_size,
|
|
num_heads=self.num_heads,
|
|
head_dim=self.head_dim,
|
|
)
|
|
self.backend = PaddleNativeAttnBackend(self.model_runner)
|
|
self.ref_backend = PaddleNativeAttnBackend(self.model_runner)
|
|
self.model_runner.model_config.num_attention_heads = self.num_heads
|
|
|
|
def _mock_write_to_req_to_token_pool(self, batch_size, seq_len, page_size):
|
|
# if page_size > 1, the token pool stores the index to the page.
|
|
# so we need to multiply the index by page_size.
|
|
self.req_to_token = (
|
|
paddle.arange(0, batch_size, dtype=paddle.int32)[:, None] * seq_len
|
|
+ paddle.arange(0, seq_len, dtype=paddle.int32)[None, :] +
|
|
page_size)
|
|
self.model_runner.req_to_token_pool.req_to_token[:batch_size, :
|
|
seq_len] = (
|
|
self.req_to_token)
|
|
|
|
def _create_attention_layer(self):
|
|
"""Create attention layer for testing."""
|
|
return Attention(
|
|
num_heads=self.num_heads,
|
|
head_dim=self.head_dim,
|
|
num_kv_heads=self.num_heads,
|
|
layer_id=0,
|
|
)
|
|
|
|
def _create_qkv_tensors(self, tokens_len):
|
|
"""Create q, k, v tensors for testing."""
|
|
shape = (tokens_len, self.num_heads, self.head_dim)
|
|
return (
|
|
paddle.randn(shape, dtype=self.dtype),
|
|
paddle.randn(shape, dtype=self.dtype),
|
|
paddle.randn(shape, dtype=self.dtype),
|
|
)
|
|
|
|
def _run_reference_forward(self, mode, q, k, v, layer, forward_batch,
|
|
expected_shape):
|
|
"""Run reference forward pass using native backend."""
|
|
if mode == ForwardMode.EXTEND:
|
|
output = self.ref_backend.forward_extend(q, k, v, layer,
|
|
forward_batch)
|
|
else: # ForwardMode.DECODE
|
|
output = self.ref_backend.forward_decode(q, k, v, layer,
|
|
forward_batch)
|
|
return output.view(expected_shape)
|
|
|
|
def _verify_output(self, output, expected_shape, output_ref=None):
|
|
"""Verify output tensor shape, dtype, and values."""
|
|
self.assertEqual(
|
|
output.shape,
|
|
expected_shape,
|
|
f"Expected shape {expected_shape}, got {output.shape}",
|
|
)
|
|
self.assertEqual(output.dtype, self.dtype)
|
|
self.assertEqual(
|
|
paddle.isnan(output).sum().item(), 0, "Output contains NaN values")
|
|
|
|
if output_ref is not None:
|
|
if not paddle.allclose(output, output_ref, atol=1e-1, rtol=0.0):
|
|
# Check where the values differ beyond the given tolerances
|
|
diff_mask = ~paddle.isclose(
|
|
output, output_ref, atol=1e-1, rtol=0.0)
|
|
|
|
# Find the first index where the difference occurs
|
|
if diff_mask.any():
|
|
first_mismatch_idx = diff_mask.nonzero()[0]
|
|
print("First mismatch at index:",
|
|
tuple(first_mismatch_idx.tolist()))
|
|
print("output:",
|
|
output[tuple(first_mismatch_idx.tolist())])
|
|
print("output_ref:",
|
|
output_ref[tuple(first_mismatch_idx.tolist())])
|
|
raise AssertionError(
|
|
"Attention output is not close to the torch native backend output"
|
|
)
|
|
|
|
def _create_forward_batch(self,
|
|
mode,
|
|
q_len=None,
|
|
prefix_len=0,
|
|
page_size=1):
|
|
"""Create a forward batch for testing based on mode and lengths."""
|
|
self._init_model_runner(page_size=page_size)
|
|
|
|
# Default to self.seq_len if not specified
|
|
q_len = q_len or self.seq_len
|
|
|
|
if mode == ForwardMode.EXTEND:
|
|
total_len = prefix_len + q_len
|
|
out_cache_start = prefix_len * self.batch_size
|
|
out_cache_end = total_len * self.batch_size
|
|
|
|
forward_batch = ForwardMeta(
|
|
batch_size=self.batch_size,
|
|
input_ids=paddle.randint(0, 100, (self.batch_size, q_len)),
|
|
out_cache_loc=paddle.arange(out_cache_start, out_cache_end),
|
|
seq_lens_sum=self.batch_size * total_len, # need to be real
|
|
forward_mode=mode,
|
|
req_pool_indices=paddle.arange(self.batch_size),
|
|
seq_lens=paddle.to_tensor([total_len] * self.batch_size),
|
|
extend_prefix_lens=paddle.to_tensor([prefix_len] *
|
|
self.batch_size),
|
|
extend_seq_lens=paddle.to_tensor([q_len] * self.batch_size),
|
|
seq_lens_cpu=paddle.to_tensor([total_len] * self.batch_size,
|
|
place="cpu"),
|
|
extend_prefix_lens_cpu=paddle.to_tensor([prefix_len] *
|
|
self.batch_size,
|
|
place="cpu"),
|
|
extend_seq_lens_cpu=paddle.to_tensor([q_len] * self.batch_size,
|
|
place="cpu"),
|
|
attn_backend=self.backend,
|
|
)
|
|
else: # ForwardMode.DECODE
|
|
decode_len = q_len # Assuming 1 for decode testing
|
|
total_len = self.seq_len + decode_len
|
|
if mode == ForwardMode.DECODE and page_size > 1:
|
|
# Get next page_size multiple of self.seq_len
|
|
out_cache_start = (self.batch_size * self.seq_len // page_size
|
|
+ 1) * page_size
|
|
# out_cache_end is the start of the next block
|
|
out_cache_end = out_cache_start + decode_len * page_size
|
|
else:
|
|
out_cache_start = self.batch_size * self.seq_len
|
|
out_cache_end = self.batch_size * total_len
|
|
|
|
forward_batch = ForwardMeta(
|
|
batch_size=self.batch_size,
|
|
input_ids=paddle.randint(0, 100,
|
|
(self.batch_size, decode_len)),
|
|
out_cache_loc=paddle.to_tensor(
|
|
[out_cache_start, out_cache_end]),
|
|
seq_lens_sum=self.batch_size * total_len,
|
|
forward_mode=mode,
|
|
req_pool_indices=paddle.arange(self.batch_size),
|
|
seq_lens=paddle.to_tensor([total_len] * self.batch_size),
|
|
seq_lens_cpu=paddle.to_tensor([total_len] * self.batch_size,
|
|
place="cpu"),
|
|
attn_backend=self.backend,
|
|
)
|
|
|
|
# Add token pool
|
|
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
|
|
|
# Write current batch's req_to_token to req_to_token_pool
|
|
self._mock_write_to_req_to_token_pool(self.batch_size, total_len,
|
|
page_size)
|
|
# Add kv pool for this forward batch
|
|
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
|
|
|
return forward_batch
|
|
|
|
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
|
# Create constant values for the prefix cache for easy debugging
|
|
cache_k = paddle.ones(
|
|
[self.batch_size * cache_len, self.num_heads, self.head_dim],
|
|
dtype=self.dtype,
|
|
)
|
|
cache_v = (paddle.ones(
|
|
[self.batch_size * cache_len, self.num_heads, self.head_dim],
|
|
dtype=self.dtype,
|
|
) * 2)
|
|
|
|
# Set the prefix KV cache
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer,
|
|
paddle.arange(self.batch_size * cache_len),
|
|
cache_k,
|
|
cache_v,
|
|
layer.k_scale,
|
|
layer.v_scale,
|
|
)
|
|
|
|
def _run_attention_test(self, mode, q_len, prefix_len=0, page_size=1):
|
|
"""
|
|
Run an attention test with the specified parameters.
|
|
Args:
|
|
mode: ForwardMode.EXTEND or ForwardMode.DECODE
|
|
q_len: Length of the query sequence. For decode mode, q_len is 1.
|
|
prefix_len: Length of the prefix sequence for extend mode
|
|
page_size: Page size for the KV cache
|
|
"""
|
|
layer = self._create_attention_layer()
|
|
|
|
# Create forward batch and set up
|
|
forward_batch = self._create_forward_batch(mode, q_len, prefix_len,
|
|
page_size)
|
|
|
|
# Create QKV tensors for the input
|
|
q, k, v = self._create_qkv_tensors(self.batch_size * q_len)
|
|
|
|
# KV cache for prefixed extend is prefix_len
|
|
# KV cache for decode is same as seq_len
|
|
# No KV cache for extend without prefix
|
|
if mode == ForwardMode.EXTEND:
|
|
if prefix_len > 0:
|
|
self._setup_kv_cache(forward_batch, layer, prefix_len)
|
|
else:
|
|
self._setup_kv_cache(forward_batch, layer, self.seq_len)
|
|
|
|
self.backend.init_attention_metadata(forward_batch)
|
|
|
|
if mode == ForwardMode.EXTEND:
|
|
expected_shape = [
|
|
self.batch_size * q_len,
|
|
self.num_heads,
|
|
self.head_dim,
|
|
]
|
|
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
|
else:
|
|
expected_shape = [self.batch_size, self.num_heads * self.head_dim]
|
|
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
|
|
|
output_ref = self._run_reference_forward(mode, q, k, v, layer,
|
|
forward_batch, expected_shape)
|
|
|
|
self._verify_output(output, expected_shape, output_ref)
|
|
|
|
return output
|
|
|
|
def test_forward_extend(self):
|
|
"""Test the standard extend operation."""
|
|
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
|
|
|
|
def test_forward_decode(self):
|
|
"""Test the decode operation with cached tokens."""
|
|
self._run_attention_test(ForwardMode.DECODE, q_len=1)
|
|
|
|
def test_forward_extend_with_prefix(self):
|
|
"""Test extending from cached prefix tokens."""
|
|
prefix_len = self.seq_len // 2
|
|
extend_len = self.seq_len - prefix_len
|
|
self._run_attention_test(ForwardMode.EXTEND,
|
|
q_len=extend_len,
|
|
prefix_len=prefix_len)
|
|
|
|
def test_forward_extend_with_page_size_greater_than_1(self):
|
|
"""Test extending from cached prefix tokens with page size greater than 1."""
|
|
self._run_attention_test(ForwardMode.EXTEND,
|
|
q_len=self.seq_len,
|
|
page_size=64)
|
|
|
|
def test_forward_decode_with_page_size_greater_than_1(self):
|
|
"""Test decode operation with page size greater than 1."""
|
|
self._run_attention_test(ForwardMode.DECODE, q_len=1, page_size=64)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|