mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[FDConfig]Remove max_num_batched_tokens/max_num_seqs in parallel config (#4116)
* remove max_num_batched_tokens in parallel config * remove max_num_seqs * update test case * fix test * fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -121,9 +121,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
|
||||
# Initialize share inputs
|
||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||
self._init_share_inputs(self.scheduler_config.max_num_seqs)
|
||||
self.infer_seed_increment = paddle.full(
|
||||
shape=[self.parallel_config.max_num_seqs, 1],
|
||||
shape=[self.scheduler_config.max_num_seqs, 1],
|
||||
fill_value=4,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
@@ -995,7 +995,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
decode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
@@ -1242,7 +1242,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
in_capturing=True,
|
||||
expected_decode_len=expected_decode_len,
|
||||
@@ -1257,7 +1257,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
start_time = time.perf_counter()
|
||||
for batch_size in self.sot_warmup_sizes:
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
@@ -1489,8 +1489,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
|
||||
# 2. Dummy run
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
batch_size=min(self.parallel_config.max_num_seqs, 3),
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=min(self.scheduler_config.max_num_seqs, 3),
|
||||
)
|
||||
|
||||
# 3. gc
|
||||
|
Reference in New Issue
Block a user