diff --git a/test/graph_optimization/test_cuda_graph.py b/test/graph_optimization/test_cuda_graph.py deleted file mode 100644 index 597901357..000000000 --- a/test/graph_optimization/test_cuda_graph.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import paddle - -from fastdeploy.config import FDConfig, GraphOptimizationConfig -from fastdeploy.model_executor.forward_meta import ForwardMeta -from fastdeploy.model_executor.graph_optimization.decorator import ( - support_graph_optimization, -) - - -@support_graph_optimization -class TestCase1SubLayer1(paddle.nn.Layer): - """Sub layer 1 of test case 1""" - - def __init__(self, fd_config: FDConfig, **kwargs): - super().__init__() - - def forward(self, _, forward_meta: ForwardMeta): - """Sub layer1 forward pass""" - - output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) - print(" SubLayer1 Output: {output}") - return output - - -class TestCase1SubLayer2(paddle.nn.Layer): - """ """ - - def __init__(self, fd_config: FDConfig, **kwargs): - super().__init__() - - def forward(self, _, forward_meta: ForwardMeta): - """Sub layer2 forward pass""" - x = paddle.ones_like(forward_meta.input_ids) - y = paddle.ones_like(forward_meta.input_ids) - output = x + y - print(" SubLayer2 Output: {output}") - return output - - -@support_graph_optimization -class TestCase1SubLayer3(paddle.nn.Layer): - """ """ - - def __init__(self, fd_config: FDConfig, **kwargs): - super().__init__() - - def forward(self, _, forward_meta: ForwardMeta): - """Sub layer3 forward pass""" - output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) - print(" SubLayer3 Output: {output}") - return output - - -class TestModel1(paddle.nn.Layer): - """Tast Model""" - - def __init__(self, fd_config: FDConfig, **kwargs): - super().__init__() - self.fd_config = fd_config - - def forward(self, _, forward_meta: ForwardMeta): - """Test model for ward pass""" - self.sublayer1 = TestCase1SubLayer1(self.fd_config) - self.sublayer2 = TestCase1SubLayer2(self.fd_config) - self.sublayer3 = TestCase1SubLayer3(self.fd_config) - - # sublayer1 use cuda graph - sub_meta1 = forward_meta - sublayer1_output = self.sublayer1(_=None, forward_meta=sub_meta1) - - # sublayer2 not use cuda garph - sub_meta2 = ForwardMeta(input_ids=sublayer1_output) - sublayer2_output = self.sublayer2(_=None, forward_meta=sub_meta2) - - # sublayer3 use cuda graph - sub_meta3 = ForwardMeta(input_ids=sublayer2_output) - sublayer3_output = self.sublayer3(_=None, forward_meta=sub_meta3) - - return sublayer3_output - - -@support_graph_optimization -class TestModel2(paddle.nn.Layer): - """Tast Model""" - - def __init__(self, fd_config: FDConfig, **kwargs): - super().__init__() - - def forward(self, _, forward_meta: ForwardMeta): - """Test model for ward pass""" - return forward_meta.input_ids + forward_meta.input_ids - - -def run_test_case(): - """Run test case""" - # Set llm config1 - graph_opt_config = GraphOptimizationConfig() - graph_opt_config.use_cudagraph = True - graph_opt_config.cudagraph_capture_sizes = [1] - fd_config = FDConfig(graph_opt_config=graph_opt_config) - - # Run Test Case1 - test_model1 = TestModel1(fd_config=fd_config) - input_tensor1 = paddle.zeros([1, 8]) - forward_meta1 = ForwardMeta(input_ids=input_tensor1) - output1 = test_model1(_=None, forward_meta=forward_meta1) - print(output1) - - # Run Test Case2 - test_model2 = TestModel2(fd_config=fd_config) - input_tensor2 = paddle.zeros([1, 8]) - forward_meta2 = ForwardMeta(input_ids=input_tensor2) - output2 = test_model2(_=None, forward_meta=forward_meta2) - print(output2) - - -if __name__ == "__main__": - run_test_case() diff --git a/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py b/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py new file mode 100644 index 000000000..9e28240bf --- /dev/null +++ b/test/graph_optimization/test_cuda_graph_dynamic_subgraph.py @@ -0,0 +1,168 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.config import FDConfig, GraphOptimizationConfig, ParallelConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) + + +@support_graph_optimization +class TestCase1SubLayer1(paddle.nn.Layer): + """Sub layer 1 of test case 1""" + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer1 forward pass""" + + output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) + return output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer1 Correct forward pass""" + output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) + return output + + +class TestCase1SubLayer2(paddle.nn.Layer): + """ """ + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer2 forward pass""" + x = forward_meta.input_ids + y = forward_meta.input_ids + output = x + y + return output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer2 Correct forward pass""" + x = forward_meta.input_ids + y = forward_meta.input_ids + output = x + y + return output + + +@support_graph_optimization +class TestCase1SubLayer3(paddle.nn.Layer): + """ """ + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer3 forward pass""" + output = paddle.matmul(forward_meta.input_ids, forward_meta.input_ids) + return output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer3 Correct forward pass""" + output = paddle.matmul(forward_meta.input_ids, forward_meta.input_ids) + return output + + +class TestModel1(paddle.nn.Layer): + """Tast Model""" + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + self.fd_config = fd_config + + self.sublayer1 = TestCase1SubLayer1(self.fd_config) + self.sublayer2 = TestCase1SubLayer2(self.fd_config) + self.sublayer3 = TestCase1SubLayer3(self.fd_config) + + self.sublayer2_output_buffer = paddle.zeros([1]) + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Test model for ward pass""" + # sublayer1 use cuda graph + sub_meta1 = forward_meta + sublayer1_output = self.sublayer1(ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1) + + # sublayer2 not use cuda garph + sub_meta2 = ForwardMeta( + input_ids=sublayer1_output, ids_remove_padding=sublayer1_output, step_use_cudagraph=False + ) + sublayer2_output = self.sublayer2(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) + self.sublayer2_output_buffer.copy_(sublayer2_output, False) + + # sublayer3 use cuda graph + sub_meta3 = ForwardMeta( + input_ids=self.sublayer2_output_buffer, + ids_remove_padding=self.sublayer2_output_buffer, + step_use_cudagraph=True, + ) + sublayer3_output = self.sublayer3(ids_remove_padding=self.sublayer2_output_buffer, forward_meta=sub_meta3) + + return sublayer3_output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Test model for ward pass""" + + # sublayer1 not use cuda graph + sub_meta1 = forward_meta + sublayer1_output = self.sublayer1.forward_correct( + ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1 + ) + + # sublayer2 not use cuda garph + sub_meta2 = ForwardMeta(input_ids=sublayer1_output, ids_remove_padding=sublayer1_output) + sublayer2_output = self.sublayer2.forward_correct(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) + + # sublayer3 not use cuda graph + sub_meta3 = ForwardMeta(input_ids=sublayer2_output, ids_remove_padding=sublayer2_output) + sublayer3_output = self.sublayer3.forward_correct(ids_remove_padding=sublayer2_output, forward_meta=sub_meta3) + + return sublayer3_output + + +def run_test_case(): + """Run test case""" + # Set FastDeploy config + graph_opt_config = GraphOptimizationConfig(args={}) + graph_opt_config.use_cudagraph = True + parallel_config = ParallelConfig(args={}) + parallel_config.max_num_seqs = 1 + fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config) + + # Run Test Case1 + test_model1 = TestModel1(fd_config=fd_config) + input_tensor1 = paddle.ones([1]) + forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) + + # Triger Capture + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + # Reaplay + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + # Corrent output + output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + assert output1 == output1_correct + + +if __name__ == "__main__": + run_test_case() diff --git a/test/graph_optimization/test_cuda_graph_spec_decode.py b/test/graph_optimization/test_cuda_graph_spec_decode.py new file mode 100644 index 000000000..8e8fcf488 --- /dev/null +++ b/test/graph_optimization/test_cuda_graph_spec_decode.py @@ -0,0 +1,114 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.config import FDConfig, GraphOptimizationConfig, ParallelConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) + + +@support_graph_optimization +class TestCase1SubLayer1(paddle.nn.Layer): + """Sub layer 1 of test case 1""" + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer1 forward pass""" + + output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) + return output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer1 Correct forward pass""" + + output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) + return output + + +class TestModel1(paddle.nn.Layer): + """Tast Model""" + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + self.fd_config = fd_config + + self.sublayer1 = TestCase1SubLayer1(self.fd_config) + sublayer1_copy = TestCase1SubLayer1(self.fd_config) + self.sublayer2 = sublayer1_copy + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Test model forward pass""" + # sublayer1 use cuda graph + sub_meta1 = forward_meta + sublayer1_output = self.sublayer1(ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1) + + # sublayer2 use cuda graph + sub_meta2 = ForwardMeta( + input_ids=sublayer1_output, ids_remove_padding=sublayer1_output, step_use_cudagraph=True + ) + sublayer2_output = self.sublayer2(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) + + return sublayer2_output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Test model Correct forward pass""" + # sublayer1 not use cuda graph + sub_meta1 = forward_meta + sublayer1_output = self.sublayer1.forward_correct( + ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1 + ) + + # sublayer2 not use cuda graph + sub_meta2 = ForwardMeta(input_ids=sublayer1_output, ids_remove_padding=sublayer1_output) + sublayer2_output = self.sublayer2.forward_correct(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) + + return sublayer2_output + + +def run_test_case(): + """Run test case""" + # Set FastDeploy config + graph_opt_config = GraphOptimizationConfig(args={}) + graph_opt_config.use_cudagraph = True + parallel_config = ParallelConfig(args={}) + parallel_config.max_num_seqs = 1 + fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config) + + # Run Test Case1 + test_model1 = TestModel1(fd_config=fd_config) + input_tensor1 = paddle.ones([1]) + forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) + + # Triger Capture + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + # Reaplay + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + # Corrent output + output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + assert output1 == output1_correct + + +if __name__ == "__main__": + run_test_case() diff --git a/test/layers/test_attention.py b/test/layers/test_attention.py deleted file mode 100644 index 5a9816454..000000000 --- a/test/layers/test_attention.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Adapt from -# https://github.com/sgl-project/sglang/blob/main/python/sglang/test/attention/test_flashattn_backend.py - -import unittest - -import paddle - -from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode -from fastdeploy.model_executor.layers.attention import ( - Attention, - PaddleNativeAttnBackend, -) - - -class MockModelRunner: - def __init__( - self, - page_size=1, - num_heads=2, - head_dim=8, - ): - self.device = "cuda" - self.dtype = paddle.float16 - # Max batch size for the test. - max_batch_size = 160 - # Total tokens(prefix + extend + decode) in the test should not exceed this length. - max_context_len = 2048 - self.model_config = type( - "ModelConfig", - (), - { - "context_len": max_context_len, - }, - ) - self.sliding_window_size = None - self.device = self.device - # Create a large enough req_to_token_pool to fit the test usage. - self.req_to_token_pool = type( - "TokenPool", - (), - { - # A typical max_bs * max_context_len for cuda graph decode - "size": max_batch_size, - # Add req_to_token attribute - "req_to_token": paddle.zeros([max_batch_size, max_context_len], dtype=paddle.int32), - }, - ) - self.page_size = page_size - - -class TestNativePaddleAttentionBackend(unittest.TestCase): - def setUp(self): - # Test parameters - self.batch_size = 2 - self.seq_len = 256 - self.num_heads = 2 - self.head_dim = 128 - self.device = "gpu" - self.dtype = paddle.float16 - - def _init_model_runner(self, page_size=1): - self.model_runner = MockModelRunner( - page_size=page_size, - num_heads=self.num_heads, - head_dim=self.head_dim, - ) - self.backend = PaddleNativeAttnBackend(self.model_runner) - self.ref_backend = PaddleNativeAttnBackend(self.model_runner) - self.model_runner.model_config.num_attention_heads = self.num_heads - - def _mock_write_to_req_to_token_pool(self, batch_size, seq_len, page_size): - # if page_size > 1, the token pool stores the index to the page. - # so we need to multiply the index by page_size. - self.req_to_token = ( - paddle.arange(0, batch_size, dtype=paddle.int32)[:, None] * seq_len - + paddle.arange(0, seq_len, dtype=paddle.int32)[None, :] - + page_size - ) - self.model_runner.req_to_token_pool.req_to_token[:batch_size, :seq_len] = self.req_to_token - - def _create_attention_layer(self): - """Create attention layer for testing.""" - return Attention( - num_heads=self.num_heads, - head_dim=self.head_dim, - num_kv_heads=self.num_heads, - layer_id=0, - ) - - def _create_qkv_tensors(self, tokens_len): - """Create q, k, v tensors for testing.""" - shape = (tokens_len, self.num_heads, self.head_dim) - return ( - paddle.randn(shape, dtype=self.dtype), - paddle.randn(shape, dtype=self.dtype), - paddle.randn(shape, dtype=self.dtype), - ) - - def _run_reference_forward(self, mode, q, k, v, layer, forward_batch, expected_shape): - """Run reference forward pass using native backend.""" - if mode == ForwardMode.EXTEND: - output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch) - else: # ForwardMode.DECODE - output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch) - return output.view(expected_shape) - - def _verify_output(self, output, expected_shape, output_ref=None): - """Verify output tensor shape, dtype, and values.""" - self.assertEqual( - output.shape, - expected_shape, - f"Expected shape {expected_shape}, got {output.shape}", - ) - self.assertEqual(output.dtype, self.dtype) - self.assertEqual(paddle.isnan(output).sum().item(), 0, "Output contains NaN values") - - if output_ref is not None: - if not paddle.allclose(output, output_ref, atol=1e-1, rtol=0.0): - # Check where the values differ beyond the given tolerances - diff_mask = ~paddle.isclose(output, output_ref, atol=1e-1, rtol=0.0) - - # Find the first index where the difference occurs - if diff_mask.any(): - first_mismatch_idx = diff_mask.nonzero()[0] - print( - "First mismatch at index:", - tuple(first_mismatch_idx.tolist()), - ) - print("output:", output[tuple(first_mismatch_idx.tolist())]) - print( - "output_ref:", - output_ref[tuple(first_mismatch_idx.tolist())], - ) - raise AssertionError("Attention output is not close to the torch native backend output") - - def _create_forward_batch(self, mode, q_len=None, prefix_len=0, page_size=1): - """Create a forward batch for testing based on mode and lengths.""" - self._init_model_runner(page_size=page_size) - - # Default to self.seq_len if not specified - q_len = q_len or self.seq_len - - if mode == ForwardMode.EXTEND: - total_len = prefix_len + q_len - out_cache_start = prefix_len * self.batch_size - out_cache_end = total_len * self.batch_size - - forward_batch = ForwardMeta( - batch_size=self.batch_size, - input_ids=paddle.randint(0, 100, (self.batch_size, q_len)), - out_cache_loc=paddle.arange(out_cache_start, out_cache_end), - seq_lens_sum=self.batch_size * total_len, # need to be real - forward_mode=mode, - req_pool_indices=paddle.arange(self.batch_size), - seq_lens=paddle.to_tensor([total_len] * self.batch_size), - extend_prefix_lens=paddle.to_tensor([prefix_len] * self.batch_size), - extend_seq_lens=paddle.to_tensor([q_len] * self.batch_size), - seq_lens_cpu=paddle.to_tensor([total_len] * self.batch_size, place="cpu"), - extend_prefix_lens_cpu=paddle.to_tensor([prefix_len] * self.batch_size, place="cpu"), - extend_seq_lens_cpu=paddle.to_tensor([q_len] * self.batch_size, place="cpu"), - attn_backend=self.backend, - ) - else: # ForwardMode.DECODE - decode_len = q_len # Assuming 1 for decode testing - total_len = self.seq_len + decode_len - if mode == ForwardMode.DECODE and page_size > 1: - # Get next page_size multiple of self.seq_len - out_cache_start = (self.batch_size * self.seq_len // page_size + 1) * page_size - # out_cache_end is the start of the next block - out_cache_end = out_cache_start + decode_len * page_size - else: - out_cache_start = self.batch_size * self.seq_len - out_cache_end = self.batch_size * total_len - - forward_batch = ForwardMeta( - batch_size=self.batch_size, - input_ids=paddle.randint(0, 100, (self.batch_size, decode_len)), - out_cache_loc=paddle.to_tensor([out_cache_start, out_cache_end]), - seq_lens_sum=self.batch_size * total_len, - forward_mode=mode, - req_pool_indices=paddle.arange(self.batch_size), - seq_lens=paddle.to_tensor([total_len] * self.batch_size), - seq_lens_cpu=paddle.to_tensor([total_len] * self.batch_size, place="cpu"), - attn_backend=self.backend, - ) - - # Add token pool - forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool - - # Write current batch's req_to_token to req_to_token_pool - self._mock_write_to_req_to_token_pool(self.batch_size, total_len, page_size) - # Add kv pool for this forward batch - forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool - - return forward_batch - - def _setup_kv_cache(self, forward_batch, layer, cache_len): - # Create constant values for the prefix cache for easy debugging - cache_k = paddle.ones( - [self.batch_size * cache_len, self.num_heads, self.head_dim], - dtype=self.dtype, - ) - cache_v = ( - paddle.ones( - [self.batch_size * cache_len, self.num_heads, self.head_dim], - dtype=self.dtype, - ) - * 2 - ) - - # Set the prefix KV cache - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - paddle.arange(self.batch_size * cache_len), - cache_k, - cache_v, - layer.k_scale, - layer.v_scale, - ) - - def _run_attention_test(self, mode, q_len, prefix_len=0, page_size=1): - """ - Run an attention test with the specified parameters. - Args: - mode: ForwardMode.EXTEND or ForwardMode.DECODE - q_len: Length of the query sequence. For decode mode, q_len is 1. - prefix_len: Length of the prefix sequence for extend mode - page_size: Page size for the KV cache - """ - layer = self._create_attention_layer() - - # Create forward batch and set up - forward_batch = self._create_forward_batch(mode, q_len, prefix_len, page_size) - - # Create QKV tensors for the input - q, k, v = self._create_qkv_tensors(self.batch_size * q_len) - - # KV cache for prefixed extend is prefix_len - # KV cache for decode is same as seq_len - # No KV cache for extend without prefix - if mode == ForwardMode.EXTEND: - if prefix_len > 0: - self._setup_kv_cache(forward_batch, layer, prefix_len) - else: - self._setup_kv_cache(forward_batch, layer, self.seq_len) - - self.backend.init_attention_metadata(forward_batch) - - if mode == ForwardMode.EXTEND: - expected_shape = [ - self.batch_size * q_len, - self.num_heads, - self.head_dim, - ] - output = self.backend.forward_extend(q, k, v, layer, forward_batch) - else: - expected_shape = [self.batch_size, self.num_heads * self.head_dim] - output = self.backend.forward_decode(q, k, v, layer, forward_batch) - - output_ref = self._run_reference_forward(mode, q, k, v, layer, forward_batch, expected_shape) - - self._verify_output(output, expected_shape, output_ref) - - return output - - def test_forward_extend(self): - """Test the standard extend operation.""" - self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len) - - def test_forward_decode(self): - """Test the decode operation with cached tokens.""" - self._run_attention_test(ForwardMode.DECODE, q_len=1) - - def test_forward_extend_with_prefix(self): - """Test extending from cached prefix tokens.""" - prefix_len = self.seq_len // 2 - extend_len = self.seq_len - prefix_len - self._run_attention_test(ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len) - - def test_forward_extend_with_page_size_greater_than_1(self): - """Test extending from cached prefix tokens with page size greater than 1.""" - self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len, page_size=64) - - def test_forward_decode_with_page_size_greater_than_1(self): - """Test decode operation with page size greater than 1.""" - self._run_attention_test(ForwardMode.DECODE, q_len=1, page_size=64) - - -if __name__ == "__main__": - unittest.main()