From a5b4866ff111f8569ff4e3092f647b12acfbfba4 Mon Sep 17 00:00:00 2001 From: Ryan Date: Tue, 26 Aug 2025 11:25:04 +0800 Subject: [PATCH] [CudaGraph][SOT] Add unit tests for splitting the static graph into piecewise graphs that support cuda_graph (#3590) * add unitest * change sot_warmup_sizes * wtf; add missed commit --- tests/ce/deploy/21b_sot.yaml | 1 + .../test_static_graph_cuda_graph_split.py | 121 ++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 tests/graph_optimization/test_static_graph_cuda_graph_split.py diff --git a/tests/ce/deploy/21b_sot.yaml b/tests/ce/deploy/21b_sot.yaml index db1bcec52..6ead6cb7b 100644 --- a/tests/ce/deploy/21b_sot.yaml +++ b/tests/ce/deploy/21b_sot.yaml @@ -5,3 +5,4 @@ quantization: wint4 use_cudagraph: True graph_optimization_config: graph_opt_level: 1 + sot_warmup_sizes: [2,16,32,64] diff --git a/tests/graph_optimization/test_static_graph_cuda_graph_split.py b/tests/graph_optimization/test_static_graph_cuda_graph_split.py new file mode 100644 index 000000000..7421333a5 --- /dev/null +++ b/tests/graph_optimization/test_static_graph_cuda_graph_split.py @@ -0,0 +1,121 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +os.environ["FLAGS_cuda_graph_blacklist"] = "pd_op.matmul,pd_op.transpose" + + +import unittest + +import paddle +import paddle.nn as nn + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + ParallelConfig, +) +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) + + +@support_graph_optimization +class Attention(nn.Layer): + def __init__(self, fd_config: FDConfig) -> None: + super().__init__() + self.embed_tokens = nn.Embedding(num_embeddings=100, embedding_dim=32) + self.qkv_proj = nn.Linear(32, 64) + self.attn = nn.MultiHeadAttention(embed_dim=64, num_heads=1) + self.o_proj = nn.Linear(64, 32) + + def forward( + self, + ids_remove_padding, + forward_meta: ForwardMeta, + ): + hidden_states = self.embed_tokens(forward_meta.ids_remove_padding) + qkv_out = self.qkv_proj(hidden_states) + attn_out = self.attn(qkv_out) + output = self.o_proj(attn_out) + + return output + + def forward_dynamic( + self, + ids_remove_padding, + forward_meta: ForwardMeta, + ): + hidden_states = self.embed_tokens(forward_meta.ids_remove_padding) + qkv_out = self.qkv_proj(hidden_states) + attn_out = self.attn(qkv_out) + output = self.o_proj(attn_out) + + return output + + +class TestModel(nn.Layer): + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + self.model = Attention(fd_config) + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + return self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + return self.model.forward_dynamic(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + + +class TestStaticGraphCUDAGraphSplit(unittest.TestCase): + + def test(self): + """Run test case""" + # Set FastDeploy config + graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1}) + parallel_config = ParallelConfig({"max_num_seqs": 1}) + graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + cache_config = CacheConfig({}) + + fd_config = FDConfig( + graph_opt_config=graph_opt_config, + parallel_config=parallel_config, + cache_config=cache_config, + test_mode=True, + ) + + test_model1 = TestModel(fd_config=fd_config) + x = paddle.randint(32, shape=[1, 8]) + forward_meta1 = ForwardMeta(input_ids=x, ids_remove_padding=x, step_use_cudagraph=True) + + # Triger Capture + _ = test_model1(x, forward_meta=forward_meta1) + + # Reaplay + _ = test_model1(x, forward_meta=forward_meta1) + output1 = test_model1(x, forward_meta=forward_meta1) + + # Corrent output + output1_correct = test_model1.forward_correct(x, forward_meta=forward_meta1) + + assert (output1 == output1_correct).all() + + +if __name__ == "__main__": + unittest.main()