supports internode_ll_two_stage (#4143)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* supports internode_ll_two_stage

* supports internode_ll_two_stage

* supports internode_ll_two_stage

* supports internode_ll_two_stage
This commit is contained in:
lzy
2025-09-22 14:55:06 +08:00
committed by GitHub
parent f75697c2d1
commit be98f6e950
6 changed files with 144 additions and 19 deletions

View File

@@ -294,6 +294,8 @@ class ParallelConfig:
self.engine_pid: Optional[int] = None self.engine_pid: Optional[int] = None
# Do profile or not # Do profile or not
self.do_profile: bool = False self.do_profile: bool = False
# Use internode_ll_two_stage or not
self.use_internode_ll_two_stage: bool = False
self.max_num_batched_tokens: int = 2048 self.max_num_batched_tokens: int = 2048
# splitwise role # splitwise role

View File

@@ -200,6 +200,11 @@ class EngineArgs:
Flag to enable the custom all-reduce kernel. Flag to enable the custom all-reduce kernel.
""" """
use_internode_ll_two_stage: bool = False
"""
Flag to use the internode_ll_two_stage kernel.
"""
engine_worker_queue_port: str = "8002" engine_worker_queue_port: str = "8002"
""" """
Port for worker queue communication. Port for worker queue communication.
@@ -629,6 +634,12 @@ class EngineArgs:
default=EngineArgs.disable_custom_all_reduce, default=EngineArgs.disable_custom_all_reduce,
help="Flag to disable custom all-reduce.", help="Flag to disable custom all-reduce.",
) )
parallel_group.add_argument(
"--use-internode-ll-two-stage",
action="store_true",
default=EngineArgs.use_internode_ll_two_stage,
help="Flag to use the internode_ll_two_stage kernel.",
)
parallel_group.add_argument( parallel_group.add_argument(
"--max-num-seqs", "--max-num-seqs",
type=int, type=int,

View File

@@ -483,6 +483,7 @@ class LLMEngine:
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.disable_any_whitespace, "disable_any_whitespace": self.cfg.disable_any_whitespace,
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce, "disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
"enable_logprob": self.cfg.model_config.enable_logprob, "enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32, "lm_head_fp32": self.cfg.model_config.lm_head_fp32,
} }

View File

