[Quantization][Cherry-Pick] Support w4afp8 moe weight offline permute & load and DeepEP low latency two stage(#5613 #5608) (#5677)

* support w4afp8 moe offline permute & load (#5613)

* support w4afp8 two stage (#5608)

* fix
This commit is contained in:
Sunny-bot1
2025-12-23 16:04:08 +08:00
committed by GitHub
parent 52280bee61
commit cfddec7142
5 changed files with 92 additions and 42 deletions

View File

@@ -307,6 +307,7 @@ class DeepEPEngine:
topk_weights: paddle.Tensor,
expertwise_scale,
use_fp8: bool = False,
quant_group_size: int = 128,
):
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
@@ -327,6 +328,7 @@ class DeepEPEngine:
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
num_per_channel=quant_group_size,
)
return packed_recv_x, packed_recv_count, handle, dispatch_hook
@@ -363,6 +365,7 @@ class DeepEPEngine:
topk_idx: paddle.Tensor,
topk_weights: paddle.Tensor,
dispatch_use_fp8: bool,
quant_group_size: int,
handle,
):
if self.deepep_engine is None:
@@ -376,6 +379,7 @@ class DeepEPEngine:
async_finish=False,
dispatch_use_fp8=dispatch_use_fp8,
return_recv_hook=True,
num_per_channel=quant_group_size,
)
return combined_hidden_states, combine_hook
@@ -644,21 +648,29 @@ class EPDecoderRunner(EPRunner):
# just supports dispatch_use_fp8 = True now!
assert use_fp8 is True
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
self.ep_engine.low_latency_dispatch_two_stage(
x, topk_idx, topk_weights, expertwise_scale, use_fp8, quant_group_size
)
)
if dispatch_hook is not None:
dispatch_hook()
return recv_hidden_states, recv_expert_count, handle
def combine(self, ffn_out, topk_idx, topk_weights, handle):
def combine(self, ffn_out, topk_idx, topk_weights, handle, **kwargs):
quant_group_size = kwargs.get("quant_group_size", 128)
if not self.use_internode_ll_two_stage:
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
ffn_out, topk_idx, topk_weights, handle
)
else:
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
ffn_out,
topk_idx,
topk_weights,
True,
quant_group_size,
handle, # just supports dispatch_use_fp8 = True now!
)
if combine_hook is not None:
combine_hook()

View File

@@ -263,7 +263,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
)
# 4. EP combine
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
return self.ep_decoder_runner.combine(
ffn_out, topk_idx, topk_weights, handle, quant_group_size=quant_group_size
)
def apply_tp(
self,
@@ -759,8 +761,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.load_experts_weight(
@@ -780,7 +783,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
if isinstance(state_dict, list):
state_dict = dict(state_dict)
if layer.ep_size > 1:
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
@@ -813,44 +816,54 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
layer.fd_config.model_config.model,
)
)
up_gate_proj_in_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_in_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
)
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
if not layer.moe_quant_config.moe_dynamic_quant:
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
else:
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
@@ -1007,11 +1020,27 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
# weight_scales
if layer.is_quantized:
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_weight_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
down_proj_weight_scale_shape = [layer.num_local_experts, layer.hidden_size]
else:
up_gate_proj_weight_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2 // 128,
layer.hidden_size // 128,
128,
]
down_proj_weight_scale_shape = [
layer.num_local_experts,
layer.hidden_size // 128,
layer.moe_intermediate_size // 128,
128,
]
setattr(
layer,
"up_gate_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
shape=up_gate_proj_weight_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
@@ -1020,7 +1049,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
layer,
"down_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
shape=down_proj_weight_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),

View File

@@ -204,8 +204,10 @@ class FusedMoE(nn.Layer):
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
self.is_quantized = fd_config.model_config.is_quantized and not (
fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.moe_quant_type is None
self.is_moe_quantized = getattr(self.fd_config.model_config, "is_moe_quantized", False)
self.is_quantized = self.is_moe_quantized or (
fd_config.model_config.is_quantized
and not (fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.moe_quant_type is None)
)
moe_quant_config = fd_config.quant_config
self.moe_quant_config = moe_quant_config

View File

@@ -40,6 +40,7 @@ class MixQuantConfig(QuantConfigBase):
is_quantized: bool = False,
hadamard_block_size: int = 128,
moe_dynamic_quant: bool = False,
is_moe_quantized: bool = False,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
@@ -59,6 +60,7 @@ class MixQuantConfig(QuantConfigBase):
self.is_quantized = is_quantized
self.hadamard_block_size = hadamard_block_size
self.moe_dynamic_quant = moe_dynamic_quant
self.is_moe_quantized = is_moe_quantized
def name(self) -> str:
return "mix_quant"
@@ -76,6 +78,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_quantized", False),
config.get("hadamard_block_size", 128),
config.get("moe_dynamic_quant", False),
config.get("is_moe_quantized", False),
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -102,7 +105,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"is_quantized": not self.is_checkpoint_bf16,
"is_quantized": not self.is_checkpoint_bf16 or self.is_moe_quantized,
"hadamard_block_size": self.hadamard_block_size,
}
)

View File

@@ -151,6 +151,7 @@ class _RecordingBuffer:
use_fp8,
async_finish,
return_recv_hook,
num_per_channel,
):
call = {
"hidden_states": hidden_states,
@@ -161,6 +162,7 @@ class _RecordingBuffer:
"use_fp8": use_fp8,
"async_finish": async_finish,
"return_recv_hook": return_recv_hook,
"num_per_channel": num_per_channel,
"hook_called": False,
}
self.low_latency_dispatch_two_stage_calls.append(call)
@@ -204,6 +206,7 @@ class _RecordingBuffer:
async_finish,
dispatch_use_fp8,
return_recv_hook,
num_per_channel,
):
call = {
"hidden_states": hidden_states,
@@ -213,6 +216,7 @@ class _RecordingBuffer:
"async_finish": async_finish,
"dispatch_use_fp8": dispatch_use_fp8,
"return_recv_hook": return_recv_hook,
"num_per_channel": num_per_channel,
"hook_called": False,
}
self.low_latency_combine_two_stage_calls.append(call)