mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[V1 Loader] Support MOE parameters create and load for DeepGemm and marlin backend (#3447)
* support deepgemm backend * support marlin backend * remove print * fix process_prequanted_weights
This commit is contained in:
@@ -18,13 +18,26 @@ from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMetho
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||
CutlassMoEMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
|
||||
DeepGemmFusedMoeMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import (
|
||||
MarlinWeightOnlyMoEMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
BlockWiseFP8MoEMethod,
|
||||
TensorWiseFP8MoEMethod,
|
||||
TritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod)
|
||||
pre_create_weights_list = (
|
||||
CutlassMoEMethod,
|
||||
TensorWiseFP8MoEMethod,
|
||||
BlockWiseFP8MoEMethod,
|
||||
TritonWeightOnlyMoEMethod,
|
||||
DeepGemmFusedMoeMethod,
|
||||
MarlinWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
|
||||
def is_supported_moe_backend(quant_method: MoEMethodBase):
|
||||
|
@@ -23,7 +23,6 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
|
||||
|
||||
from ..utils import create_and_set_parameter
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
|
||||
|
||||
@@ -32,11 +31,73 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
|
||||
"""
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
|
||||
layer.hidden_size // self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // self.quant_config.weight_block_size[0],
|
||||
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
@@ -56,11 +117,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
getattr(layer, weight_name).copy_(quanted_weight, False)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -120,7 +181,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
|
@@ -139,9 +139,63 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
]
|
||||
self.added_zeros_attrs = ["zeros0", "zeros1"]
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
self.weight_dtype = "int32"
|
||||
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 16,
|
||||
layer.moe_intermediate_size * 4,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size // 16,
|
||||
layer.hidden_size * 2,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, 1, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, 1, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Marlin MoE create weight process.
|
||||
Marlin MoE load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
@@ -204,15 +258,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
(weight_name, quanted_weight),
|
||||
(scale_name, weight_scale),
|
||||
]:
|
||||
setattr(
|
||||
layer,
|
||||
name,
|
||||
layer.create_parameter(
|
||||
shape=tensor.shape,
|
||||
dtype=tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def apply(
|
||||
|
@@ -630,7 +630,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
|
Reference in New Issue
Block a user