From 1bf4fc7f366c31042a48ce9b0e2c1504594d1e7e Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Fri, 29 Aug 2025 14:43:06 +0800 Subject: [PATCH] support w4afp8 eplb (#3680) --- .../layers/moe/fused_moe_cutlass_backend.py | 97 ++++++++++++++++--- 1 file changed, 81 insertions(+), 16 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index be38b56cb..01049e1e5 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -661,7 +661,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): self.moe_quant_type = "w4afp8" self.pack_num = 2 - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): """ Paddle cutlass process prequanted weights. """ @@ -677,6 +677,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, + is_rearrange, ) ) @@ -686,22 +687,62 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): up_gate_proj_in_scale = [] down_proj_in_scale = [] + if isinstance(state_dict, list): + state_dict = dict(state_dict) + if layer.ep_size > 1: for expert_idx in ep_rank_to_expert_id_list: - scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]) + scale_tensor = get_tensor( + ( + state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)] + if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict + else up_gate_proj_expert_in_scale_key.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) up_gate_proj_in_scale_all_experts.append(scale_tensor) for expert_idx in logical_expert_ids: up_gate_proj_weight_scale.append( - get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))) + get_tensor( + ( + state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)) + if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict + else up_gate_proj_expert_weight_scale_key.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) ) down_proj_weight_scale.append( - get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))) + get_tensor( + ( + state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)) + if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict + else down_proj_expert_weight_scale_key.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) ) up_gate_proj_in_scale.append( - get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))) + get_tensor( + ( + state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)) + if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict + else up_gate_proj_expert_in_scale_key.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) + ) + down_proj_in_scale.append( + get_tensor( + ( + state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx)) + if down_proj_expert_in_scale_key.format(expert_idx) in state_dict + else down_proj_expert_in_scale_key.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) ) - down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx)))) up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0) down_proj_weight = paddle.stack(down_proj_weights, axis=0) @@ -763,7 +804,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): """ Paddle cutlass load weight process. """ - up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = ( + layer.extract_moe_ffn_weights(state_dict) + ) self.check(layer, up_gate_proj_weights, down_proj_weights) for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] @@ -774,7 +817,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): quanted_weight = paddle.stack(weight_list, axis=0) getattr(layer, weight_name).set_value(quanted_weight) - self.load_w4afp8_scale_weights(layer, layer.weight_key_map, state_dict) + self.load_w4afp8_scale_weights( + layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list + ) def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): """ @@ -828,7 +873,14 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): ), ) - def load_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict): + def load_w4afp8_scale_weights( + self, + layer: nn.Layer, + weight_key_map: dict, + state_dict: dict, + logical_expert_ids: paddle.Tensor, + ep_rank_to_expert_id_list: list, + ): """ Get w4afp8 weights from state dict and process them. Args: @@ -837,8 +889,15 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): state_dict (dict): The state dict. """ - def _extract_scale_tensor(state_dict, key_template, expert_idx): - return get_tensor(state_dict.pop(key_template.format(expert_idx))) + def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx): + return get_tensor( + ( + state_dict.pop(key_template.format(expert_idx)) + if key_template.format(expert_idx) in state_dict + else key_template.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) def _process_in_scale(name: str, in_scales: list[paddle.Tensor]): processed_in_scale = 1 / paddle.concat(in_scales) @@ -881,17 +940,23 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): # 2. Extract scale tensor from state dict if layer.ep_size > 1: - for expert_idx in range(layer.num_experts): - scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]) + for expert_idx in ep_rank_to_expert_id_list: + scale_tensor = get_tensor( + ( + state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)] + if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict + else scale_key_map["up_gate_proj_in_scale"].format(expert_idx) + ), + layer.fd_config.model_config.model, + ) up_gate_proj_in_scales_all_experts.append(1 / scale_tensor) getattr(layer, "up_gate_proj_in_scale_all_experts").set_value( paddle.concat(up_gate_proj_in_scales_all_experts) ) - for local_expert_idx in range(layer.num_local_experts): - expert_idx = local_expert_idx + layer.expert_id_offset + for expert_idx in logical_expert_ids: for name, scale_key_template in scale_key_map.items(): - scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx) + scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) scale_weight_map[name].append(scale_tensor) # 3. Process scale tensor and set to layer