From 04035e4ebf278cd07f01246cde66a7df395831d0 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Mon, 22 Dec 2025 15:13:05 +0800 Subject: [PATCH] support w4afp8 two stage (#5608) --- fastdeploy/model_executor/layers/moe/ep.py | 18 +++++++++++++++--- .../layers/moe/fused_moe_cutlass_backend.py | 4 +++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 4065de51f..49c4b4f4a 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -307,6 +307,7 @@ class DeepEPEngine: topk_weights: paddle.Tensor, expertwise_scale, use_fp8: bool = False, + quant_group_size: int = 128, ): if self.deepep_engine is None: raise RuntimeError("DeepEP buffer not initialized!") @@ -327,6 +328,7 @@ class DeepEPEngine: use_fp8=use_fp8, async_finish=False, return_recv_hook=True, + num_per_channel=quant_group_size, ) return packed_recv_x, packed_recv_count, handle, dispatch_hook @@ -363,6 +365,7 @@ class DeepEPEngine: topk_idx: paddle.Tensor, topk_weights: paddle.Tensor, dispatch_use_fp8: bool, + quant_group_size: int, handle, ): if self.deepep_engine is None: @@ -376,6 +379,7 @@ class DeepEPEngine: async_finish=False, dispatch_use_fp8=dispatch_use_fp8, return_recv_hook=True, + num_per_channel=quant_group_size, ) return combined_hidden_states, combine_hook @@ -663,21 +667,29 @@ class EPDecoderRunner(EPRunner): # just supports dispatch_use_fp8 = True now! assert use_fp8 is True recv_hidden_states, recv_expert_count, handle, dispatch_hook = ( - self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8) + self.ep_engine.low_latency_dispatch_two_stage( + x, topk_idx, topk_weights, expertwise_scale, use_fp8, quant_group_size + ) ) if dispatch_hook is not None: dispatch_hook() return recv_hidden_states, recv_expert_count, handle - def combine(self, ffn_out, topk_idx, topk_weights, handle): + def combine(self, ffn_out, topk_idx, topk_weights, handle, **kwargs): + quant_group_size = kwargs.get("quant_group_size", 128) if not self.use_internode_ll_two_stage: combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine( ffn_out, topk_idx, topk_weights, handle ) else: combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage( - ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now! + ffn_out, + topk_idx, + topk_weights, + True, + quant_group_size, + handle, # just supports dispatch_use_fp8 = True now! ) if combine_hook is not None: combine_hook() diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 999599673..06e4591be 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -275,7 +275,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): ) # 4. EP combine - return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle) + return self.ep_decoder_runner.combine( + ffn_out, topk_idx, topk_weights, handle, quant_group_size=quant_group_size + ) def apply_tp( self,