diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 287532922..ba9db317e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -541,6 +541,8 @@ class ParallelConfig: self.engine_pid: Optional[int] = None # Do profile or not self.do_profile: bool = False + # Use internode_ll_two_stage or not + self.use_internode_ll_two_stage: bool = False self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2ce7482d0..ccd82d7b2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -237,6 +237,11 @@ class EngineArgs: Flag to enable the custom all-reduce kernel. """ + use_internode_ll_two_stage: bool = False + """ + Flag to use the internode_ll_two_stage kernel. + """ + engine_worker_queue_port: str = "0" """ Port for worker queue communication. @@ -721,6 +726,12 @@ class EngineArgs: default=EngineArgs.disable_custom_all_reduce, help="Flag to disable custom all-reduce.", ) + parallel_group.add_argument( + "--use-internode-ll-two-stage", + action="store_true", + default=EngineArgs.use_internode_ll_two_stage, + help="Flag to use the internode_ll_two_stage kernel.", + ) parallel_group.add_argument( "--max-num-seqs", type=int, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e4c0b717a..1c6ec8787 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -563,6 +563,7 @@ class LLMEngine: "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, "disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace, "disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce, + "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, } diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index eb5742d98..1733ad5de 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -64,6 +64,8 @@ class DeepEPBuffer: num_max_dispatch_tokens_per_rank: int, splitwise_role: str, moe_phase: MoEPhase, + use_internode_ll_two_stage: bool = False, + top_k: int = 8, ): self.group = group self.hidden_size = hidden_size @@ -72,6 +74,8 @@ class DeepEPBuffer: self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank self.splitwise_role = splitwise_role self.moe_phase = moe_phase + self.use_internode_ll_two_stage = use_internode_ll_two_stage + self.top_k = top_k self.deepep_buffer = None self.num_nvl_bytes = 0 @@ -95,12 +99,26 @@ class DeepEPBuffer: ) if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode": - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.ep_size, - self.num_experts, - ) + if not self.use_internode_ll_two_stage: + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.ep_size, + self.num_experts, + ) + else: + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint_two_stage( + self.num_max_dispatch_tokens_per_rank, self.hidden_size, self.ep_size, self.num_experts, self.top_k + ) + num_nvl_bytes = deep_ep.Buffer.get_low_latency_nvl_size_hint_two_stage( + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.ep_size, + self.num_experts, + self.top_k, + True, # just supports dispatch_use_fp8 = True now! + ) + self.num_nvl_bytes = max(self.num_nvl_bytes, num_nvl_bytes) self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes) logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}") @@ -128,9 +146,9 @@ class DeepEPBuffer: self.deepep_buffer = deep_ep.Buffer( self.group, self.num_nvl_bytes, - 0, - low_latency_mode=False, - num_qps_per_rank=1, + self.num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=24, ) else: raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}") @@ -138,26 +156,18 @@ class DeepEPBuffer: logger.info("DeepEP buffer created successfully.") def _create_low_latency_buffer(self): - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.ep_size, - self.num_experts, - ) - - if ( - self.deepep_buffer is None - or self.deepep_buffer.group != self.group - or not self.deepep_buffer.low_latency_mode - or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes - ): + if self.deepep_buffer is None: assert self.num_experts % self.ep_size == 0 + if self.ep_size // 8 > 1: + num_qps_per_rank_now = self.ep_size // 8 + else: + num_qps_per_rank_now = 1 self.deepep_buffer = deep_ep.Buffer( self.group, - 0, - num_rdma_bytes, + self.num_nvl_bytes, + self.num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=self.num_experts // self.ep_size, + num_qps_per_rank=num_qps_per_rank_now, ) def clear_buffer(self): @@ -172,11 +182,21 @@ class DeepEPBuffer: def clean_low_latency_buffer(self): if self.deepep_buffer is not None: - self.deepep_buffer.clean_low_latency_buffer( - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.num_experts, - ) + if not self.use_internode_ll_two_stage: + self.deepep_buffer.clean_low_latency_buffer( + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + ) + else: + self.deepep_buffer.clean_low_latency_two_stage_buffer( + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + self.top_k, + self.ep_size, + True, # just supports dispatch_use_fp8 = True now! + ) def barrier_all(self): if self.deepep_buffer is not None: @@ -201,6 +221,8 @@ class DeepEPEngine: moe_phase: MoEPhase, async_finish: bool = False, group=None, + use_internode_ll_two_stage: bool = False, + top_k: int = 8, ): if group is None: group = paddle.distributed.new_group(range(ep_size)) @@ -210,6 +232,7 @@ class DeepEPEngine: self.hidden_size = hidden_size self.num_experts = num_experts self.num_local_experts = num_experts // ep_size + self.top_k = top_k self.async_finish = async_finish self.ep_config = None @@ -227,6 +250,8 @@ class DeepEPEngine: num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, splitwise_role=splitwise_role, moe_phase=moe_phase, + use_internode_ll_two_stage=use_internode_ll_two_stage, + top_k=self.top_k, ) self.buffer.create_buffer() @@ -273,6 +298,37 @@ class DeepEPEngine: return packed_recv_x, recv_expert_count, handle, dispatch_hook + def low_latency_dispatch_two_stage( + self, + hidden_states: paddle.Tensor, + topk_idx: paddle.Tensor, + topk_weights: paddle.Tensor, + expertwise_scale, + use_fp8: bool = False, + ): + if self.deepep_engine is None: + raise RuntimeError("DeepEP buffer not initialized!") + + ( + packed_recv_x, + packed_recv_count, + _, + handle, + _, + dispatch_hook, + ) = self.deepep_engine.low_latency_dispatch_two_stage( + hidden_states, + topk_idx, + topk_weights, + self.buffer.num_max_dispatch_tokens_per_rank, + self.num_experts, + use_fp8=use_fp8, + async_finish=False, + return_recv_hook=True, + ) + + return packed_recv_x, packed_recv_count, handle, dispatch_hook + def low_latency_combine( self, hidden_states: paddle.Tensor, @@ -299,6 +355,28 @@ class DeepEPEngine: ) return combined_hidden_states, combine_hook + def low_latency_combine_two_stage( + self, + hidden_states: paddle.Tensor, + topk_idx: paddle.Tensor, + topk_weights: paddle.Tensor, + dispatch_use_fp8: bool, + handle, + ): + if self.deepep_engine is None: + raise RuntimeError("DeepEP buffer not initialized!") + + combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine_two_stage( + hidden_states, + topk_idx, + topk_weights, + handle, + async_finish=False, + dispatch_use_fp8=dispatch_use_fp8, + return_recv_hook=True, + ) + return combined_hidden_states, combine_hook + def clean_low_latency_buffer(self): self.buffer.clean_low_latency_buffer() @@ -323,10 +401,12 @@ class EPRunner: ep_rank: int = 0, redundant_experts_num: int = 0, ep_group=None, + use_internode_ll_two_stage: bool = False, ): self.top_k = top_k self.num_experts = num_experts self.redundant_experts_num = redundant_experts_num + self.use_internode_ll_two_stage = use_internode_ll_two_stage self.ep_engine = DeepEPEngine( num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, hidden_size=hidden_size, @@ -336,6 +416,8 @@ class EPRunner: splitwise_role=splitwise_role, moe_phase=moe_phase, group=ep_group, + use_internode_ll_two_stage=self.use_internode_ll_two_stage, + top_k=self.top_k, ) def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): @@ -416,6 +498,7 @@ class EPPrefillRunner(EPRunner): redundant_experts_num: int = 0, moe_phase: MoEPhase = MoEPhase("prefill"), ep_group=None, + use_internode_ll_two_stage: bool = False, ): super().__init__( top_k, @@ -428,6 +511,7 @@ class EPPrefillRunner(EPRunner): ep_rank=ep_rank, redundant_experts_num=redundant_experts_num, ep_group=ep_group, + use_internode_ll_two_stage=use_internode_ll_two_stage, ) def dispatch( @@ -504,6 +588,7 @@ class EPDecoderRunner(EPRunner): redundant_experts_num: int = 0, ep_group=None, moe_phase: MoEPhase = MoEPhase("decode"), + use_internode_ll_two_stage: bool = False, ): super().__init__( top_k, @@ -516,6 +601,7 @@ class EPDecoderRunner(EPRunner): ep_rank=ep_rank, redundant_experts_num=redundant_experts_num, ep_group=ep_group, + use_internode_ll_two_stage=use_internode_ll_two_stage, ) def dispatch( @@ -529,18 +615,30 @@ class EPDecoderRunner(EPRunner): expertwise_scale = kwargs.get("expertwise_scale", None) use_fp8 = kwargs.get("use_fp8", False) - recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch( - x, topk_idx, expertwise_scale, use_fp8 - ) + if not self.use_internode_ll_two_stage: + recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch( + x, topk_idx, expertwise_scale, use_fp8 + ) + else: + # just supports dispatch_use_fp8 = True now! + assert use_fp8 is True + recv_hidden_states, recv_expert_count, handle, dispatch_hook = ( + self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8) + ) if dispatch_hook is not None: dispatch_hook() return recv_hidden_states, recv_expert_count, handle def combine(self, ffn_out, topk_idx, topk_weights, handle): - combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine( - ffn_out, topk_idx, topk_weights, handle - ) + if not self.use_internode_ll_two_stage: + combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine( + ffn_out, topk_idx, topk_weights, handle + ) + else: + combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage( + ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now! + ) if combine_hook is not None: combine_hook() diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 8c35fa83a..41b06962d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -92,8 +92,18 @@ class MoEMethodBase(QuantMethodBase): # for RL init model without deepep buff return else: - self.ep_prefill_runner = self.EPPrefillRunner(**common_args) - self.ep_decoder_runner = self.EPDecoderRunner(**common_args) + if current_platform.is_cuda(): + self.ep_prefill_runner = self.EPPrefillRunner( + **common_args, + use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage, + ) + self.ep_decoder_runner = self.EPDecoderRunner( + **common_args, + use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage, + ) + else: + self.ep_prefill_runner = self.EPPrefillRunner(**common_args) + self.ep_decoder_runner = self.EPDecoderRunner(**common_args) return # For non-mixed ep diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8cedfea37..fb55cddf2 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -636,6 +636,11 @@ def parse_args(): action="store_true", help="enable chunked prefill", ) + parser.add_argument( + "--use_internode_ll_two_stage", + action="store_true", + help="enable internode_ll_two_stage", + ) parser.add_argument( "--speculative_config", type=json.loads,