From 9d9f5df8d01a14604eabc2d546be331d68ae66f6 Mon Sep 17 00:00:00 2001 From: MingkunZhang <39252862+StareAtYou@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:32:26 +0800 Subject: [PATCH] [Metax] support default_v1 loader & thinking model (#4956) Co-authored-by: plusNew001 <95567040+plusNew001@users.noreply.github.com> --- .../metax/moe/fused_moe_cutlass_metax_backend.py | 12 ++++++++++-- fastdeploy/model_executor/layers/embeddings.py | 4 +++- .../layers/quantization/weight_only.py | 3 +++ fastdeploy/model_executor/utils.py | 9 +++++++-- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py index a9a47291d..eb090b486 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py @@ -142,11 +142,18 @@ class MetaxCutlassMoEMethod(MoEMethodBase): 1.0, ) else: + added_weight_attrs0 = getattr(layer, self.added_weight_attrs[0]) + added_weight_attrs1 = getattr(layer, self.added_weight_attrs[1]) + + if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1": + added_weight_attrs0 = paddle.transpose(added_weight_attrs0, perm=[0, 2, 1]) + added_weight_attrs1 = paddle.transpose(added_weight_attrs1, perm=[0, 2, 1]) + fused_moe_out = fused_expert_moe( x, gate.weight, - getattr(layer, self.added_weight_attrs[0]), - getattr(layer, self.added_weight_attrs[1]), + added_weight_attrs0, + added_weight_attrs1, None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), None, @@ -348,6 +355,7 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod): weight_name = self.added_weight_attrs[weight_id_map[weight_type]] unquantized_weight_name = weight_name.replace("quant_weight", "weight") weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_shape[1], weight_shape[2] = weight_shape[2], weight_shape[1] weight_dtype = "int8" # scale scale_name = self.added_scale_attrs[weight_id_map[weight_type]] diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 0e66268ff..76ea13f30 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -24,6 +24,7 @@ from paddle.distributed import fleet from fastdeploy.config import FDConfig from fastdeploy.model_executor.utils import h2d_copy, set_weight_attrs, slice_fn +from fastdeploy.platforms import current_platform from .utils import ( DEFAULT_VOCAB_PADDING_SIZE, @@ -274,7 +275,8 @@ class VocabParallelEmbedding(nn.Layer): if output_dim == 0: h2d_copy(param[: shard_weight.shape[0]], shard_weight) - param[shard_weight.shape[0] :].fill_(0) + if not current_platform.is_maca(): + param[shard_weight.shape[0] :].fill_(0) else: h2d_copy(param[:, : shard_weight.shape[1]], shard_weight) param[:, shard_weight.shape[1] :].fill_(0) diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index a248775cd..38ced01b2 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -326,6 +326,9 @@ class WeightOnlyLinearMethod(QuantMethodBase): arch=self.quant_config.weight_only_linear_arch, ) + if current_platform.is_maca(): + quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0]) + free_tensor(layer.weight) layer.weight = layer.create_parameter( diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index de7d87cbc..e04341061 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -363,8 +363,13 @@ def v1_loader_support(fd_config): def _err_msg(msg: str) -> str: logger.info(msg + "; fallback to the v0 loader for model loading.") - if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_iluvatar()): - _err_msg("v1loader currently only support backends gpu, xpu and iluvatar") + if not ( + current_platform.is_cuda() + or current_platform.is_xpu() + or current_platform.is_iluvatar() + or current_platform.is_maca() + ): + _err_msg("v1loader currently only support backends gpu, xpu, iluvatar and maca") return False if is_pre_sliced_weight(fd_config.model_config.model):