mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
small change in test_fusedmoe.py (#4538)
This commit is contained in:
@@ -185,12 +185,18 @@ class TestFusedMoE(unittest.TestCase):
|
||||
|
||||
cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16)
|
||||
|
||||
num_layers = 80
|
||||
|
||||
def fake_model_run():
|
||||
for _ in range(num_layers):
|
||||
out = fused_moe.fused_moe(cache_hidden_states[idx], gating)
|
||||
|
||||
return out
|
||||
|
||||
moe_cuda_graphs[idx] = graphs.CUDAGraph()
|
||||
moe_cuda_graphs[idx].capture_begin()
|
||||
|
||||
num_layers = 80
|
||||
for _ in range(num_layers):
|
||||
out = fused_moe.fused_moe(cache_hidden_states[idx], gating)
|
||||
fake_model_run()
|
||||
|
||||
moe_cuda_graphs[idx].capture_end()
|
||||
|
||||
@@ -213,7 +219,6 @@ class TestFusedMoE(unittest.TestCase):
|
||||
print(times[-1], round(GB / times_s, 1))
|
||||
|
||||
shutil.rmtree(self.model_name_or_path)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user