From 2fa173e327c77b29e7554920329b124aa91f4dd5 Mon Sep 17 00:00:00 2001 From: RAM Date: Mon, 25 Aug 2025 20:59:30 +0800 Subject: [PATCH] [Executor] CUDAGraph support RL training (#3265) * add clear graph opt backend * cuda graph support rl * add branch * 1.fix dynamic_weight_manager bug 2.add clear api for CasualLM * open test case * fix typo * update mkdocs.yaml * [Docs]Update mkdocs.yml * update test case * use unittest in graph test case --- .../cudagraph_piecewise_backend.py | 36 +++-- .../graph_optimization/decorator.py | 8 ++ .../graph_optimization_backend.py | 4 + .../model_executor/models/deepseek_v3.py | 4 + .../model_executor/models/ernie4_5_moe.py | 4 + .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 4 + fastdeploy/model_executor/models/qwen2.py | 4 + fastdeploy/model_executor/models/qwen3.py | 4 + fastdeploy/model_executor/models/qwen3moe.py | 4 + fastdeploy/utils.py | 9 ++ fastdeploy/worker/gpu_model_runner.py | 11 ++ .../test_cuda_graph_dynamic_subgraph.py | 70 +++++----- .../test_cuda_graph_recapture.py | 131 ++++++++++++++++++ .../test_cuda_graph_spec_decode.py | 63 +++++---- 14 files changed, 289 insertions(+), 67 deletions(-) create mode 100644 tests/graph_optimization/test_cuda_graph_recapture.py diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 1b1bebebc..a5c149d04 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -62,15 +62,7 @@ class CudaGraphPiecewiseBackend: self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size - # Runtime real shape -> ConcreteSizeEntry - self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} - - for shape in self.cudagraph_capture_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape) - - logger.info( - f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry." - ) + self._create_entry_dict() def __call__(self, **kwargs): # Get real shape(all num tokens) @@ -128,3 +120,29 @@ class CudaGraphPiecewiseBackend: entry.cuda_graph.replay() logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") return entry.output_buffer + + def _create_entry_dict(self): + """ """ + # Runtime real shape -> ConcreteSizeEntry + self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + + for shape in self.cudagraph_capture_sizes: + self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape) + + logger.info( + f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry." + ) + + def clear_graph(self): + """ """ + # Clear graphs + for id, entry in self.concrete_size_entries.items(): + if entry.cuda_graph: + del entry.cuda_graph + logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.") + + del self.concrete_size_entries + paddle.device.cuda.empty_cache() + + # Create new entrys + self._create_entry_dict() diff --git a/fastdeploy/model_executor/graph_optimization/decorator.py b/fastdeploy/model_executor/graph_optimization/decorator.py index 49b92feb4..ef4f54f98 100644 --- a/fastdeploy/model_executor/graph_optimization/decorator.py +++ b/fastdeploy/model_executor/graph_optimization/decorator.py @@ -91,3 +91,11 @@ class GraphOptWrapper: def __call__(self, **kwargs): return self.graph_opt_backend(**kwargs) + + def clear_grpah_opt_backend(self, fd_config): + """ """ + # TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights + assert ( + fd_config.graph_opt_config.graph_opt_level < 1 + ), "Currently unable to update weights in static graph mode." + self.graph_opt_backend.clear_cudagraph_piecewise_backend() diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 9f56d313c..e843753e8 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -144,3 +144,7 @@ class GraphOptBackend: return self.runnable(**kwargs) else: return self.cudagraph_piecewise_backend.__call__(**kwargs) + + def clear_cudagraph_piecewise_backend(self): + """ """ + self.cudagraph_piecewise_backend.clear_graph() diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index f240e760f..6b28226ed 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -737,6 +737,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): ) return hidden_states + def clear_grpah_opt_backend(self): + """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" + self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + class DeepSeekV3PretrainedModel(PretrainedModel): """ diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index abae0347d..41d46b7d0 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -545,6 +545,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): return hidden_states + def clear_grpah_opt_backend(self): + """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" + self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM): """ diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 600811ff3..950109c93 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -703,6 +703,10 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): return hidden_states + def clear_grpah_opt_backend(self): + """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" + self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + class Ernie4_5_VLPretrainedModel(PretrainedModel): """ diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 3682b5dc1..eaa1e26a8 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -390,6 +390,10 @@ class Qwen2ForCausalLM(ModelForCasualLM): return hidden_states + def clear_grpah_opt_backend(self): + """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" + self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config) + class Qwen2PretrainedModel(PretrainedModel): """ diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 6d4553dc1..c44a1e127 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -330,6 +330,10 @@ class Qwen3ForCausalLM(ModelForCasualLM): return hidden_states + def clear_grpah_opt_backend(self): + """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" + self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + class Qwen3PretrainedModel(PretrainedModel): """ diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 3dce5c976..f74d49d56 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -420,6 +420,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): return hidden_states + def clear_grpah_opt_backend(self): + """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" + self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + class Qwen3MoePretrainedModel(PretrainedModel): """ diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 141c2a4ab..f60a96468 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -497,11 +497,20 @@ def print_gpu_memory_use(gpu_id: int, title: str) -> None: meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) pynvml.nvmlShutdown() + paddle_max_reserved = paddle.device.cuda.max_memory_reserved(gpu_id) + paddle_max_allocated = paddle.device.cuda.max_memory_allocated(gpu_id) + paddle_reserved = paddle.device.cuda.memory_reserved(gpu_id) + paddle_allocated = paddle.device.cuda.memory_allocated(gpu_id) + print( f"\n{title}:", f"\n\tDevice Total memory: {meminfo.total}", f"\n\tDevice Used memory: {meminfo.used}", f"\n\tDevice Free memory: {meminfo.free}", + f"\n\tPaddle max memory Reserved: {paddle_max_reserved}", + f"\n\tPaddle max memory Allocated: {paddle_max_allocated}", + f"\n\tPaddle memory Reserved: {paddle_reserved}", + f"\n\tPaddle memory Allocated: {paddle_allocated}", ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7e7165c74..2f466e3eb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -876,6 +876,7 @@ class GPUModelRunner(ModelRunnerBase): # 1. Load original model model_loader = get_model_loader(load_config=self.fd_config.load_config) self.model = model_loader.load_model(fd_config=self.fd_config) + # 1.1 Load RL dynamic model if self.fd_config.load_config.dynamic_load_weight: from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager @@ -1595,12 +1596,22 @@ class GPUModelRunner(ModelRunnerBase): self.dynamic_weight_manager.clear_parameters(pid) self.clear_cache() paddle.device.cuda.empty_cache() + + # Clear CudaGraph + if self.use_cudagraph: + self.model.clear_grpah_opt_backend() + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") def update_parameters(self, pid): """ " Dynamic model loader use to update parameters use for RL""" self.dynamic_weight_manager.update_parameters(pid) self.initialize_kv_cache() + + # Recapture CudaGraph + if self.use_cudagraph: + self.capture_model() + self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def padding_cudagraph_inputs(self) -> None: diff --git a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py index 9b5f2b4c8..166b2d3ca 100644 --- a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py +++ b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import unittest + import paddle from fastdeploy.config import ( @@ -94,7 +96,7 @@ class TestModel1(paddle.nn.Layer): self.fd_config = fd_config self.sublayer1 = TestCase1SubLayer1(self.fd_config) - self.sublayer2 = TestCase1SubLayer2(self.fd_config) + self.sublayer2 = TestCase1SubLayer2(self.fd_config) # Attention self.sublayer3 = TestCase1SubLayer3(self.fd_config) self.sublayer2_output_buffer = paddle.zeros([1]) @@ -106,9 +108,7 @@ class TestModel1(paddle.nn.Layer): sublayer1_output = self.sublayer1(ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1) # sublayer2 not use cuda garph - sub_meta2 = ForwardMeta( - input_ids=sublayer1_output, ids_remove_padding=sublayer1_output, step_use_cudagraph=False - ) + sub_meta2 = ForwardMeta(input_ids=sublayer1_output, ids_remove_padding=sublayer1_output) sublayer2_output = self.sublayer2(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) self.sublayer2_output_buffer.copy_(sublayer2_output, False) @@ -142,38 +142,46 @@ class TestModel1(paddle.nn.Layer): return sublayer3_output -def run_test_case(): - """Run test case""" - # Set FastDeploy config - graph_opt_config = GraphOptimizationConfig(args={}) - graph_opt_config.use_cudagraph = True - parallel_config = ParallelConfig(args={}) - parallel_config.max_num_seqs = 1 - cache_config = CacheConfig({}) - # Initialize cuda graph capture list - graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) - fd_config = FDConfig( - graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True - ) +class TestCUDAGrpahSubgraph(unittest.TestCase): + """ + Test CUDAGraph Memory change + """ - # Run Test Case1 - test_model1 = TestModel1(fd_config=fd_config) - input_tensor1 = paddle.ones([1]) - forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) + def test_cuda_graph_subgraph(self): + """Run test case""" + # Set FastDeploy config + graph_opt_config = GraphOptimizationConfig(args={}) + graph_opt_config.use_cudagraph = True + parallel_config = ParallelConfig(args={}) + parallel_config.max_num_seqs = 1 + cache_config = CacheConfig({}) + # Initialize cuda graph capture list + graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + fd_config = FDConfig( + graph_opt_config=graph_opt_config, + parallel_config=parallel_config, + cache_config=cache_config, + test_mode=True, + ) - # Triger Capture - _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Run Test Case1 + test_model1 = TestModel1(fd_config=fd_config) + input_tensor1 = paddle.ones([32768]) + forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) - # Reaplay - _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Triger Capture + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - # Corrent output - output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Reaplay + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - assert output1 == output1_correct + # Corrent output + output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + assert sum(output1 - output1_correct) == 0 if __name__ == "__main__": - run_test_case() + unittest.main() diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py new file mode 100644 index 000000000..f3b3ff214 --- /dev/null +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -0,0 +1,131 @@ +import unittest + +import paddle + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + ParallelConfig, +) +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) +from fastdeploy.utils import print_gpu_memory_use + + +@support_graph_optimization +class TestCase1SubLayer1(paddle.nn.Layer): + """Sub layer 1 of test case 1""" + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer1 forward pass""" + + output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) + return output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Sub layer1 Correct forward pass""" + + output = paddle.add(forward_meta.input_ids, forward_meta.input_ids) + return output + + +class TestModel1(paddle.nn.Layer): + """Tast Model""" + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__() + self.fd_config = fd_config + + self.sublayer1 = TestCase1SubLayer1(self.fd_config) + sublayer1_copy = TestCase1SubLayer1(self.fd_config) + self.sublayer2 = sublayer1_copy + + def forward(self, ids_remove_padding, forward_meta: ForwardMeta): + """Test model forward pass""" + # sublayer1 use cuda graph + sub_meta1 = forward_meta + sublayer1_output = self.sublayer1(ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1) + + # sublayer2 use cuda graph + sub_meta2 = ForwardMeta( + input_ids=sublayer1_output, ids_remove_padding=sublayer1_output, step_use_cudagraph=True + ) + sublayer2_output = self.sublayer2(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) + + return sublayer2_output + + def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): + """Test model Correct forward pass""" + # sublayer1 not use cuda graph + sub_meta1 = forward_meta + sublayer1_output = self.sublayer1.forward_correct( + ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1 + ) + + # sublayer2 not use cuda graph + sub_meta2 = ForwardMeta(input_ids=sublayer1_output, ids_remove_padding=sublayer1_output) + sublayer2_output = self.sublayer2.forward_correct(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2) + + return sublayer2_output + + def clear_grpah_opt_backend(self): + """ """ + self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config) + self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config) + + +class TestCUDAGrpahRecapture(unittest.TestCase): + """ + Test CUDAGraph Memory change + """ + + def test_cuda_graph_recapture(self): + """Run test case""" + # Set FastDeploy config + graph_opt_config = GraphOptimizationConfig(args={}) + graph_opt_config.use_cudagraph = True + parallel_config = ParallelConfig(args={}) + cache_config = CacheConfig(args={}) + parallel_config.max_num_seqs = 1 + fd_config = FDConfig( + graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config + ) + + # Run Test Case1 + test_model1 = TestModel1(fd_config=fd_config) + input_tensor1 = paddle.ones([32768]) + forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) + + # Triger Capture + print_gpu_memory_use(0, "before capture") + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + print_gpu_memory_use(0, "after capture") + # Reaplay + output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Destory + print_gpu_memory_use(0, "before destory") + test_model1.clear_grpah_opt_backend() + print_gpu_memory_use(0, "after destory") + + # Triger Capture + print_gpu_memory_use(0, "before recapture") + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + print_gpu_memory_use(0, "after recapture") + # Reaplay + output2 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + # Corrent output + output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + assert sum(output1 - output2) == 0 + assert sum(output1_correct - output1) == 0 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/graph_optimization/test_cuda_graph_spec_decode.py b/tests/graph_optimization/test_cuda_graph_spec_decode.py index f3d87950c..0ae5e5fc6 100644 --- a/tests/graph_optimization/test_cuda_graph_spec_decode.py +++ b/tests/graph_optimization/test_cuda_graph_spec_decode.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import unittest + import paddle from fastdeploy.config import ( @@ -88,38 +90,45 @@ class TestModel1(paddle.nn.Layer): return sublayer2_output -def run_test_case(): - """Run test case""" - # Set FastDeploy config - graph_opt_config = GraphOptimizationConfig(args={}) - graph_opt_config.use_cudagraph = True - parallel_config = ParallelConfig(args={}) - parallel_config.max_num_seqs = 1 - cache_config = CacheConfig({}) - # Initialize cuda graph capture list - graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) - fd_config = FDConfig( - graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True - ) +class TestCUDAGrpahSpecDecode(unittest.TestCase): + """ + Test CUDAGraph Memory change + """ - # Run Test Case1 - test_model1 = TestModel1(fd_config=fd_config) - input_tensor1 = paddle.ones([1]) - forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) + def test_cuda_graph_spec_decode(self): + """Run test case""" + graph_opt_config = GraphOptimizationConfig(args={}) + graph_opt_config.use_cudagraph = True + parallel_config = ParallelConfig(args={}) + parallel_config.max_num_seqs = 1 + cache_config = CacheConfig({}) + # Initialize cuda graph capture list + graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + fd_config = FDConfig( + graph_opt_config=graph_opt_config, + parallel_config=parallel_config, + cache_config=cache_config, + test_mode=True, + ) - # Triger Capture - _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Run Test Case1 + test_model1 = TestModel1(fd_config=fd_config) + input_tensor1 = paddle.ones([32768]) + forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True) - # Reaplay - _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Triger Capture + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - # Corrent output - output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + # Reaplay + _ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - assert output1 == output1_correct + # Corrent output + output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) + + assert sum(output1 - output1_correct) == 0 if __name__ == "__main__": - run_test_case() + unittest.main()