support w4afp8 two stage (#5608)

This commit is contained in:
Sunny-bot1
2025-12-22 15:13:05 +08:00
committed by GitHub
parent 40f3897a4e
commit 04035e4ebf
2 changed files with 18 additions and 4 deletions

View File

@@ -307,6 +307,7 @@ class DeepEPEngine:
topk_weights: paddle.Tensor,
expertwise_scale,
use_fp8: bool = False,
quant_group_size: int = 128,
):
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
@@ -327,6 +328,7 @@ class DeepEPEngine:
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
num_per_channel=quant_group_size,
)
return packed_recv_x, packed_recv_count, handle, dispatch_hook
@@ -363,6 +365,7 @@ class DeepEPEngine:
topk_idx: paddle.Tensor,
topk_weights: paddle.Tensor,
dispatch_use_fp8: bool,
quant_group_size: int,
handle,
):
if self.deepep_engine is None:
@@ -376,6 +379,7 @@ class DeepEPEngine:
async_finish=False,
dispatch_use_fp8=dispatch_use_fp8,
return_recv_hook=True,
num_per_channel=quant_group_size,
)
return combined_hidden_states, combine_hook
@@ -663,21 +667,29 @@ class EPDecoderRunner(EPRunner):
# just supports dispatch_use_fp8 = True now!
assert use_fp8 is True
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
self.ep_engine.low_latency_dispatch_two_stage(
x, topk_idx, topk_weights, expertwise_scale, use_fp8, quant_group_size
)
)
if dispatch_hook is not None:
dispatch_hook()
return recv_hidden_states, recv_expert_count, handle
def combine(self, ffn_out, topk_idx, topk_weights, handle):
def combine(self, ffn_out, topk_idx, topk_weights, handle, **kwargs):
quant_group_size = kwargs.get("quant_group_size", 128)
if not self.use_internode_ll_two_stage:
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
ffn_out, topk_idx, topk_weights, handle
)
else:
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
ffn_out,
topk_idx,
topk_weights,
True,
quant_group_size,
handle, # just supports dispatch_use_fp8 = True now!
)
if combine_hook is not None:
combine_hook()

View File

@@ -275,7 +275,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
)
# 4. EP combine
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
return self.ep_decoder_runner.combine(
ffn_out, topk_idx, topk_weights, handle, quant_group_size=quant_group_size
)
def apply_tp(
self,