diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 63b1b15be..03a0ad9ad 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -59,8 +59,8 @@ class ExpertService: self.cfg.disaggregate_info = None self.scheduler = cfg.scheduler_config.scheduler() - - self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") + if cfg.splitwise_role != "mixed": + self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id diff --git a/fastdeploy/model_executor/layers/moe/check_backend_supported.py b/fastdeploy/model_executor/layers/moe/check_backend_supported.py new file mode 100644 index 000000000..1255843d8 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/check_backend_supported.py @@ -0,0 +1,31 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase +from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( + CutlassMoEMethod, +) +from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import ( + BlockWiseFP8MoEMethod, + TensorWiseFP8MoEMethod, + TritonWeightOnlyMoEMethod, +) + +pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod) + + +def is_supported_moe_backend(quant_method: MoEMethodBase): + return isinstance(quant_method, pre_create_weights_list) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 914853e89..6a57b2007 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -19,7 +19,7 @@ from abc import abstractmethod import paddle from paddle import nn -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.layers.utils import set_weight_attrs from fastdeploy.platforms import current_platform from ..quantization.quant_base import QuantMethodBase diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 8b1bf6a97..9c21fbb98 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -23,7 +23,7 @@ import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.platforms import current_platform -from ..utils import create_and_set_parameter, get_tensor +from ..utils import get_tensor from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): @@ -202,7 +202,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): gate_out = gate(x.cast("float32")) # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) - expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None) + expertwise_scale = None + if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8 + expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None) + # 2. EP Dispatch permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( x, topk_idx, topk_weights, expertwise_scale=expertwise_scale @@ -382,12 +385,48 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): "down_proj_in_scale": down_proj_in_scale, } for name, tensor in name_tensor_map.items(): - create_and_set_parameter(layer, name, tensor) + getattr(layer, name).set_value(tensor) - def create_weights(self, layer: nn.Layer, state_dict): + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ Paddle cutlass create weight process. """ + self.weight_dtype = "int8" + self.ffn1_weight_shape = [ + layer.num_local_experts, + layer.hidden_size // 2, + layer.moe_intermediate_size * 2, + ] + self.ffn2_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size // 2, + layer.hidden_size, + ] + setattr( + layer, + self.added_weight_attrs[0], + layer.create_parameter( + shape=self.ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_weight_attrs[1], + layer.create_parameter( + shape=self.ffn2_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + self.create_w4a8_scale_weights(layer, layer.weight_key_map) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Paddle cutlass load weight process. + """ up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) self.check(layer, up_gate_proj_weights, down_proj_weights) for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): @@ -397,11 +436,63 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80) weight_list.append(quant_weight) quanted_weight = paddle.stack(weight_list, axis=0) - create_and_set_parameter(layer, weight_name, quanted_weight) + getattr(layer, weight_name).set_value(quanted_weight) - self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict) + self.load_w4a8_scale_weights(layer, layer.weight_key_map, state_dict) - def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict): + def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): + """ + Get w4a8 weights from state dict and process them. + Args: + layer (nn.Layer): The layer to add parameters to. + weight_key_map (dict): The weight key map. + state_dict (dict): The state dict. + """ + self.default_dtype = layer._helper.get_default_dtype() + if layer.ep_size > 1: + setattr( + layer, + "up_gate_proj_in_scale_all_experts", + layer.create_parameter( + shape=[layer.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + # in_scales + for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: + setattr( + layer, + in_scale_name, + layer.create_parameter( + shape=[layer.num_local_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + # weight_scales + setattr( + layer, + "up_gate_proj_weight_scale", + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + "down_proj_weight_scale", + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict): """ Get w4a8 weights from state dict and process them. Args: @@ -415,7 +506,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): def _process_in_scale(name: str, in_scales: list[paddle.Tensor]): processed_in_scale = 1 / paddle.concat(in_scales) - create_and_set_parameter(layer, name, processed_in_scale) + getattr(layer, name).set_value(processed_in_scale) return processed_in_scale def _process_weight_scale( @@ -426,7 +517,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): processed_weight_scale = ( paddle.stack(weight_scales, axis=0) / (127 * 112) / processed_in_scale[:, None] ).cast(paddle.get_default_dtype()) - create_and_set_parameter(layer, name, processed_weight_scale) + getattr(layer, name).set_value(processed_weight_scale) # 1. Init scale containers and maps up_gate_proj_weight_scales = [] @@ -456,8 +547,8 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): for expert_idx in range(layer.num_experts): scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]) up_gate_proj_in_scales_all_experts.append(1 / scale_tensor) - create_and_set_parameter( - layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts) + getattr(layer, "up_gate_proj_in_scale_all_experts").set_value( + paddle.concat(up_gate_proj_in_scales_all_experts) ) for local_expert_idx in range(layer.num_local_experts): @@ -527,15 +618,85 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): "down_proj_weight_scale": down_proj_weight_scale, } for name, tensor in name_tensor_map.items(): - create_and_set_parameter(layer, name, tensor) + getattr(layer, name).set_value(tensor) - def create_weights(self, layer: nn.Layer, state_dict): + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ Paddle cutlass create weight process. """ + self.default_dtype = layer._helper.get_default_dtype() + self.weight_dtype = "int8" + + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + if self.moe_quant_type == "weight_only_int4": + self.ffn1_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size, + layer.hidden_size, + ] + else: + self.ffn1_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + if self.moe_quant_type == "weight_only_int4": + self.ffn2_weight_shape = [ + layer.num_local_experts, + layer.hidden_size // 2, + layer.moe_intermediate_size, + ] + else: + self.ffn2_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.ffn2_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Paddle cutlass load weight process. + """ up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) self.check(layer, up_gate_proj_weights, down_proj_weights) - for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -547,7 +708,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): weight_list.append(quant_weight) weight_scale_list.append(scale) quanted_weight = paddle.stack(weight_list, axis=0) - create_and_set_parameter(layer, weight_name, quanted_weight) + getattr(layer, weight_name).set_value(quanted_weight) quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) - create_and_set_parameter(layer, scale_name, quanted_weight_scale) + getattr(layer, scale_name).set_value(quanted_weight_scale) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index a1ace0e61..b79d8c077 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -19,7 +19,7 @@ from paddle import nn import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce -from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor +from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase @@ -52,10 +52,66 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): """process_prequanted_weights""" pass - def create_weights(self, layer: nn.Layer, state_dict): + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ Triton MoE create weight process. """ + self.weight_dtype = "int8" + self.default_dtype = layer._helper.get_default_dtype() + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + self.ffn1_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size * 2, + ] + self.ffn2_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size, + layer.hidden_size, + ] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.ffn2_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Triton MoE load weight process. + """ up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) assert len(up_gate_proj_weights) == layer.num_local_experts assert len(down_proj_weights) == layer.num_local_experts @@ -90,25 +146,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): quanted_weight = paddle.round(quanted_weight).astype("int8") quanted_weight_scale = quanted_weight_scale / max_bound - setattr( - layer, - weight_name, - layer.create_parameter( - shape=quanted_weight.shape, - dtype=quanted_weight.dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) getattr(layer, weight_name).set_value(quanted_weight) - - setattr( - layer, - scale_name, - layer.create_parameter( - shape=quanted_weight_scale.shape, - dtype=quanted_weight_scale.dtype, - ), - ) getattr(layer, scale_name).set_value(quanted_weight_scale) def apply( @@ -264,6 +302,14 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): Triton Group Gemm to compute Fused MoE. """ self.quant_method = quant_method + self.added_wfp8afp8_attrs = [ + "up_gate_proj_weight", + "down_proj_weight", + "up_gate_proj_weight_scale", + "down_proj_weight_scale", + "up_gate_proj_in_scale", + "down_proj_in_scale", + ] def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: """process_prequanted_weights""" @@ -281,15 +327,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn) down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn) - added_wfp8afp8_attrs = [ - "up_gate_proj_weight", - "down_proj_weight", - "up_gate_proj_weight_scale", - "down_proj_weight_scale", - "up_gate_proj_in_scale", - "down_proj_in_scale", - ] - def _extract_scale_tensor(key_template): result = [] for i in range(layer.num_experts): @@ -312,26 +349,58 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): down_proj_in_scale, ] ): - name = added_wfp8afp8_attrs[idx] - setattr( - layer, - name, - layer.create_parameter( - shape=weight_tensor.shape, - dtype=weight_tensor.dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + name = self.added_wfp8afp8_attrs[idx] if weight_tensor.dtype == paddle.float8_e4m3fn: getattr(layer, name).copy_(weight_tensor, False) else: getattr(layer, name).set_value(weight_tensor) - def create_weights(self, layer: nn.Layer, state_dict): + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ Triton MoE create weight process. """ - pass + self.weight_dtype = paddle.float8_e4m3fn + self.default_dtype = layer._helper.get_default_dtype() + up_gate_proj_weight_name = self.added_wfp8afp8_attrs[0] + down_proj_weight_name = self.added_wfp8afp8_attrs[1] + self.ffn1_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + self.ffn2_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.ffn2_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + for idx in range(2, len(self.added_wfp8afp8_attrs)): + setattr( + layer, + self.added_wfp8afp8_attrs[idx], + layer.create_parameter( + shape=[layer.num_local_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) def apply( self, @@ -531,14 +600,76 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): raise NotImplementedError - def create_weights(self, layer: nn.Layer, state_dict): + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + Triton MoE create weight process. + """ + self.weight_dtype = paddle.float8_e4m3fn + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + self.ffn1_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + self.ffn2_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=[ + layer.num_local_experts, + layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0], + layer.hidden_size // self.quant_config.weight_block_size[1], + ], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=[ + layer.num_local_experts, + layer.hidden_size // self.quant_config.weight_block_size[0], + layer.moe_intermediate_size // self.quant_config.weight_block_size[1], + ], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): """ Triton MoE create weight process. """ up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) self.check(layer, up_gate_proj_weights, down_proj_weights) - for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -554,11 +685,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): weight_scale_list.append(scale) quanted_weight = paddle.stack(weight_list, axis=0) quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn) - create_and_set_parameter(layer, weight_name, quanted_weight) + getattr(layer, weight_name).copy_(quanted_weight, False) quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous() - create_and_set_parameter(layer, scale_name, quanted_weight_scale) + getattr(layer, scale_name).set_value(quanted_weight_scale) def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): """ diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 22c2e3f7c..310f4d3df 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -22,8 +22,14 @@ from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger +# TODO(lulinjun): remove this import after supporting all backends +is_supported_moe_backend = None +if current_platform.is_cuda(): + from .check_backend_supported import is_supported_moe_backend + def get_moe_method(): """ @@ -121,10 +127,7 @@ class FusedMoE(nn.Layer): self.quant_method = moe_quant_config.get_quant_method(self) self.moe_quant_type = moe_quant_config.name() else: - # w_fp16 a_fp16 self.quant_method = get_moe_method() - self.quant_method.create_weights(self, weight_loader=self.weight_loader) - self.redundant_table_manger = None if self.ep_size > 1: if fd_config.model_config.enable_redundant_experts is True: @@ -139,6 +142,20 @@ class FusedMoE(nn.Layer): if fd_config.load_config.dynamic_load_weight: # It's for RL to build model self.init_moe_weights() + else: + self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None) + if self.gate_correction_bias_key is not None: + self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32") + if moe_quant_config: + if ( + moe_quant_config + and is_supported_moe_backend is not None + and is_supported_moe_backend(self.quant_method) + ): + self.quant_method.create_weights(self, weight_loader=self.weight_loader) + else: + # w_fp16 a_fp16 + self.quant_method.create_weights(self, weight_loader=self.weight_loader) logger.info( f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \ @@ -475,23 +492,33 @@ class FusedMoE(nn.Layer): gate_correction_bias_tensor = self.extract_gate_correction_bias( self.gate_correction_bias_key, state_dict ) - self.gate_correction_bias = self.create_parameter( - shape=gate_correction_bias_tensor.shape, - dtype="float32", - ) self.gate_correction_bias.set_value(gate_correction_bias_tensor) + else: + self.gate_correction_bias = None - if self.fd_config.model_config.is_quantized: - if getattr(self.fd_config.quant_config, "is_permuted", True): - self.quant_method.process_prequanted_weights(self, state_dict) - else: - self.quant_method.create_weights(self, state_dict) else: - if self.moe_quant_config: - self.quant_method.create_weights(self, state_dict) + self.gate_correction_bias = None + + if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method): + if self.fd_config.model_config.is_quantized: + if getattr(self.fd_config.quant_config, "is_permuted", True): + self.quant_method.process_prequanted_weights(self, state_dict) + else: + self.quant_method.process_loaded_weights(self, state_dict) else: - # w_fp16 a_fp16 self.quant_method.process_loaded_weights(self, state_dict) + else: + if self.fd_config.model_config.is_quantized: + if getattr(self.fd_config.quant_config, "is_permuted", True): + self.quant_method.process_prequanted_weights(self, state_dict) + else: + self.quant_method.create_weights(self, state_dict) + else: + if self.moe_quant_config: + self.quant_method.create_weights(self, state_dict) + else: + # w_fp16 a_fp16 + self.quant_method.process_loaded_weights(self, state_dict) def forward(self, x: paddle.Tensor, gate: nn.Layer): """ diff --git a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py index 965695216..9576882ec 100644 --- a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py @@ -82,7 +82,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase): self.weight_dtype = "float8_e4m3fn" def create_weights(self, layer, **extra_weight_attrs): - + layer.weight_dtype = "float8_e4m3fn" layer.weight = layer.create_parameter( shape=layer.weight_shape, dtype=layer.weight_dtype, diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index e7a6c0137..b5e1c2ad0 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -15,7 +15,7 @@ """ import functools -from typing import Tuple, Union +from typing import Any, Optional, Tuple, Union import numpy as np import paddle @@ -45,6 +45,14 @@ if cache_params != "none": c8_state_dict = paddle.load(cache_params, return_numpy=True) +# TODO(lulinjun): delete it, import from fastdeploy.model_executor.models.utils after supporting all backends +def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): + if param_attr_map is None: + return + for key, value in param_attr_map.items(): + setattr(param, key, value) + + def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: """ Only used in deep_gemm block wise quant weight. diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 712cff972..6aacb3a59 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -66,7 +66,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool """ with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f: weight_list = json.load(f)["weight_map"] - filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k} + filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k} num_local_ffn_keys = [] from itertools import chain diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 29e71ed93..4a0250028 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -424,7 +424,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): """ self.ernie.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + if hasattr(self.lm_head, "linear"): + self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + else: # ep + self.lm_head.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) else: self.lm_head.load_state_dict(state_dict)