diff --git a/fastdeploy/model_executor/layers/backends/metax/__init__.py b/fastdeploy/model_executor/layers/backends/metax/__init__.py index 1cc5b8260..51650d0b9 100644 --- a/fastdeploy/model_executor/layers/backends/metax/__init__.py +++ b/fastdeploy/model_executor/layers/backends/metax/__init__.py @@ -14,7 +14,10 @@ from .attention.flash_attn_backend import FlashAttentionBackend from .attention.mla_attn_metax_backend import MetaxMLAAttentionBackend -from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod +from .moe.fused_moe_cutlass_metax_backend import ( + MetaxCutlassUnquantizedFusedMoEMethod, + MetaxCutlassWeightOnlyMoEMethod, +) from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod __all__ = [ @@ -22,4 +25,5 @@ __all__ = [ "MetaxMLAAttentionBackend", "MetaxTritonWeightOnlyMoEMethod", "MetaxCutlassWeightOnlyMoEMethod", + "MetaxCutlassUnquantizedFusedMoEMethod", ] diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py index eb090b486..3d354df99 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py @@ -21,8 +21,12 @@ from paddle import nn from paddle.nn.quant import weight_quantize from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce -from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( + MoEMethodBase, + UnquantizedFusedMoEMethod, +) from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.ops.gpu import ( fused_expert_moe, @@ -30,7 +34,147 @@ from fastdeploy.model_executor.ops.gpu import ( moe_expert_ffn, moe_expert_reduce, ) -from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs +from fastdeploy.model_executor.utils import ( + TensorTracker, + free_tensor, + process_weight_transpose, + set_weight_attrs, + weight_fully_copied, +) + + +class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + """ + Use Cutlass Group Gemm to compute Fused MoE. + This method is the oldest way to compute MoE in Paddle. + """ + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = ( + layer.extract_moe_ffn_weights(state_dict) + ) + stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) + stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) + + layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights) + layer.down_proj_weight.set_value(stacked_down_proj_weights) + + if layer.with_bias: + up_gate_proj_bias, down_proj_bias = layer.extract_moe_ffn_bias(state_dict) + stacked_up_gate_proj_bias = paddle.stack(up_gate_proj_bias, axis=0) + stacked_down_proj_bias = paddle.stack(down_proj_bias, axis=0) + + layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias) + layer.down_proj_bias.set_value(stacked_down_proj_bias) + + def compute_ffn( + self, + layer: nn.Layer, + permute_input: paddle.Tensor, + token_nums_per_expert: paddle.Tensor, + expert_idx_per_token: paddle.Tensor, + used_in_ep_low_latency: bool = False, + estimate_total_token_nums: int = -1, + ): + """ + Paddle Cutlass compute Fused MoE. + """ + raise NotImplementedError + + def apply_ep_prefill( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Apply the EP prefill method. + """ + raise NotImplementedError + + def apply_ep_decode( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Apply the EP decoder method. + """ + raise NotImplementedError + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Paddle Cutlass compute Fused MoE. + """ + """ + Paddle Cutlass compute Fused MoE. + """ + if layer.topk_method == "noaux_tc": + gate_out = gate(x.cast("float32")) + + gate_out, topk_weights, topk_idx = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + + ( + permute_input, + token_nums_per_expert, + permute_indices_per_token, + topk_weights, + topk_idx, + ) = moe_expert_dispatch( + x, + gate_out, + layer.top_k, + False, + True, + ) + + ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None) + + fused_moe_out = moe_expert_reduce( + ffn_out, + topk_weights, + permute_indices_per_token, + topk_idx, + None, + False, + 1.0, + ) + else: + raise NotImplementedError + + fused_moe_out = fused_expert_moe( + x, + gate.weight, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), + None, + (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), + None, + (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), + "weight_only_int8", + layer.top_k, + True, + False, + ) + + if layer.reduce_results and layer.tp_size > 1: + fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group) + + return fused_moe_out class MetaxCutlassMoEMethod(MoEMethodBase): @@ -142,18 +286,11 @@ class MetaxCutlassMoEMethod(MoEMethodBase): 1.0, ) else: - added_weight_attrs0 = getattr(layer, self.added_weight_attrs[0]) - added_weight_attrs1 = getattr(layer, self.added_weight_attrs[1]) - - if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1": - added_weight_attrs0 = paddle.transpose(added_weight_attrs0, perm=[0, 2, 1]) - added_weight_attrs1 = paddle.transpose(added_weight_attrs1, perm=[0, 2, 1]) - fused_moe_out = fused_expert_moe( x, gate.weight, - added_weight_attrs0, - added_weight_attrs1, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), None, @@ -177,7 +314,10 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod): def __init__(self, quant_config): super().__init__(quant_config) - self.quant_config = quant_config + if quant_config is None: + self.quant_config = WeightOnlyConfig(algo="weight_only_int8", is_checkpoint_bf16=True) + else: + self.quant_config = quant_config self.moe_quant_type = self.quant_config.algo self.pack_num = 1 self.weight_only_linear_arch = os.getenv("FLAGS_weight_only_linear_arch") @@ -252,33 +392,61 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod): ] self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2] self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size] + self.model_format = extra_weight_attrs.get("model_format") # TODO(bukejiyu): remove v1 loader check when v0 loader is removed if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1": + if self.model_format != "torch": + up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size * 2, + ] + down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size] + up_gate_proj_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True), + } + down_proj_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False), + } + else: + up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size] + up_gate_proj_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False), + "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}, + } + down_proj_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True), + "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}, + } + layer.up_gate_proj_weight = layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + shape=up_gate_proj_weight_shape, dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ) layer.down_proj_weight = layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size], + shape=down_proj_weight_shape, dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ) - extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch" + # extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch" set_weight_attrs( layer.up_gate_proj_weight, - { - **extra_weight_attrs, - "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), - }, + up_gate_proj_attrs, ) set_weight_attrs( layer.down_proj_weight, - { - **extra_weight_attrs, - "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), - }, + down_proj_attrs, ) else: self.weight_dtype = "int8" @@ -325,7 +493,7 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod): default_initializer=paddle.nn.initializer.Constant(0), ), ) - extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch" + # extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch" moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs) set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs) @@ -337,69 +505,71 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod): set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs) def process_weights_after_loading(self, layer): - """ """ - if not self.quant_config.is_checkpoint_bf16: - return - weight_id_map = {"gate_up": 0, "down": 1} - if ( - hasattr(layer.up_gate_proj_weight, "tensor_track") - and layer.up_gate_proj_weight.tensor_track is not None - and layer.up_gate_proj_weight.tensor_track.is_fully_copied() - ): - weight_type = "gate_up" - else: - weight_type = "down" + def _process_quantize(weight_idx): + # 1.init shape and type + weight_name = self.added_weight_attrs[weight_idx] + unquantized_weight_name = weight_name.replace("quant_weight", "weight") + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + transposed_weight_shape = [weight_shape[0], weight_shape[2], weight_shape[1]] + weight_dtype = "int8" + # scale + scale_name = self.added_scale_attrs[weight_idx] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = self.default_dtype - # 1.init shape and type - # weight - weight_name = self.added_weight_attrs[weight_id_map[weight_type]] - unquantized_weight_name = weight_name.replace("quant_weight", "weight") - weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape - weight_shape[1], weight_shape[2] = weight_shape[2], weight_shape[1] - weight_dtype = "int8" - # scale - scale_name = self.added_scale_attrs[weight_id_map[weight_type]] - scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape - scale_dtype = self.default_dtype + # 2.crate tmp tensor - # 2.crate tmp tensor + weight = paddle.empty(transposed_weight_shape, dtype=weight_dtype) + scale = paddle.empty(scale_shape, dtype=scale_dtype) - weight = paddle.empty(weight_shape, dtype=weight_dtype) - scale = paddle.empty(scale_shape, dtype=scale_dtype) + # 3.quantize weight - # 3.quantize weight + for expert_id in range(layer.num_local_experts): + weight[expert_id], scale[expert_id] = weight_quantize( + getattr(layer, unquantized_weight_name)[expert_id], + algo=self.moe_quant_type, + arch=self.weight_only_linear_arch, + ) - for expert_id in range(layer.num_local_experts): - weight[expert_id], scale[expert_id] = weight_quantize( - getattr(layer, unquantized_weight_name)[expert_id], - algo=self.moe_quant_type, - arch=self.weight_only_linear_arch, + free_tensor(getattr(layer, unquantized_weight_name)) + + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_shape, + dtype=weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=scale_shape, + dtype=scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]), False) + getattr(layer, scale_name).copy_(scale, False) - free_tensor(getattr(layer, unquantized_weight_name)) + if self.quant_config.is_checkpoint_bf16: + weight_id_map = {"gate_up": 0, "down": 1} + if weight_fully_copied(layer.up_gate_proj_weight): + weight_type = "gate_up" + else: + weight_type = "down" - # create weight - setattr( - layer, - weight_name, - layer.create_parameter( - shape=weight_shape, - dtype=weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # create scale - setattr( - layer, - scale_name, - layer.create_parameter( - shape=scale_shape, - dtype=scale_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - getattr(layer, weight_name).copy_(weight, False) - getattr(layer, scale_name).copy_(scale, False) + if self.model_format == "torch": + unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace( + "quant_weight", "weight" + ) + process_weight_transpose(layer, unquantized_weight_name) + _process_quantize(weight_id_map[weight_type]) + else: + return def process_loaded_weights(self, layer: nn.Layer, state_dict): """ diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c25263db1..ea3497478 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -59,10 +59,10 @@ def get_moe_method(): elif current_platform.is_maca(): from fastdeploy.model_executor.layers.backends import ( - MetaxCutlassWeightOnlyMoEMethod, + MetaxCutlassUnquantizedFusedMoEMethod, ) - return MetaxCutlassWeightOnlyMoEMethod(None) + return MetaxCutlassUnquantizedFusedMoEMethod(None) raise NotImplementedError @@ -227,7 +227,7 @@ class FusedMoE(nn.Layer): return if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM - elif current_platform.is_cuda() or current_platform.is_iluvatar(): + elif current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_maca(): SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} else: SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}