[V1 Loader] Support MOE parameters create and load for DeepGemm and marlin backend (#3447)

* support deepgemm backend

* support marlin backend

* remove print

* fix process_prequanted_weights
This commit is contained in:
Zero Rains
2025-08-19 14:15:53 +08:00
committed by GitHub
parent 6735626014
commit fef447e350
4 changed files with 137 additions and 18 deletions

View File

@@ -18,13 +18,26 @@ from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMetho
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
CutlassMoEMethod,
)
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
DeepGemmFusedMoeMethod,
)
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import (
MarlinWeightOnlyMoEMethod,
)
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
BlockWiseFP8MoEMethod,
TensorWiseFP8MoEMethod,
TritonWeightOnlyMoEMethod,
)
pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod)
pre_create_weights_list = (
CutlassMoEMethod,
TensorWiseFP8MoEMethod,
BlockWiseFP8MoEMethod,
TritonWeightOnlyMoEMethod,
DeepGemmFusedMoeMethod,
MarlinWeightOnlyMoEMethod,
)
def is_supported_moe_backend(quant_method: MoEMethodBase):

View File

@@ -23,7 +23,6 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
from ..utils import create_and_set_parameter
from .fused_moe_backend_base import MoEMethodBase
@@ -32,11 +31,73 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
"""
def create_weights(self, layer: nn.Layer, state_dict):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
deepgemm create weight process.
"""
self.weight_dtype = paddle.float8_e4m3fn
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[
layer.num_local_experts,
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
layer.hidden_size // self.quant_config.weight_block_size[1],
],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[
layer.num_local_experts,
layer.hidden_size // self.quant_config.weight_block_size[0],
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
deepgemm create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
@@ -56,11 +117,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
create_and_set_parameter(layer, weight_name, quanted_weight)
getattr(layer, weight_name).copy_(quanted_weight, False)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
"""
@@ -120,7 +181,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
"down_proj_weight_scale": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
getattr(layer, name).set_value(tensor)
def apply_ep_prefill(
self,

View File

@@ -139,9 +139,63 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
]
self.added_zeros_attrs = ["zeros0", "zeros1"]
def create_weights(self, layer: nn.Layer, state_dict):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
self.default_dtype = layer._helper.get_default_dtype()
self.weight_dtype = "int32"
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.hidden_size // 16,
layer.moe_intermediate_size * 4,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size // 16,
layer.hidden_size * 2,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, 1, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, 1, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Marlin MoE create weight process.
Marlin MoE load weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
@@ -204,15 +258,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
(weight_name, quanted_weight),
(scale_name, weight_scale),
]:
setattr(
layer,
name,
layer.create_parameter(
shape=tensor.shape,
dtype=tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, name).set_value(tensor)
def apply(

View File

@@ -630,7 +630,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),