mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Supports DP+TP+EP hybrid parallel deployment strategy (#3489)
* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
This commit is contained in:
@@ -1117,7 +1117,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampler_output.sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
else:
|
||||
self.sampler(
|
||||
logits,
|
||||
@@ -1127,10 +1131,26 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
sampler_output = None
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_num"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_num"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["step_idx"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
# 5. post process
|
||||
model_output_data = ModelOutputData(
|
||||
@@ -1149,7 +1169,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
full_hidden_states=model_output,
|
||||
msg_queue_id=self.parallel_config.msg_queue_id,
|
||||
mp_rank=self.local_rank,
|
||||
mp_rank=self.parallel_config.tensor_parallel_rank,
|
||||
use_ep=self.parallel_config.use_ep,
|
||||
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
|
||||
actual_draft_token_num=(
|
||||
@@ -1200,13 +1220,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
return
|
||||
for task in tasks:
|
||||
if task.get("prefill_chunk_info", None) is None:
|
||||
continue
|
||||
|
||||
if task.chunk_idx > len(task.prefill_chunk_info):
|
||||
continue
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
if tasks is not None:
|
||||
for task in tasks:
|
||||
if task.get("prefill_chunk_info", None) is None:
|
||||
continue
|
||||
|
||||
if task.chunk_idx > len(task.prefill_chunk_info):
|
||||
continue
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
idx = task.idx
|
||||
@@ -1384,7 +1406,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
skip_idx_list,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampler_output.sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
else:
|
||||
self.sampler(
|
||||
@@ -1395,10 +1421,26 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
sampler_output = None
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_num"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_num"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["step_idx"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
# 5. Post Process
|
||||
model_output_data = ModelOutputData(
|
||||
@@ -1417,7 +1459,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
full_hidden_states=model_output,
|
||||
msg_queue_id=self.parallel_config.msg_queue_id,
|
||||
mp_rank=self.local_rank,
|
||||
mp_rank=self.parallel_config.tensor_parallel_rank,
|
||||
use_ep=self.parallel_config.use_ep,
|
||||
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
|
||||
actual_draft_token_num=(
|
||||
@@ -1454,7 +1496,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.proposer.run(share_inputs=self.share_inputs)
|
||||
|
||||
# 7. Updata 'infer_seed' and step_cuda()
|
||||
# 7. Update 'infer_seed' and step_cuda()
|
||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
|
||||
|
||||
|
Reference in New Issue
Block a user