[Metax] support default_v1 loader based #4988 (#5001)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
MingkunZhang
2025-11-18 09:44:30 +08:00
committed by GitHub
parent 5d7516dc8c
commit a36c958c66
3 changed files with 257 additions and 83 deletions

View File

@@ -14,7 +14,10 @@
from .attention.flash_attn_backend import FlashAttentionBackend
from .attention.mla_attn_metax_backend import MetaxMLAAttentionBackend
from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod
from .moe.fused_moe_cutlass_metax_backend import (
MetaxCutlassUnquantizedFusedMoEMethod,
MetaxCutlassWeightOnlyMoEMethod,
)
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
__all__ = [
@@ -22,4 +25,5 @@ __all__ = [
"MetaxMLAAttentionBackend",
"MetaxTritonWeightOnlyMoEMethod",
"MetaxCutlassWeightOnlyMoEMethod",
"MetaxCutlassUnquantizedFusedMoEMethod",
]

View File

@@ -21,8 +21,12 @@ from paddle import nn
from paddle.nn.quant import weight_quantize
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
MoEMethodBase,
UnquantizedFusedMoEMethod,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import (
fused_expert_moe,
@@ -30,7 +34,147 @@ from fastdeploy.model_executor.ops.gpu import (
moe_expert_ffn,
moe_expert_reduce,
)
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
process_weight_transpose,
set_weight_attrs,
weight_fully_copied,
)
class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.
This method is the oldest way to compute MoE in Paddle.
"""
def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
layer.down_proj_weight.set_value(stacked_down_proj_weights)
if layer.with_bias:
up_gate_proj_bias, down_proj_bias = layer.extract_moe_ffn_bias(state_dict)
stacked_up_gate_proj_bias = paddle.stack(up_gate_proj_bias, axis=0)
stacked_down_proj_bias = paddle.stack(down_proj_bias, axis=0)
layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias)
layer.down_proj_bias.set_value(stacked_down_proj_bias)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
"""
raise NotImplementedError
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
"""
Paddle Cutlass compute Fused MoE.
"""
if layer.topk_method == "noaux_tc":
gate_out = gate(x.cast("float32"))
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
) = moe_expert_dispatch(
x,
gate_out,
layer.top_k,
False,
True,
)
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_idx,
None,
False,
1.0,
)
else:
raise NotImplementedError
fused_moe_out = fused_expert_moe(
x,
gate.weight,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
None,
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
"weight_only_int8",
layer.top_k,
True,
False,
)
if layer.reduce_results and layer.tp_size > 1:
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
return fused_moe_out
class MetaxCutlassMoEMethod(MoEMethodBase):
@@ -142,18 +286,11 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
1.0,
)
else:
added_weight_attrs0 = getattr(layer, self.added_weight_attrs[0])
added_weight_attrs1 = getattr(layer, self.added_weight_attrs[1])
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
added_weight_attrs0 = paddle.transpose(added_weight_attrs0, perm=[0, 2, 1])
added_weight_attrs1 = paddle.transpose(added_weight_attrs1, perm=[0, 2, 1])
fused_moe_out = fused_expert_moe(
x,
gate.weight,
added_weight_attrs0,
added_weight_attrs1,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
None,
@@ -177,7 +314,10 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
def __init__(self, quant_config):
super().__init__(quant_config)
self.quant_config = quant_config
if quant_config is None:
self.quant_config = WeightOnlyConfig(algo="weight_only_int8", is_checkpoint_bf16=True)
else:
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
self.weight_only_linear_arch = os.getenv("FLAGS_weight_only_linear_arch")
@@ -252,33 +392,61 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
]
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
self.model_format = extra_weight_attrs.get("model_format")
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
# extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
down_proj_attrs,
)
else:
self.weight_dtype = "int8"
@@ -325,7 +493,7 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
# extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
@@ -337,69 +505,71 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
else:
weight_type = "down"
def _process_quantize(weight_idx):
# 1.init shape and type
weight_name = self.added_weight_attrs[weight_idx]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
transposed_weight_shape = [weight_shape[0], weight_shape[2], weight_shape[1]]
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_idx]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
# 1.init shape and type
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_shape[1], weight_shape[2] = weight_shape[2], weight_shape[1]
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
# 2.crate tmp tensor
# 2.crate tmp tensor
weight = paddle.empty(transposed_weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
weight = paddle.empty(weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
# 3.quantize weight
# 3.quantize weight
for expert_id in range(layer.num_local_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id],
algo=self.moe_quant_type,
arch=self.weight_only_linear_arch,
)
for expert_id in range(layer.num_local_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id],
algo=self.moe_quant_type,
arch=self.weight_only_linear_arch,
free_tensor(getattr(layer, unquantized_weight_name))
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]), False)
getattr(layer, scale_name).copy_(scale, False)
free_tensor(getattr(layer, unquantized_weight_name))
if self.quant_config.is_checkpoint_bf16:
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
if self.model_format == "torch":
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
else:
return
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""

View File

@@ -59,10 +59,10 @@ def get_moe_method():
elif current_platform.is_maca():
from fastdeploy.model_executor.layers.backends import (
MetaxCutlassWeightOnlyMoEMethod,
MetaxCutlassUnquantizedFusedMoEMethod,
)
return MetaxCutlassWeightOnlyMoEMethod(None)
return MetaxCutlassUnquantizedFusedMoEMethod(None)
raise NotImplementedError
@@ -227,7 +227,7 @@ class FusedMoE(nn.Layer):
return
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
elif current_platform.is_cuda() or current_platform.is_iluvatar():
elif current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_maca():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}