mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support w4afp8 two stage (#5608)
This commit is contained in:
@@ -307,6 +307,7 @@ class DeepEPEngine:
|
||||
topk_weights: paddle.Tensor,
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
quant_group_size: int = 128,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
@@ -327,6 +328,7 @@ class DeepEPEngine:
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
num_per_channel=quant_group_size,
|
||||
)
|
||||
|
||||
return packed_recv_x, packed_recv_count, handle, dispatch_hook
|
||||
@@ -363,6 +365,7 @@ class DeepEPEngine:
|
||||
topk_idx: paddle.Tensor,
|
||||
topk_weights: paddle.Tensor,
|
||||
dispatch_use_fp8: bool,
|
||||
quant_group_size: int,
|
||||
handle,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
@@ -376,6 +379,7 @@ class DeepEPEngine:
|
||||
async_finish=False,
|
||||
dispatch_use_fp8=dispatch_use_fp8,
|
||||
return_recv_hook=True,
|
||||
num_per_channel=quant_group_size,
|
||||
)
|
||||
return combined_hidden_states, combine_hook
|
||||
|
||||
@@ -663,21 +667,29 @@ class EPDecoderRunner(EPRunner):
|
||||
# just supports dispatch_use_fp8 = True now!
|
||||
assert use_fp8 is True
|
||||
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
|
||||
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
|
||||
self.ep_engine.low_latency_dispatch_two_stage(
|
||||
x, topk_idx, topk_weights, expertwise_scale, use_fp8, quant_group_size
|
||||
)
|
||||
)
|
||||
if dispatch_hook is not None:
|
||||
dispatch_hook()
|
||||
|
||||
return recv_hidden_states, recv_expert_count, handle
|
||||
|
||||
def combine(self, ffn_out, topk_idx, topk_weights, handle):
|
||||
def combine(self, ffn_out, topk_idx, topk_weights, handle, **kwargs):
|
||||
quant_group_size = kwargs.get("quant_group_size", 128)
|
||||
if not self.use_internode_ll_two_stage:
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
|
||||
ffn_out, topk_idx, topk_weights, handle
|
||||
)
|
||||
else:
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
|
||||
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
|
||||
ffn_out,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
True,
|
||||
quant_group_size,
|
||||
handle, # just supports dispatch_use_fp8 = True now!
|
||||
)
|
||||
if combine_hook is not None:
|
||||
combine_hook()
|
||||
|
||||
@@ -275,7 +275,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
)
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
|
||||
return self.ep_decoder_runner.combine(
|
||||
ffn_out, topk_idx, topk_weights, handle, quant_group_size=quant_group_size
|
||||
)
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user