mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Quantization][Cherry-Pick] Support w4afp8 moe weight offline permute & load and DeepEP low latency two stage(#5613 #5608) (#5677)
* support w4afp8 moe offline permute & load (#5613) * support w4afp8 two stage (#5608) * fix
This commit is contained in:
@@ -307,6 +307,7 @@ class DeepEPEngine:
|
||||
topk_weights: paddle.Tensor,
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
quant_group_size: int = 128,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
@@ -327,6 +328,7 @@ class DeepEPEngine:
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
num_per_channel=quant_group_size,
|
||||
)
|
||||
|
||||
return packed_recv_x, packed_recv_count, handle, dispatch_hook
|
||||
@@ -363,6 +365,7 @@ class DeepEPEngine:
|
||||
topk_idx: paddle.Tensor,
|
||||
topk_weights: paddle.Tensor,
|
||||
dispatch_use_fp8: bool,
|
||||
quant_group_size: int,
|
||||
handle,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
@@ -376,6 +379,7 @@ class DeepEPEngine:
|
||||
async_finish=False,
|
||||
dispatch_use_fp8=dispatch_use_fp8,
|
||||
return_recv_hook=True,
|
||||
num_per_channel=quant_group_size,
|
||||
)
|
||||
return combined_hidden_states, combine_hook
|
||||
|
||||
@@ -644,21 +648,29 @@ class EPDecoderRunner(EPRunner):
|
||||
# just supports dispatch_use_fp8 = True now!
|
||||
assert use_fp8 is True
|
||||
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
|
||||
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
|
||||
self.ep_engine.low_latency_dispatch_two_stage(
|
||||
x, topk_idx, topk_weights, expertwise_scale, use_fp8, quant_group_size
|
||||
)
|
||||
)
|
||||
if dispatch_hook is not None:
|
||||
dispatch_hook()
|
||||
|
||||
return recv_hidden_states, recv_expert_count, handle
|
||||
|
||||
def combine(self, ffn_out, topk_idx, topk_weights, handle):
|
||||
def combine(self, ffn_out, topk_idx, topk_weights, handle, **kwargs):
|
||||
quant_group_size = kwargs.get("quant_group_size", 128)
|
||||
if not self.use_internode_ll_two_stage:
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
|
||||
ffn_out, topk_idx, topk_weights, handle
|
||||
)
|
||||
else:
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
|
||||
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
|
||||
ffn_out,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
True,
|
||||
quant_group_size,
|
||||
handle, # just supports dispatch_use_fp8 = True now!
|
||||
)
|
||||
if combine_hook is not None:
|
||||
combine_hook()
|
||||
|
||||
@@ -263,7 +263,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
)
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
|
||||
return self.ep_decoder_runner.combine(
|
||||
ffn_out, topk_idx, topk_weights, handle, quant_group_size=quant_group_size
|
||||
)
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
@@ -759,8 +761,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.load_experts_weight(
|
||||
@@ -780,7 +783,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if layer.ep_size > 1:
|
||||
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
@@ -813,44 +816,54 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
down_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scale,
|
||||
"down_proj_in_scale": down_proj_in_scale,
|
||||
}
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scale,
|
||||
"down_proj_in_scale": down_proj_in_scale,
|
||||
}
|
||||
else:
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
@@ -1007,11 +1020,27 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
# weight_scales
|
||||
if layer.is_quantized:
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
up_gate_proj_weight_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
||||
down_proj_weight_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
||||
else:
|
||||
up_gate_proj_weight_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // 128,
|
||||
layer.hidden_size // 128,
|
||||
128,
|
||||
]
|
||||
down_proj_weight_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 128,
|
||||
layer.moe_intermediate_size // 128,
|
||||
128,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
shape=up_gate_proj_weight_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
@@ -1020,7 +1049,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
layer,
|
||||
"down_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
shape=down_proj_weight_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
|
||||
@@ -204,8 +204,10 @@ class FusedMoE(nn.Layer):
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
self.is_quantized = fd_config.model_config.is_quantized and not (
|
||||
fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.moe_quant_type is None
|
||||
self.is_moe_quantized = getattr(self.fd_config.model_config, "is_moe_quantized", False)
|
||||
self.is_quantized = self.is_moe_quantized or (
|
||||
fd_config.model_config.is_quantized
|
||||
and not (fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.moe_quant_type is None)
|
||||
)
|
||||
moe_quant_config = fd_config.quant_config
|
||||
self.moe_quant_config = moe_quant_config
|
||||
|
||||
@@ -40,6 +40,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
is_quantized: bool = False,
|
||||
hadamard_block_size: int = 128,
|
||||
moe_dynamic_quant: bool = False,
|
||||
is_moe_quantized: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -59,6 +60,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.is_quantized = is_quantized
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
self.moe_dynamic_quant = moe_dynamic_quant
|
||||
self.is_moe_quantized = is_moe_quantized
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
@@ -76,6 +78,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("is_quantized", False),
|
||||
config.get("hadamard_block_size", 128),
|
||||
config.get("moe_dynamic_quant", False),
|
||||
config.get("is_moe_quantized", False),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -102,7 +105,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"is_quantized": not self.is_checkpoint_bf16,
|
||||
"is_quantized": not self.is_checkpoint_bf16 or self.is_moe_quantized,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -151,6 +151,7 @@ class _RecordingBuffer:
|
||||
use_fp8,
|
||||
async_finish,
|
||||
return_recv_hook,
|
||||
num_per_channel,
|
||||
):
|
||||
call = {
|
||||
"hidden_states": hidden_states,
|
||||
@@ -161,6 +162,7 @@ class _RecordingBuffer:
|
||||
"use_fp8": use_fp8,
|
||||
"async_finish": async_finish,
|
||||
"return_recv_hook": return_recv_hook,
|
||||
"num_per_channel": num_per_channel,
|
||||
"hook_called": False,
|
||||
}
|
||||
self.low_latency_dispatch_two_stage_calls.append(call)
|
||||
@@ -204,6 +206,7 @@ class _RecordingBuffer:
|
||||
async_finish,
|
||||
dispatch_use_fp8,
|
||||
return_recv_hook,
|
||||
num_per_channel,
|
||||
):
|
||||
call = {
|
||||
"hidden_states": hidden_states,
|
||||
@@ -213,6 +216,7 @@ class _RecordingBuffer:
|
||||
"async_finish": async_finish,
|
||||
"dispatch_use_fp8": dispatch_use_fp8,
|
||||
"return_recv_hook": return_recv_hook,
|
||||
"num_per_channel": num_per_channel,
|
||||
"hook_called": False,
|
||||
}
|
||||
self.low_latency_combine_two_stage_calls.append(call)
|
||||
|
||||
Reference in New Issue
Block a user