mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Graph Optimization] Add the CUDAGraph usage switch for Draft Model (#4601)
* add draft model using cudagraph switch * set default as false * capture draft model in ci * fix bug
This commit is contained in:
@@ -1884,35 +1884,37 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info(
|
||||
f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}"
|
||||
)
|
||||
# Capture Draft Model without bsz 1
|
||||
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
if batch_size == 1:
|
||||
logger.info("Skip token_num = 1, when capture Draft model for mtp")
|
||||
else:
|
||||
assert batch_size % 2 == 0
|
||||
if self.graph_opt_config.draft_model_use_cudagraph:
|
||||
# Capture Draft Model without bsz 1
|
||||
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
if batch_size == 1:
|
||||
logger.info("Skip token_num = 1, when capture Draft model for mtp")
|
||||
else:
|
||||
assert batch_size % 2 == 0
|
||||
self._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=int(batch_size / 2),
|
||||
in_capturing=True,
|
||||
expected_decode_len=3,
|
||||
accept_all_drafts=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
|
||||
)
|
||||
# Capture Draft Model with bsz 1
|
||||
if 1 in capture_sizes:
|
||||
self._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=int(batch_size / 2),
|
||||
batch_size=int(1),
|
||||
in_capturing=True,
|
||||
expected_decode_len=3,
|
||||
accept_all_drafts=True,
|
||||
accept_all_drafts=False,
|
||||
reject_all_drafts=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
|
||||
)
|
||||
# Capture Draft Model with bsz 1
|
||||
if 1 in capture_sizes:
|
||||
self._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=int(1),
|
||||
in_capturing=True,
|
||||
expected_decode_len=3,
|
||||
accept_all_drafts=False,
|
||||
reject_all_drafts=True,
|
||||
)
|
||||
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
|
||||
|
||||
else:
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
|
||||
Reference in New Issue
Block a user