mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feat] support mixed ep (#2969)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* Support mixed ep * fix comment * fix comment * update mixep * fix conflict * fix typo * update * fix typo * fix code style * fix conflict
This commit is contained in:
@@ -18,7 +18,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
@@ -30,13 +29,24 @@ from fastdeploy.utils import get_logger
|
|||||||
logger = get_logger("config", "config.log")
|
logger = get_logger("config", "config.log")
|
||||||
|
|
||||||
|
|
||||||
class MoEPhase(Enum):
|
class MoEPhase:
|
||||||
"""
|
"""
|
||||||
The generation phase of the moe.
|
The generation phase of the moe.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PREFILL = 1
|
def __init__(self, phase="prefill"):
|
||||||
DECODER = 2
|
self._phase = phase
|
||||||
|
|
||||||
|
@property
|
||||||
|
def phase(self):
|
||||||
|
return self._phase
|
||||||
|
|
||||||
|
@phase.setter
|
||||||
|
def phase(self, value):
|
||||||
|
if value not in ["prefill", "decode"]:
|
||||||
|
raise ValueError(f"The moe_phase is invalid, only support prefill and decode, but got {value}")
|
||||||
|
else:
|
||||||
|
self._phase = value
|
||||||
|
|
||||||
|
|
||||||
class ErnieArchitectures:
|
class ErnieArchitectures:
|
||||||
@@ -146,7 +156,7 @@ class ParallelConfig:
|
|||||||
):
|
):
|
||||||
self.sequence_parallel = False # Whether to enable sequence parallelism.
|
self.sequence_parallel = False # Whether to enable sequence parallelism.
|
||||||
self.use_ep = False # Whether to enable Expert Parallelism
|
self.use_ep = False # Whether to enable Expert Parallelism
|
||||||
self.moe_phase = MoEPhase.PREFILL # Generation phase
|
self.moe_phase = MoEPhase("prefill") # Generation phase
|
||||||
self.msg_queue_id = 1 # mesage queue id
|
self.msg_queue_id = 1 # mesage queue id
|
||||||
|
|
||||||
self.tensor_parallel_rank = 0 # TP rank ID
|
self.tensor_parallel_rank = 0 # TP rank ID
|
||||||
@@ -210,11 +220,11 @@ class ParallelConfig:
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
self.use_ep = args["expert_parallel_size"] > 1
|
self.use_ep = args["expert_parallel_size"] > 1
|
||||||
if self.splitwise_role == "mixed":
|
if self.splitwise_role == "mixed":
|
||||||
self.moe_phase = MoEPhase.PREFILL
|
self.moe_phase = MoEPhase(phase="prefill")
|
||||||
elif self.splitwise_role == "prefill":
|
elif self.splitwise_role == "prefill":
|
||||||
self.moe_phase = MoEPhase.PREFILL
|
self.moe_phase = MoEPhase(phase="prefill")
|
||||||
elif self.splitwise_role == "decode":
|
elif self.splitwise_role == "decode":
|
||||||
self.moe_phase = MoEPhase.DECODER
|
self.moe_phase = MoEPhase(phase="decode")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -43,9 +43,10 @@ class DeepEPEngine:
|
|||||||
num_max_dispatch_tokens_per_rank: int,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
hidden: int,
|
hidden: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
moe_phase: MoEPhase,
|
|
||||||
ep_size: int,
|
ep_size: int,
|
||||||
ep_rank: int,
|
ep_rank: int,
|
||||||
|
splitwise_role: str,
|
||||||
|
moe_phase: MoEPhase,
|
||||||
async_finish: bool = False,
|
async_finish: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -65,24 +66,42 @@ class DeepEPEngine:
|
|||||||
self.hidden = hidden
|
self.hidden = hidden
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_local_experts = num_experts // ep_size
|
self.num_local_experts = num_experts // ep_size
|
||||||
self.moe_phase = moe_phase
|
|
||||||
self.async_finish = async_finish
|
self.async_finish = async_finish
|
||||||
|
|
||||||
self.deepep_engine = None
|
self.prefill_deepep_engine = None
|
||||||
|
self.decode_deepep_engine = None
|
||||||
|
|
||||||
if moe_phase == MoEPhase.DECODER:
|
self.ep_config = Config(24, 6, 256)
|
||||||
logger.info("Initializing Low Latency Buffer")
|
|
||||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||||
|
|
||||||
|
# In mixed EP mode on a single node, we dynamically switch between
|
||||||
|
# high throughput and low latency modes.
|
||||||
|
if splitwise_role == "mixed":
|
||||||
|
# decode engine
|
||||||
|
logger.info("Initializing Low Latency Buffer")
|
||||||
self.get_low_latency_buffer()
|
self.get_low_latency_buffer()
|
||||||
elif moe_phase == MoEPhase.PREFILL:
|
# prefill engine
|
||||||
self.deepep_engine = deep_ep.Buffer(
|
self.prefill_deepep_engine = deep_ep.Buffer(
|
||||||
|
self.group,
|
||||||
|
int(5e8),
|
||||||
|
0,
|
||||||
|
low_latency_mode=False,
|
||||||
|
num_qps_per_rank=1,
|
||||||
|
)
|
||||||
|
# In disaggregated mode on mutiple nodes, we either use
|
||||||
|
# high throughput mode or low latency mode.
|
||||||
|
else:
|
||||||
|
if moe_phase.phase == "decode":
|
||||||
|
logger.info("Initializing Low Latency Buffer")
|
||||||
|
self.get_low_latency_buffer()
|
||||||
|
elif moe_phase.phase == "prefill":
|
||||||
|
self.prefill_deepep_engine = deep_ep.Buffer(
|
||||||
self.group,
|
self.group,
|
||||||
int(5e8),
|
int(5e8),
|
||||||
0,
|
0,
|
||||||
low_latency_mode=False,
|
low_latency_mode=False,
|
||||||
num_qps_per_rank=1,
|
num_qps_per_rank=1,
|
||||||
)
|
)
|
||||||
self.ep_config = Config(24, 6, 256)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown generation phase {moe_phase}")
|
raise ValueError(f"Unknown generation phase {moe_phase}")
|
||||||
|
|
||||||
@@ -105,14 +124,14 @@ class DeepEPEngine:
|
|||||||
)
|
)
|
||||||
# Allocate a buffer if not existed or not enough buffer size
|
# Allocate a buffer if not existed or not enough buffer size
|
||||||
if (
|
if (
|
||||||
self.deepep_engine is None
|
self.decode_deepep_engine is None
|
||||||
or self.deepep_engine.group != self.group
|
or self.decode_deepep_engine.group != self.group
|
||||||
or not self.deepep_engine.low_latency_mode
|
or not self.decode_deepep_engine.low_latency_mode
|
||||||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||||
):
|
):
|
||||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||||
assert self.num_experts % self.ep_size == 0
|
assert self.num_experts % self.ep_size == 0
|
||||||
self.deepep_engine = deep_ep.Buffer(
|
self.decode_deepep_engine = deep_ep.Buffer(
|
||||||
self.group,
|
self.group,
|
||||||
0,
|
0,
|
||||||
num_rdma_bytes,
|
num_rdma_bytes,
|
||||||
@@ -149,7 +168,7 @@ class DeepEPEngine:
|
|||||||
handle,
|
handle,
|
||||||
_,
|
_,
|
||||||
dispatch_hook,
|
dispatch_hook,
|
||||||
) = self.deepep_engine.low_latency_dispatch(
|
) = self.decode_deepep_engine.low_latency_dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expertwise_scale,
|
expertwise_scale,
|
||||||
@@ -174,8 +193,22 @@ class DeepEPEngine:
|
|||||||
Return:
|
Return:
|
||||||
combined_hidden_states: [num_tokens, hidden]
|
combined_hidden_states: [num_tokens, hidden]
|
||||||
"""
|
"""
|
||||||
|
# TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed
|
||||||
|
(
|
||||||
|
src_info,
|
||||||
|
layout_range,
|
||||||
|
num_max_dispatch_tokens_per_rank,
|
||||||
|
num_experts,
|
||||||
|
) = handle
|
||||||
|
handle = (
|
||||||
|
src_info,
|
||||||
|
layout_range,
|
||||||
|
num_max_dispatch_tokens_per_rank,
|
||||||
|
None,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@@ -189,7 +222,7 @@ class DeepEPEngine:
|
|||||||
"""
|
"""
|
||||||
clean_low_latency_buffer
|
clean_low_latency_buffer
|
||||||
"""
|
"""
|
||||||
self.deepep_engine.clean_low_latency_buffer(
|
self.decode_deepep_engine.clean_low_latency_buffer(
|
||||||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -197,7 +230,11 @@ class DeepEPEngine:
|
|||||||
"""
|
"""
|
||||||
barrier_all
|
barrier_all
|
||||||
"""
|
"""
|
||||||
self.deepep_engine.barrier_all()
|
if self.prefill_deepep_engine is not None:
|
||||||
|
self.prefill_deepep_engine.barrier_all()
|
||||||
|
|
||||||
|
if self.decode_deepep_engine is not None:
|
||||||
|
self.decode_deepep_engine.barrier_all()
|
||||||
|
|
||||||
|
|
||||||
class EPRunner:
|
class EPRunner:
|
||||||
@@ -210,6 +247,7 @@ class EPRunner:
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
hidden: int,
|
hidden: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
splitwise_role: str,
|
||||||
moe_phase: MoEPhase,
|
moe_phase: MoEPhase,
|
||||||
num_max_dispatch_tokens_per_rank: int = 1,
|
num_max_dispatch_tokens_per_rank: int = 1,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
@@ -223,9 +261,10 @@ class EPRunner:
|
|||||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||||
hidden=hidden,
|
hidden=hidden,
|
||||||
num_experts=num_experts + redundant_experts_num,
|
num_experts=num_experts + redundant_experts_num,
|
||||||
moe_phase=moe_phase,
|
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
|
splitwise_role=splitwise_role,
|
||||||
|
moe_phase=moe_phase,
|
||||||
)
|
)
|
||||||
|
|
||||||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
||||||
@@ -286,15 +325,19 @@ class EPPrefillRunner(EPRunner):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
hidden: int,
|
hidden: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
splitwise_role: str,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
redundant_experts_num: int = 0,
|
redundant_experts_num: int = 0,
|
||||||
|
moe_phase: MoEPhase = MoEPhase("prefill"),
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
top_k,
|
top_k,
|
||||||
hidden,
|
hidden,
|
||||||
num_experts,
|
num_experts,
|
||||||
MoEPhase.PREFILL,
|
splitwise_role,
|
||||||
|
moe_phase,
|
||||||
|
num_max_dispatch_tokens_per_rank=256,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
redundant_experts_num=redundant_experts_num,
|
redundant_experts_num=redundant_experts_num,
|
||||||
@@ -314,7 +357,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
num_tokens_per_expert,
|
num_tokens_per_expert,
|
||||||
is_token_in_rank,
|
is_token_in_rank,
|
||||||
_,
|
_,
|
||||||
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||||
|
|
||||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||||
dispatch_args = {
|
dispatch_args = {
|
||||||
@@ -327,7 +370,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
"topk_idx": topk_idx,
|
"topk_idx": topk_idx,
|
||||||
"topk_weights": topk_weights,
|
"topk_weights": topk_weights,
|
||||||
}
|
}
|
||||||
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
|
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
self,
|
self,
|
||||||
@@ -342,7 +385,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
"async_finish": self.ep_engine.async_finish,
|
"async_finish": self.ep_engine.async_finish,
|
||||||
"topk_weights": recv_topk_weights,
|
"topk_weights": recv_topk_weights,
|
||||||
}
|
}
|
||||||
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
|
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
|
||||||
|
|
||||||
return fused_moe_out
|
return fused_moe_out
|
||||||
|
|
||||||
@@ -357,16 +400,19 @@ class EPDecoderRunner(EPRunner):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
hidden: int,
|
hidden: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
splitwise_role: str,
|
||||||
num_max_dispatch_tokens_per_rank: int,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
redundant_experts_num: int = 0,
|
redundant_experts_num: int = 0,
|
||||||
|
moe_phase: MoEPhase = MoEPhase("decode"),
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
top_k,
|
top_k,
|
||||||
hidden,
|
hidden,
|
||||||
num_experts,
|
num_experts,
|
||||||
MoEPhase.DECODER,
|
splitwise_role,
|
||||||
|
moe_phase,
|
||||||
num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
|
@@ -19,8 +19,6 @@ from abc import abstractmethod
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
from fastdeploy.config import MoEPhase
|
|
||||||
|
|
||||||
from ..quantization.quant_base import QuantMethodBase
|
from ..quantization.quant_base import QuantMethodBase
|
||||||
|
|
||||||
|
|
||||||
@@ -45,25 +43,50 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
Init EP related module
|
Init EP related module
|
||||||
"""
|
"""
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
if layer.fd_config.parallel_config.moe_phase == MoEPhase.DECODER:
|
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||||
from .ep import EPDecoderRunner
|
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||||
|
|
||||||
|
self.ep_prefill_runner = EPPrefillRunner(
|
||||||
|
layer.top_k,
|
||||||
|
layer.hidden_size,
|
||||||
|
layer.num_experts,
|
||||||
|
layer.fd_config.parallel_config.splitwise_role,
|
||||||
|
layer.ep_size,
|
||||||
|
layer.ep_rank,
|
||||||
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
|
)
|
||||||
self.ep_decoder_runner = EPDecoderRunner(
|
self.ep_decoder_runner = EPDecoderRunner(
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
|
layer.fd_config.parallel_config.splitwise_role,
|
||||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||||
layer.ep_size,
|
layer.ep_size,
|
||||||
layer.ep_rank,
|
layer.ep_rank,
|
||||||
layer.fd_config.model_config.redundant_experts_num,
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if layer.fd_config.parallel_config.moe_phase == "prefill":
|
||||||
from .ep import EPPrefillRunner
|
from .ep import EPPrefillRunner
|
||||||
|
|
||||||
self.ep_prefill_runner = EPPrefillRunner(
|
self.ep_prefill_runner = EPPrefillRunner(
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
|
layer.fd_config.parallel_config.splitwise_role,
|
||||||
|
layer.ep_size,
|
||||||
|
layer.ep_rank,
|
||||||
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from .ep import EPDecoderRunner
|
||||||
|
|
||||||
|
self.ep_decoder_runner = EPDecoderRunner(
|
||||||
|
layer.top_k,
|
||||||
|
layer.hidden_size,
|
||||||
|
layer.num_experts,
|
||||||
|
layer.moe_config.num_max_dispatch_tokens_per_rank,
|
||||||
|
layer.fd_config.parallel_config.splitwise_role,
|
||||||
layer.ep_size,
|
layer.ep_size,
|
||||||
layer.ep_rank,
|
layer.ep_rank,
|
||||||
layer.fd_config.model_config.redundant_experts_num,
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
@@ -141,7 +164,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
if layer.fd_config.parallel_config.moe_phase == MoEPhase.PREFILL:
|
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||||
return self.apply_ep_prefill(layer, x, gate_out)
|
return self.apply_ep_prefill(layer, x, gate_out)
|
||||||
else:
|
else:
|
||||||
return self.apply_ep_decode(layer, x, gate_out)
|
return self.apply_ep_decode(layer, x, gate_out)
|
||||||
|
@@ -794,6 +794,14 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# Update Batch type for cuda graph
|
# Update Batch type for cuda graph
|
||||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
|
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
|
||||||
|
|
||||||
|
# mix ep in single node
|
||||||
|
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||||
|
is_decode_batch_list = []
|
||||||
|
paddle.distributed.all_gather_object(is_decode_batch_list, is_decode_batch)
|
||||||
|
is_decode_batch = all(is_decode_batch_list)
|
||||||
|
self.fd_config.parallel_config.moe_phase.phase = "decode" if is_decode_batch else "prefill"
|
||||||
|
|
||||||
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||||
|
|
||||||
# Initialzie attention meta data
|
# Initialzie attention meta data
|
||||||
@@ -1163,16 +1171,18 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
We plan to replace it with 'ModelForwardBatch'.
|
We plan to replace it with 'ModelForwardBatch'.
|
||||||
intermediate_tensors:
|
intermediate_tensors:
|
||||||
"""
|
"""
|
||||||
# NOTE(wufeisheng): For Expert Parallelism
|
|
||||||
if not self.not_need_stop():
|
|
||||||
self._execute_empty_input()
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 1. Prepare inputs of model and sampler.
|
# 1. Prepare inputs of model and sampler.
|
||||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||||
self._prepare_inputs()
|
self._prepare_inputs()
|
||||||
self.sampler.pre_process(skip_idx_list)
|
self.sampler.pre_process(skip_idx_list)
|
||||||
|
|
||||||
|
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||||
|
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||||
|
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||||
|
if not self.not_need_stop():
|
||||||
|
self._execute_empty_input()
|
||||||
|
return None
|
||||||
|
|
||||||
# 2. Padding inputs for cuda graph
|
# 2. Padding inputs for cuda graph
|
||||||
self.padding_cudagraph_inputs()
|
self.padding_cudagraph_inputs()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user