mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Executor] CUDAGraph support RL training (#3265)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* add clear graph opt backend * cuda graph support rl * add branch * 1.fix dynamic_weight_manager bug 2.add clear api for CasualLM * open test case * fix typo * update mkdocs.yaml * [Docs]Update mkdocs.yml * update test case * use unittest in graph test case
This commit is contained in:
@@ -62,15 +62,7 @@ class CudaGraphPiecewiseBackend:
|
||||
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
|
||||
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
|
||||
|
||||
# Runtime real shape -> ConcreteSizeEntry
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
for shape in self.cudagraph_capture_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)
|
||||
|
||||
logger.info(
|
||||
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
|
||||
)
|
||||
self._create_entry_dict()
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get real shape(all num tokens)
|
||||
@@ -128,3 +120,29 @@ class CudaGraphPiecewiseBackend:
|
||||
entry.cuda_graph.replay()
|
||||
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
|
||||
return entry.output_buffer
|
||||
|
||||
def _create_entry_dict(self):
|
||||
""" """
|
||||
# Runtime real shape -> ConcreteSizeEntry
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
for shape in self.cudagraph_capture_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)
|
||||
|
||||
logger.info(
|
||||
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
|
||||
)
|
||||
|
||||
def clear_graph(self):
|
||||
""" """
|
||||
# Clear graphs
|
||||
for id, entry in self.concrete_size_entries.items():
|
||||
if entry.cuda_graph:
|
||||
del entry.cuda_graph
|
||||
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.")
|
||||
|
||||
del self.concrete_size_entries
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
# Create new entrys
|
||||
self._create_entry_dict()
|
||||
|
@@ -91,3 +91,11 @@ class GraphOptWrapper:
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
return self.graph_opt_backend(**kwargs)
|
||||
|
||||
def clear_grpah_opt_backend(self, fd_config):
|
||||
""" """
|
||||
# TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights
|
||||
assert (
|
||||
fd_config.graph_opt_config.graph_opt_level < 1
|
||||
), "Currently unable to update weights in static graph mode."
|
||||
self.graph_opt_backend.clear_cudagraph_piecewise_backend()
|
||||
|
@@ -144,3 +144,7 @@ class GraphOptBackend:
|
||||
return self.runnable(**kwargs)
|
||||
else:
|
||||
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
||||
|
||||
def clear_cudagraph_piecewise_backend(self):
|
||||
""" """
|
||||
self.cudagraph_piecewise_backend.clear_graph()
|
||||
|
@@ -737,6 +737,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
"""Clear graph optimization bakcend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
|
@@ -545,6 +545,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
"""Clear graph optimization bakcend, the captured cuda graph will be cleaned"""
|
||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
"""
|
||||
|
@@ -703,6 +703,10 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
"""Clear graph optimization bakcend, the captured cuda graph will be cleaned"""
|
||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
|
@@ -390,6 +390,10 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
"""Clear graph optimization bakcend, the captured cuda graph will be cleaned"""
|
||||
self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen2PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
|
@@ -330,6 +330,10 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
"""Clear graph optimization bakcend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen3PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
|
@@ -420,6 +420,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
"""Clear graph optimization bakcend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen3MoePretrainedModel(PretrainedModel):
|
||||
"""
|
||||
|
@@ -497,11 +497,20 @@ def print_gpu_memory_use(gpu_id: int, title: str) -> None:
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
paddle_max_reserved = paddle.device.cuda.max_memory_reserved(gpu_id)
|
||||
paddle_max_allocated = paddle.device.cuda.max_memory_allocated(gpu_id)
|
||||
paddle_reserved = paddle.device.cuda.memory_reserved(gpu_id)
|
||||
paddle_allocated = paddle.device.cuda.memory_allocated(gpu_id)
|
||||
|
||||
print(
|
||||
f"\n{title}:",
|
||||
f"\n\tDevice Total memory: {meminfo.total}",
|
||||
f"\n\tDevice Used memory: {meminfo.used}",
|
||||
f"\n\tDevice Free memory: {meminfo.free}",
|
||||
f"\n\tPaddle max memory Reserved: {paddle_max_reserved}",
|
||||
f"\n\tPaddle max memory Allocated: {paddle_max_allocated}",
|
||||
f"\n\tPaddle memory Reserved: {paddle_reserved}",
|
||||
f"\n\tPaddle memory Allocated: {paddle_allocated}",
|
||||
)
|
||||
|
||||
|
||||
|
@@ -876,6 +876,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# 1. Load original model
|
||||
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
||||
self.model = model_loader.load_model(fd_config=self.fd_config)
|
||||
|
||||
# 1.1 Load RL dynamic model
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
@@ -1595,12 +1596,22 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
# Clear CudaGraph
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
""" " Dynamic model loader use to update parameters use for RL"""
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Recapture CudaGraph
|
||||
if self.use_cudagraph:
|
||||
self.capture_model()
|
||||
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
def padding_cudagraph_inputs(self) -> None:
|
||||
|
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import (
|
||||
@@ -94,7 +96,7 @@ class TestModel1(paddle.nn.Layer):
|
||||
self.fd_config = fd_config
|
||||
|
||||
self.sublayer1 = TestCase1SubLayer1(self.fd_config)
|
||||
self.sublayer2 = TestCase1SubLayer2(self.fd_config)
|
||||
self.sublayer2 = TestCase1SubLayer2(self.fd_config) # Attention
|
||||
self.sublayer3 = TestCase1SubLayer3(self.fd_config)
|
||||
|
||||
self.sublayer2_output_buffer = paddle.zeros([1])
|
||||
@@ -106,9 +108,7 @@ class TestModel1(paddle.nn.Layer):
|
||||
sublayer1_output = self.sublayer1(ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1)
|
||||
|
||||
# sublayer2 not use cuda garph
|
||||
sub_meta2 = ForwardMeta(
|
||||
input_ids=sublayer1_output, ids_remove_padding=sublayer1_output, step_use_cudagraph=False
|
||||
)
|
||||
sub_meta2 = ForwardMeta(input_ids=sublayer1_output, ids_remove_padding=sublayer1_output)
|
||||
sublayer2_output = self.sublayer2(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2)
|
||||
self.sublayer2_output_buffer.copy_(sublayer2_output, False)
|
||||
|
||||
@@ -142,7 +142,12 @@ class TestModel1(paddle.nn.Layer):
|
||||
return sublayer3_output
|
||||
|
||||
|
||||
def run_test_case():
|
||||
class TestCUDAGrpahSubgraph(unittest.TestCase):
|
||||
"""
|
||||
Test CUDAGraph Memory change
|
||||
"""
|
||||
|
||||
def test_cuda_graph_subgraph(self):
|
||||
"""Run test case"""
|
||||
# Set FastDeploy config
|
||||
graph_opt_config = GraphOptimizationConfig(args={})
|
||||
@@ -154,12 +159,15 @@ def run_test_case():
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True
|
||||
graph_opt_config=graph_opt_config,
|
||||
parallel_config=parallel_config,
|
||||
cache_config=cache_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
# Run Test Case1
|
||||
test_model1 = TestModel1(fd_config=fd_config)
|
||||
input_tensor1 = paddle.ones([1])
|
||||
input_tensor1 = paddle.ones([32768])
|
||||
forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True)
|
||||
|
||||
# Triger Capture
|
||||
@@ -172,8 +180,8 @@ def run_test_case():
|
||||
# Corrent output
|
||||
output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
|
||||
assert output1 == output1_correct
|
||||
assert sum(output1 - output1_correct) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test_case()
|
||||
unittest.main()
|
||||
|
131
tests/graph_optimization/test_cuda_graph_recapture.py
Normal file
131
tests/graph_optimization/test_cuda_graph_recapture.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
support_graph_optimization,
|
||||
)
|
||||
from fastdeploy.utils import print_gpu_memory_use
|
||||
|
||||
|
||||
@support_graph_optimization
|
||||
class TestCase1SubLayer1(paddle.nn.Layer):
|
||||
"""Sub layer 1 of test case 1"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, ids_remove_padding, forward_meta: ForwardMeta):
|
||||
"""Sub layer1 forward pass"""
|
||||
|
||||
output = paddle.add(forward_meta.input_ids, forward_meta.input_ids)
|
||||
return output
|
||||
|
||||
def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta):
|
||||
"""Sub layer1 Correct forward pass"""
|
||||
|
||||
output = paddle.add(forward_meta.input_ids, forward_meta.input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class TestModel1(paddle.nn.Layer):
|
||||
"""Tast Model"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
|
||||
self.sublayer1 = TestCase1SubLayer1(self.fd_config)
|
||||
sublayer1_copy = TestCase1SubLayer1(self.fd_config)
|
||||
self.sublayer2 = sublayer1_copy
|
||||
|
||||
def forward(self, ids_remove_padding, forward_meta: ForwardMeta):
|
||||
"""Test model forward pass"""
|
||||
# sublayer1 use cuda graph
|
||||
sub_meta1 = forward_meta
|
||||
sublayer1_output = self.sublayer1(ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1)
|
||||
|
||||
# sublayer2 use cuda graph
|
||||
sub_meta2 = ForwardMeta(
|
||||
input_ids=sublayer1_output, ids_remove_padding=sublayer1_output, step_use_cudagraph=True
|
||||
)
|
||||
sublayer2_output = self.sublayer2(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2)
|
||||
|
||||
return sublayer2_output
|
||||
|
||||
def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta):
|
||||
"""Test model Correct forward pass"""
|
||||
# sublayer1 not use cuda graph
|
||||
sub_meta1 = forward_meta
|
||||
sublayer1_output = self.sublayer1.forward_correct(
|
||||
ids_remove_padding=ids_remove_padding, forward_meta=sub_meta1
|
||||
)
|
||||
|
||||
# sublayer2 not use cuda graph
|
||||
sub_meta2 = ForwardMeta(input_ids=sublayer1_output, ids_remove_padding=sublayer1_output)
|
||||
sublayer2_output = self.sublayer2.forward_correct(ids_remove_padding=sublayer1_output, forward_meta=sub_meta2)
|
||||
|
||||
return sublayer2_output
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
""" """
|
||||
self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||
"""
|
||||
Test CUDAGraph Memory change
|
||||
"""
|
||||
|
||||
def test_cuda_graph_recapture(self):
|
||||
"""Run test case"""
|
||||
# Set FastDeploy config
|
||||
graph_opt_config = GraphOptimizationConfig(args={})
|
||||
graph_opt_config.use_cudagraph = True
|
||||
parallel_config = ParallelConfig(args={})
|
||||
cache_config = CacheConfig(args={})
|
||||
parallel_config.max_num_seqs = 1
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config
|
||||
)
|
||||
|
||||
# Run Test Case1
|
||||
test_model1 = TestModel1(fd_config=fd_config)
|
||||
input_tensor1 = paddle.ones([32768])
|
||||
forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True)
|
||||
|
||||
# Triger Capture
|
||||
print_gpu_memory_use(0, "before capture")
|
||||
_ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
print_gpu_memory_use(0, "after capture")
|
||||
# Reaplay
|
||||
output1 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
# Destory
|
||||
print_gpu_memory_use(0, "before destory")
|
||||
test_model1.clear_grpah_opt_backend()
|
||||
print_gpu_memory_use(0, "after destory")
|
||||
|
||||
# Triger Capture
|
||||
print_gpu_memory_use(0, "before recapture")
|
||||
_ = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
print_gpu_memory_use(0, "after recapture")
|
||||
# Reaplay
|
||||
output2 = test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
|
||||
# Corrent output
|
||||
output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
|
||||
assert sum(output1 - output2) == 0
|
||||
assert sum(output1_correct - output1) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import (
|
||||
@@ -88,9 +90,13 @@ class TestModel1(paddle.nn.Layer):
|
||||
return sublayer2_output
|
||||
|
||||
|
||||
def run_test_case():
|
||||
class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
||||
"""
|
||||
Test CUDAGraph Memory change
|
||||
"""
|
||||
|
||||
def test_cuda_graph_spec_decode(self):
|
||||
"""Run test case"""
|
||||
# Set FastDeploy config
|
||||
graph_opt_config = GraphOptimizationConfig(args={})
|
||||
graph_opt_config.use_cudagraph = True
|
||||
parallel_config = ParallelConfig(args={})
|
||||
@@ -100,12 +106,15 @@ def run_test_case():
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True
|
||||
graph_opt_config=graph_opt_config,
|
||||
parallel_config=parallel_config,
|
||||
cache_config=cache_config,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
# Run Test Case1
|
||||
test_model1 = TestModel1(fd_config=fd_config)
|
||||
input_tensor1 = paddle.ones([1])
|
||||
input_tensor1 = paddle.ones([32768])
|
||||
forward_meta1 = ForwardMeta(input_ids=input_tensor1, ids_remove_padding=input_tensor1, step_use_cudagraph=True)
|
||||
|
||||
# Triger Capture
|
||||
@@ -118,8 +127,8 @@ def run_test_case():
|
||||
# Corrent output
|
||||
output1_correct = test_model1.forward_correct(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
|
||||
assert output1 == output1_correct
|
||||
assert sum(output1 - output1_correct) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test_case()
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user