mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -19,13 +19,14 @@ import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention import PaddleNativeAttnBackend, Attention
|
||||
from fastdeploy.model_executor.model_runner import ReqToTokenPool, KVCache, MHATokenToKVPool
|
||||
from fastdeploy.model_executor.model_runner.model_runner_minimal_os import MinimalModelRunner
|
||||
from fastdeploy.model_executor.model_runner import ForwardMeta, ForwardMode
|
||||
from fastdeploy.model_executor.layers.attention import (
|
||||
Attention, PaddleNativeAttnBackend)
|
||||
from fastdeploy.worker.forward_meta import (ForwardMeta, ForwardMode,
|
||||
MHATokenToKVPool)
|
||||
|
||||
|
||||
class MockModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page_size=1,
|
||||
@@ -53,12 +54,12 @@ class MockModelRunner:
|
||||
(),
|
||||
{
|
||||
# A typical max_bs * max_context_len for cuda graph decode
|
||||
"size": max_batch_size,
|
||||
"size":
|
||||
max_batch_size,
|
||||
# Add req_to_token attribute
|
||||
"req_to_token": paddle.zeros(
|
||||
[max_batch_size, max_context_len],
|
||||
dtype=paddle.int32
|
||||
),
|
||||
"req_to_token":
|
||||
paddle.zeros([max_batch_size, max_context_len],
|
||||
dtype=paddle.int32),
|
||||
},
|
||||
)
|
||||
self.page_size = page_size
|
||||
@@ -70,11 +71,11 @@ class MockModelRunner:
|
||||
head_num=num_heads,
|
||||
head_dim=head_dim,
|
||||
layer_num=1, # only consider layer=1 for unit test
|
||||
device=self.device
|
||||
)
|
||||
device=self.device)
|
||||
|
||||
|
||||
class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Test parameters
|
||||
self.batch_size = 2
|
||||
@@ -98,14 +99,12 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
# if page_size > 1, the token pool stores the index to the page.
|
||||
# so we need to multiply the index by page_size.
|
||||
self.req_to_token = (
|
||||
paddle.arange(0, batch_size, dtype=paddle.int32)[:, None]
|
||||
* seq_len
|
||||
+ paddle.arange(0, seq_len, dtype=paddle.int32)[None, :]
|
||||
+ page_size
|
||||
)
|
||||
self.model_runner.req_to_token_pool.req_to_token[:batch_size, :seq_len] = (
|
||||
self.req_to_token
|
||||
)
|
||||
paddle.arange(0, batch_size, dtype=paddle.int32)[:, None] * seq_len
|
||||
+ paddle.arange(0, seq_len, dtype=paddle.int32)[None, :] +
|
||||
page_size)
|
||||
self.model_runner.req_to_token_pool.req_to_token[:batch_size, :
|
||||
seq_len] = (
|
||||
self.req_to_token)
|
||||
|
||||
def _create_attention_layer(self):
|
||||
"""Create attention layer for testing."""
|
||||
@@ -125,16 +124,15 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
paddle.randn(shape, dtype=self.dtype),
|
||||
)
|
||||
|
||||
def _run_reference_forward(
|
||||
self, mode, q, k, v, layer, forward_batch, expected_shape
|
||||
):
|
||||
def _run_reference_forward(self, mode, q, k, v, layer, forward_batch,
|
||||
expected_shape):
|
||||
"""Run reference forward pass using native backend."""
|
||||
if mode == ForwardMode.EXTEND:
|
||||
output = self.ref_backend.forward_extend(
|
||||
q, k, v, layer, forward_batch)
|
||||
output = self.ref_backend.forward_extend(q, k, v, layer,
|
||||
forward_batch)
|
||||
else: # ForwardMode.DECODE
|
||||
output = self.ref_backend.forward_decode(
|
||||
q, k, v, layer, forward_batch)
|
||||
output = self.ref_backend.forward_decode(q, k, v, layer,
|
||||
forward_batch)
|
||||
return output.view(expected_shape)
|
||||
|
||||
def _verify_output(self, output, expected_shape, output_ref=None):
|
||||
@@ -146,8 +144,7 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(output.dtype, self.dtype)
|
||||
self.assertEqual(
|
||||
paddle.isnan(output).sum().item(), 0, "Output contains NaN values"
|
||||
)
|
||||
paddle.isnan(output).sum().item(), 0, "Output contains NaN values")
|
||||
|
||||
if output_ref is not None:
|
||||
if not paddle.allclose(output, output_ref, atol=1e-1, rtol=0.0):
|
||||
@@ -158,19 +155,21 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
# Find the first index where the difference occurs
|
||||
if diff_mask.any():
|
||||
first_mismatch_idx = diff_mask.nonzero()[0]
|
||||
print(
|
||||
"First mismatch at index:", tuple(
|
||||
first_mismatch_idx.tolist())
|
||||
)
|
||||
print("output:", output[tuple(
|
||||
first_mismatch_idx.tolist())])
|
||||
print("output_ref:", output_ref[tuple(
|
||||
first_mismatch_idx.tolist())])
|
||||
print("First mismatch at index:",
|
||||
tuple(first_mismatch_idx.tolist()))
|
||||
print("output:",
|
||||
output[tuple(first_mismatch_idx.tolist())])
|
||||
print("output_ref:",
|
||||
output_ref[tuple(first_mismatch_idx.tolist())])
|
||||
raise AssertionError(
|
||||
"Attention output is not close to the torch native backend output"
|
||||
)
|
||||
|
||||
def _create_forward_batch(self, mode, q_len=None, prefix_len=0, page_size=1):
|
||||
def _create_forward_batch(self,
|
||||
mode,
|
||||
q_len=None,
|
||||
prefix_len=0,
|
||||
page_size=1):
|
||||
"""Create a forward batch for testing based on mode and lengths."""
|
||||
self._init_model_runner(page_size=page_size)
|
||||
|
||||
@@ -184,32 +183,22 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
|
||||
forward_batch = ForwardMeta(
|
||||
batch_size=self.batch_size,
|
||||
input_ids=paddle.randint(
|
||||
0, 100, (self.batch_size, q_len)
|
||||
),
|
||||
out_cache_loc=paddle.arange(
|
||||
out_cache_start, out_cache_end
|
||||
),
|
||||
input_ids=paddle.randint(0, 100, (self.batch_size, q_len)),
|
||||
out_cache_loc=paddle.arange(out_cache_start, out_cache_end),
|
||||
seq_lens_sum=self.batch_size * total_len, # need to be real
|
||||
forward_mode=mode,
|
||||
req_pool_indices=paddle.arange(self.batch_size),
|
||||
seq_lens=paddle.to_tensor(
|
||||
[total_len] * self.batch_size
|
||||
),
|
||||
extend_prefix_lens=paddle.to_tensor(
|
||||
[prefix_len] * self.batch_size
|
||||
),
|
||||
extend_seq_lens=paddle.to_tensor(
|
||||
[q_len] * self.batch_size
|
||||
),
|
||||
seq_lens_cpu=paddle.to_tensor(
|
||||
[total_len] * self.batch_size, place="cpu"),
|
||||
extend_prefix_lens_cpu=paddle.to_tensor(
|
||||
[prefix_len] * self.batch_size, place="cpu"
|
||||
),
|
||||
extend_seq_lens_cpu=paddle.to_tensor(
|
||||
[q_len] * self.batch_size, place="cpu"
|
||||
),
|
||||
seq_lens=paddle.to_tensor([total_len] * self.batch_size),
|
||||
extend_prefix_lens=paddle.to_tensor([prefix_len] *
|
||||
self.batch_size),
|
||||
extend_seq_lens=paddle.to_tensor([q_len] * self.batch_size),
|
||||
seq_lens_cpu=paddle.to_tensor([total_len] * self.batch_size,
|
||||
place="cpu"),
|
||||
extend_prefix_lens_cpu=paddle.to_tensor([prefix_len] *
|
||||
self.batch_size,
|
||||
place="cpu"),
|
||||
extend_seq_lens_cpu=paddle.to_tensor([q_len] * self.batch_size,
|
||||
place="cpu"),
|
||||
attn_backend=self.backend,
|
||||
)
|
||||
else: # ForwardMode.DECODE
|
||||
@@ -217,9 +206,8 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
total_len = self.seq_len + decode_len
|
||||
if mode == ForwardMode.DECODE and page_size > 1:
|
||||
# Get next page_size multiple of self.seq_len
|
||||
out_cache_start = (
|
||||
self.batch_size * self.seq_len // page_size + 1
|
||||
) * page_size
|
||||
out_cache_start = (self.batch_size * self.seq_len // page_size
|
||||
+ 1) * page_size
|
||||
# out_cache_end is the start of the next block
|
||||
out_cache_end = out_cache_start + decode_len * page_size
|
||||
else:
|
||||
@@ -228,20 +216,16 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
|
||||
forward_batch = ForwardMeta(
|
||||
batch_size=self.batch_size,
|
||||
input_ids=paddle.randint(
|
||||
0, 100, (self.batch_size, decode_len)
|
||||
),
|
||||
input_ids=paddle.randint(0, 100,
|
||||
(self.batch_size, decode_len)),
|
||||
out_cache_loc=paddle.to_tensor(
|
||||
[out_cache_start, out_cache_end]
|
||||
),
|
||||
[out_cache_start, out_cache_end]),
|
||||
seq_lens_sum=self.batch_size * total_len,
|
||||
forward_mode=mode,
|
||||
req_pool_indices=paddle.arange(self.batch_size),
|
||||
seq_lens=paddle.to_tensor(
|
||||
[total_len] * self.batch_size
|
||||
),
|
||||
seq_lens_cpu=paddle.to_tensor(
|
||||
[total_len] * self.batch_size, place="cpu"),
|
||||
seq_lens=paddle.to_tensor([total_len] * self.batch_size),
|
||||
seq_lens_cpu=paddle.to_tensor([total_len] * self.batch_size,
|
||||
place="cpu"),
|
||||
attn_backend=self.backend,
|
||||
)
|
||||
|
||||
@@ -249,8 +233,8 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
||||
|
||||
# Write current batch's req_to_token to req_to_token_pool
|
||||
self._mock_write_to_req_to_token_pool(
|
||||
self.batch_size, total_len, page_size)
|
||||
self._mock_write_to_req_to_token_pool(self.batch_size, total_len,
|
||||
page_size)
|
||||
# Add kv pool for this forward batch
|
||||
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||
|
||||
@@ -259,20 +243,13 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
||||
# Create constant values for the prefix cache for easy debugging
|
||||
cache_k = paddle.ones(
|
||||
[self.batch_size * cache_len,
|
||||
self.num_heads,
|
||||
self.head_dim],
|
||||
[self.batch_size * cache_len, self.num_heads, self.head_dim],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
cache_v = (
|
||||
paddle.ones(
|
||||
[self.batch_size * cache_len,
|
||||
self.num_heads,
|
||||
self.head_dim],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
* 2
|
||||
)
|
||||
cache_v = (paddle.ones(
|
||||
[self.batch_size * cache_len, self.num_heads, self.head_dim],
|
||||
dtype=self.dtype,
|
||||
) * 2)
|
||||
|
||||
# Set the prefix KV cache
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
@@ -296,8 +273,8 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
layer = self._create_attention_layer()
|
||||
|
||||
# Create forward batch and set up
|
||||
forward_batch = self._create_forward_batch(
|
||||
mode, q_len, prefix_len, page_size)
|
||||
forward_batch = self._create_forward_batch(mode, q_len, prefix_len,
|
||||
page_size)
|
||||
|
||||
# Create QKV tensors for the input
|
||||
q, k, v = self._create_qkv_tensors(self.batch_size * q_len)
|
||||
@@ -316,16 +293,16 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
if mode == ForwardMode.EXTEND:
|
||||
expected_shape = [
|
||||
self.batch_size * q_len,
|
||||
self.num_heads, self.head_dim,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
]
|
||||
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
||||
else:
|
||||
expected_shape = [self.batch_size, self.num_heads * self.head_dim]
|
||||
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
||||
|
||||
output_ref = self._run_reference_forward(
|
||||
mode, q, k, v, layer, forward_batch, expected_shape
|
||||
)
|
||||
output_ref = self._run_reference_forward(mode, q, k, v, layer,
|
||||
forward_batch, expected_shape)
|
||||
|
||||
self._verify_output(output, expected_shape, output_ref)
|
||||
|
||||
@@ -343,14 +320,15 @@ class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
"""Test extending from cached prefix tokens."""
|
||||
prefix_len = self.seq_len // 2
|
||||
extend_len = self.seq_len - prefix_len
|
||||
self._run_attention_test(
|
||||
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
|
||||
)
|
||||
self._run_attention_test(ForwardMode.EXTEND,
|
||||
q_len=extend_len,
|
||||
prefix_len=prefix_len)
|
||||
|
||||
def test_forward_extend_with_page_size_greater_than_1(self):
|
||||
"""Test extending from cached prefix tokens with page size greater than 1."""
|
||||
self._run_attention_test(
|
||||
ForwardMode.EXTEND, q_len=self.seq_len, page_size=64)
|
||||
self._run_attention_test(ForwardMode.EXTEND,
|
||||
q_len=self.seq_len,
|
||||
page_size=64)
|
||||
|
||||
def test_forward_decode_with_page_size_greater_than_1(self):
|
||||
"""Test decode operation with page size greater than 1."""
|
||||
|
Reference in New Issue
Block a user