@@ -64,6 +64,8 @@ class DeepEPBuffer:
num_max_dispatch_tokens_per_rank: int, num_max_dispatch_tokens_per_rank: int,
splitwise_role: str, splitwise_role: str,
moe_phase: MoEPhase, moe_phase: MoEPhase,
use_internode_ll_two_stage: bool = False,
top_k: int = 8,
): ):
self.group = group self.group = group
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -72,6 +74,8 @@ class DeepEPBuffer:
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
self.splitwise_role = splitwise_role self.splitwise_role = splitwise_role
self.moe_phase = moe_phase self.moe_phase = moe_phase
self.use_internode_ll_two_stage = use_internode_ll_two_stage
self.top_k = top_k
self.deepep_buffer = None self.deepep_buffer = None
self.num_nvl_bytes = 0 self.num_nvl_bytes = 0
@@ -95,12 +99,26 @@ class DeepEPBuffer:
) )
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode": if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
if not self.use_internode_ll_two_stage:
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
self.num_max_dispatch_tokens_per_rank, self.num_max_dispatch_tokens_per_rank,
self.hidden_size, self.hidden_size,
self.ep_size, self.ep_size,
self.num_experts, self.num_experts,
) )
else:
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint_two_stage(
self.num_max_dispatch_tokens_per_rank, self.hidden_size, self.ep_size, self.num_experts, self.top_k
)
num_nvl_bytes = deep_ep.Buffer.get_low_latency_nvl_size_hint_two_stage(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.ep_size,
self.num_experts,
self.top_k,
True, # just supports dispatch_use_fp8 = True now!
)
self.num_nvl_bytes = max(self.num_nvl_bytes, num_nvl_bytes)
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes) self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}") logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
@@ -172,11 +190,21 @@ class DeepEPBuffer:
def clean_low_latency_buffer(self): def clean_low_latency_buffer(self):
if self.deepep_buffer is not None: if self.deepep_buffer is not None:
if not self.use_internode_ll_two_stage:
self.deepep_buffer.clean_low_latency_buffer( self.deepep_buffer.clean_low_latency_buffer(
self.num_max_dispatch_tokens_per_rank, self.num_max_dispatch_tokens_per_rank,
self.hidden_size, self.hidden_size,
self.num_experts, self.num_experts,
) )
else:
self.deepep_buffer.clean_low_latency_two_stage_buffer(
self.num_max_dispatch_tokens_per_rank,
self.hidden_size,
self.num_experts,
self.top_k,
self.ep_size,
True, # just supports dispatch_use_fp8 = True now!
)
def barrier_all(self): def barrier_all(self):
if self.deepep_buffer is not None: if self.deepep_buffer is not None:
@@ -201,6 +229,8 @@ class DeepEPEngine:
moe_phase: MoEPhase, moe_phase: MoEPhase,
async_finish: bool = False, async_finish: bool = False,
group=None, group=None,
use_internode_ll_two_stage: bool = False,
top_k: int = 8,
): ):
if group is None: if group is None:
group = paddle.distributed.new_group(range(ep_size)) group = paddle.distributed.new_group(range(ep_size))
@@ -210,10 +240,10 @@ class DeepEPEngine:
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size self.num_local_experts = num_experts // ep_size
self.top_k = top_k
self.async_finish = async_finish self.async_finish = async_finish
from paddle.base.core import Config
self.ep_config = Config(24, 6, 256) self.ep_config = None
# Store phase and role for buffer management # Store phase and role for buffer management
self._splitwise_role = splitwise_role self._splitwise_role = splitwise_role
@@ -228,6 +258,8 @@ class DeepEPEngine:
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
splitwise_role=splitwise_role, splitwise_role=splitwise_role,
moe_phase=moe_phase, moe_phase=moe_phase,
use_internode_ll_two_stage=use_internode_ll_two_stage,
top_k=self.top_k,
) )
self.buffer.create_buffer() self.buffer.create_buffer()
@@ -274,6 +306,37 @@ class DeepEPEngine:
return packed_recv_x, recv_expert_count, handle, dispatch_hook return packed_recv_x, recv_expert_count, handle, dispatch_hook
def low_latency_dispatch_two_stage(
self,
hidden_states: paddle.Tensor,
topk_idx: paddle.Tensor,
topk_weights: paddle.Tensor,
expertwise_scale,
use_fp8: bool = False,
):
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
(
packed_recv_x,
packed_recv_count,
_,
handle,
_,
dispatch_hook,
) = self.deepep_engine.low_latency_dispatch_two_stage(
hidden_states,
topk_idx,
topk_weights,
self.buffer.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
)
return packed_recv_x, packed_recv_count, handle, dispatch_hook
def low_latency_combine( def low_latency_combine(
self, self,
hidden_states: paddle.Tensor, hidden_states: paddle.Tensor,
@@ -300,6 +363,28 @@ class DeepEPEngine:
) )
return combined_hidden_states, combine_hook return combined_hidden_states, combine_hook
def low_latency_combine_two_stage(
self,
hidden_states: paddle.Tensor,
topk_idx: paddle.Tensor,
topk_weights: paddle.Tensor,
dispatch_use_fp8: bool,
handle,
):
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine_two_stage(
hidden_states,
topk_idx,
topk_weights,
handle,
async_finish=False,
dispatch_use_fp8=dispatch_use_fp8,
return_recv_hook=True,
)
return combined_hidden_states, combine_hook
def clean_low_latency_buffer(self): def clean_low_latency_buffer(self):
self.buffer.clean_low_latency_buffer() self.buffer.clean_low_latency_buffer()
@@ -324,10 +409,12 @@ class EPRunner:
ep_rank: int = 0, ep_rank: int = 0,
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
ep_group=None, ep_group=None,
use_internode_ll_two_stage: bool = False,
): ):
self.top_k = top_k self.top_k = top_k
self.num_experts = num_experts self.num_experts = num_experts
self.redundant_experts_num = redundant_experts_num self.redundant_experts_num = redundant_experts_num
self.use_internode_ll_two_stage = use_internode_ll_two_stage
self.ep_engine = DeepEPEngine( self.ep_engine = DeepEPEngine(
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
hidden_size=hidden_size, hidden_size=hidden_size,
@@ -337,6 +424,8 @@ class EPRunner:
splitwise_role=splitwise_role, splitwise_role=splitwise_role,
moe_phase=moe_phase, moe_phase=moe_phase,
group=ep_group, group=ep_group,
use_internode_ll_two_stage=self.use_internode_ll_two_stage,
top_k=self.top_k,
) )
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -416,6 +505,7 @@ class EPPrefillRunner(EPRunner):
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
moe_phase: MoEPhase = MoEPhase("prefill"), moe_phase: MoEPhase = MoEPhase("prefill"),
ep_group=None, ep_group=None,
use_internode_ll_two_stage: bool = False,
): ):
super().__init__( super().__init__(
top_k, top_k,
@@ -428,6 +518,7 @@ class EPPrefillRunner(EPRunner):
ep_rank=ep_rank, ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num, redundant_experts_num=redundant_experts_num,
ep_group=ep_group, ep_group=ep_group,
use_internode_ll_two_stage=use_internode_ll_two_stage,
) )
def dispatch( def dispatch(
@@ -502,6 +593,7 @@ class EPDecoderRunner(EPRunner):
redundant_experts_num: int = 0, redundant_experts_num: int = 0,
ep_group=None, ep_group=None,
moe_phase: MoEPhase = MoEPhase("decode"), moe_phase: MoEPhase = MoEPhase("decode"),
use_internode_ll_two_stage: bool = False,
): ):
super().__init__( super().__init__(
top_k, top_k,
@@ -514,6 +606,7 @@ class EPDecoderRunner(EPRunner):
ep_rank=ep_rank, ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num, redundant_experts_num=redundant_experts_num,
ep_group=ep_group, ep_group=ep_group,
use_internode_ll_two_stage=use_internode_ll_two_stage,
) )
def dispatch( def dispatch(
@@ -527,18 +620,30 @@ class EPDecoderRunner(EPRunner):
expertwise_scale = kwargs.get("expertwise_scale", None) expertwise_scale = kwargs.get("expertwise_scale", None)
use_fp8 = kwargs.get("use_fp8", False) use_fp8 = kwargs.get("use_fp8", False)
if not self.use_internode_ll_two_stage:
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch( recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
x, topk_idx, expertwise_scale, use_fp8 x, topk_idx, expertwise_scale, use_fp8
) )
else:
# just supports dispatch_use_fp8 = True now!
assert use_fp8 is True
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
)
if dispatch_hook is not None: if dispatch_hook is not None:
dispatch_hook() dispatch_hook()
return recv_hidden_states, recv_expert_count, handle return recv_hidden_states, recv_expert_count, handle
def combine(self, ffn_out, topk_idx, topk_weights, handle): def combine(self, ffn_out, topk_idx, topk_weights, handle):
if not self.use_internode_ll_two_stage:
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine( combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
ffn_out, topk_idx, topk_weights, handle ffn_out, topk_idx, topk_weights, handle
) )
else:
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
)
if combine_hook is not None: if combine_hook is not None:
combine_hook() combine_hook()

View File

@@ -64,6 +64,7 @@ class MoEMethodBase(QuantMethodBase):
"ep_rank": layer.ep_rank, "ep_rank": layer.ep_rank,
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num, "redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
"ep_group": layer.fd_config.parallel_config.ep_group, "ep_group": layer.fd_config.parallel_config.ep_group,
"use_internode_ll_two_stage": layer.fd_config.parallel_config.use_internode_ll_two_stage,
} }
config = layer.fd_config config = layer.fd_config

View File

@@ -506,6 +506,11 @@ def parse_args():
action="store_true", action="store_true",
help="enable chunked prefill", help="enable chunked prefill",
) )
parser.add_argument(
"--use_internode_ll_two_stage",
action="store_true",
help="enable internode_ll_two_stage",
)
parser.add_argument( parser.add_argument(
"--speculative_config", "--speculative_config",
type=json.loads, type=json.loads,