diff --git a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py index 166b2d3ca..82b8a27ac 100644 --- a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py +++ b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py @@ -153,7 +153,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase): graph_opt_config = GraphOptimizationConfig(args={}) graph_opt_config.use_cudagraph = True parallel_config = ParallelConfig(args={}) - parallel_config.max_num_seqs = 1 + parallel_config.max_num_seqs = 8 cache_config = CacheConfig({}) # Initialize cuda graph capture list graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) @@ -167,7 +167,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase): # Run Test Case1 test_model1 = TestModel1(fd_config=fd_config) - input_tensor1 = paddle.ones([32768]) + input_tensor1 = paddle.ones([8]) forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) # Triger Capture @@ -180,7 +180,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase): # Corrent output output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - assert sum(output1 - output1_correct) == 0 + assert (output1 == output1_correct).all() if __name__ == "__main__": diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py index f3b3ff214..6198e1eb4 100644 --- a/tests/graph_optimization/test_cuda_graph_recapture.py +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -98,33 +98,50 @@ class TestCUDAGrpahRecapture(unittest.TestCase): ) # Run Test Case1 - test_model1 = TestModel1(fd_config=fd_config) - input_tensor1 = paddle.ones([32768]) + self.test_model1 = TestModel1(fd_config=fd_config) + input_tensor1 = paddle.ones([1, 32768]) forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) + # Corrent output + self.output_correct = self.test_model1.forward_correct( + ids_remove_padding=input_tensor1, forward_meta=forward_meta1 + ) + + # Capture and Destory + self.capture_and_replay(input_tensor1, forward_meta1) + self.recapture_and_replay(input_tensor1, forward_meta1) + + def capture_and_replay(self, input_tensor1, forward_meta1): + """ """ # Triger Capture print_gpu_memory_use(0, "before capture") - _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) print_gpu_memory_use(0, "after capture") + # Reaplay - output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + assert (output1 == self.output_correct).all() + # Destory print_gpu_memory_use(0, "before destory") - test_model1.clear_grpah_opt_backend() + self.test_model1.clear_grpah_opt_backend() print_gpu_memory_use(0, "after destory") + def recapture_and_replay(self, input_tensor1, forward_meta1): + """ """ # Triger Capture print_gpu_memory_use(0, "before recapture") - _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) print_gpu_memory_use(0, "after recapture") + # Reaplay - output2 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + assert (output2 == self.output_correct).all() - # Corrent output - output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - - assert sum(output1 - output2) == 0 - assert sum(output1_correct - output1) == 0 + # Destory + print_gpu_memory_use(0, "before destory") + self.test_model1.clear_grpah_opt_backend() + print_gpu_memory_use(0, "after destory") if __name__ == "__main__": diff --git a/tests/graph_optimization/test_cuda_graph_spec_decode.py b/tests/graph_optimization/test_cuda_graph_spec_decode.py index 0ae5e5fc6..2fc685bbf 100644 --- a/tests/graph_optimization/test_cuda_graph_spec_decode.py +++ b/tests/graph_optimization/test_cuda_graph_spec_decode.py @@ -114,7 +114,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase): # Run Test Case1 test_model1 = TestModel1(fd_config=fd_config) - input_tensor1 = paddle.ones([32768]) + input_tensor1 = paddle.ones([1, 32768]) forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) # Triger Capture @@ -127,7 +127,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase): # Corrent output output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - assert sum(output1 - output1_correct) == 0 + assert (output1 == output1_correct).all() if __name__ == "__main__":