Move create_parameters to __init__ in FuseMOE for CultassBackend and TritonBackend (#3148)

* w4a8 bug

* fix w4a8 bug

* remove code

* modify the triton backend

* fix ep

* fix the bug with tensor_wise_fp8 in triton backend

* fix the RL

* fix bug by merge

* fix the bug in w4a8

* fix the tensor_wise_fp8 bug

* fix RL
This commit is contained in:
Zero Rains
2025-08-08 15:55:47 +08:00
committed by GitHub
parent d0e9a70380
commit ce1f353c70
10 changed files with 444 additions and 83 deletions

View File

@@ -59,8 +59,8 @@ class ExpertService:
self.cfg.disaggregate_info = None self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
if cfg.splitwise_role != "mixed":
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id

View File

@@ -0,0 +1,31 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
CutlassMoEMethod,
)
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
BlockWiseFP8MoEMethod,
TensorWiseFP8MoEMethod,
TritonWeightOnlyMoEMethod,
)
pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod)
def is_supported_moe_backend(quant_method: MoEMethodBase):
return isinstance(quant_method, pre_create_weights_list)

View File

@@ -19,7 +19,7 @@ from abc import abstractmethod
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy.model_executor.models.utils import set_weight_attrs from fastdeploy.model_executor.layers.utils import set_weight_attrs
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase

View File

@@ -23,7 +23,7 @@ import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor from ..utils import get_tensor
from .fused_moe_backend_base import UnquantizedFusedMoEMethod from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda(): if current_platform.is_cuda():
@@ -202,7 +202,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
gate_out = gate(x.cast("float32")) gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights # 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None) expertwise_scale = None
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
# 2. EP Dispatch # 2. EP Dispatch
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
@@ -382,12 +385,48 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
"down_proj_in_scale": down_proj_in_scale, "down_proj_in_scale": down_proj_in_scale,
} }
for name, tensor in name_tensor_map.items(): for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor) getattr(layer, name).set_value(tensor)
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
""" """
Paddle cutlass create weight process. Paddle cutlass create weight process.
""" """
self.weight_dtype = "int8"
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.hidden_size // 2,
layer.moe_intermediate_size * 2,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size // 2,
layer.hidden_size,
]
setattr(
layer,
self.added_weight_attrs[0],
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_weight_attrs[1],
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
self.create_w4a8_scale_weights(layer, layer.weight_key_map)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights) self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
@@ -397,11 +436,63 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80) quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
weight_list.append(quant_weight) weight_list.append(quant_weight)
quanted_weight = paddle.stack(weight_list, axis=0) quanted_weight = paddle.stack(weight_list, axis=0)
create_and_set_parameter(layer, weight_name, quanted_weight) getattr(layer, weight_name).set_value(quanted_weight)
self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict) self.load_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict): def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
"""
Get w4a8 weights from state dict and process them.
Args:
layer (nn.Layer): The layer to add parameters to.
weight_key_map (dict): The weight key map.
state_dict (dict): The state dict.
"""
self.default_dtype = layer._helper.get_default_dtype()
if layer.ep_size > 1:
setattr(
layer,
"up_gate_proj_in_scale_all_experts",
layer.create_parameter(
shape=[layer.num_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# in_scales
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
setattr(
layer,
in_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scales
setattr(
layer,
"up_gate_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
"down_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
""" """
Get w4a8 weights from state dict and process them. Get w4a8 weights from state dict and process them.
Args: Args:
@@ -415,7 +506,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]): def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
processed_in_scale = 1 / paddle.concat(in_scales) processed_in_scale = 1 / paddle.concat(in_scales)
create_and_set_parameter(layer, name, processed_in_scale) getattr(layer, name).set_value(processed_in_scale)
return processed_in_scale return processed_in_scale
def _process_weight_scale( def _process_weight_scale(
@@ -426,7 +517,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
processed_weight_scale = ( processed_weight_scale = (
paddle.stack(weight_scales, axis=0) / (127 * 112) / processed_in_scale[:, None] paddle.stack(weight_scales, axis=0) / (127 * 112) / processed_in_scale[:, None]
).cast(paddle.get_default_dtype()) ).cast(paddle.get_default_dtype())
create_and_set_parameter(layer, name, processed_weight_scale) getattr(layer, name).set_value(processed_weight_scale)
# 1. Init scale containers and maps # 1. Init scale containers and maps
up_gate_proj_weight_scales = [] up_gate_proj_weight_scales = []
@@ -456,8 +547,8 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
for expert_idx in range(layer.num_experts): for expert_idx in range(layer.num_experts):
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]) scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor) up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
create_and_set_parameter( getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts) paddle.concat(up_gate_proj_in_scales_all_experts)
) )
for local_expert_idx in range(layer.num_local_experts): for local_expert_idx in range(layer.num_local_experts):
@@ -527,15 +618,85 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
"down_proj_weight_scale": down_proj_weight_scale, "down_proj_weight_scale": down_proj_weight_scale,
} }
for name, tensor in name_tensor_map.items(): for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor) getattr(layer, name).set_value(tensor)
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
""" """
Paddle cutlass create weight process. Paddle cutlass create weight process.
""" """
self.default_dtype = layer._helper.get_default_dtype()
self.weight_dtype = "int8"
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
if self.moe_quant_type == "weight_only_int4":
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
else:
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
if self.moe_quant_type == "weight_only_int4":
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.hidden_size // 2,
layer.moe_intermediate_size,
]
else:
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights) self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx] weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx] scale_name = self.added_scale_attrs[idx]
@@ -547,7 +708,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
weight_list.append(quant_weight) weight_list.append(quant_weight)
weight_scale_list.append(scale) weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0) quanted_weight = paddle.stack(weight_list, axis=0)
create_and_set_parameter(layer, weight_name, quanted_weight) getattr(layer, weight_name).set_value(quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
create_and_set_parameter(layer, scale_name, quanted_weight_scale) getattr(layer, scale_name).set_value(quanted_weight_scale)

View File

@@ -19,7 +19,7 @@ from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
@@ -52,10 +52,66 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""process_prequanted_weights""" """process_prequanted_weights"""
pass pass
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
""" """
Triton MoE create weight process. Triton MoE create weight process.
""" """
self.weight_dtype = "int8"
self.default_dtype = layer._helper.get_default_dtype()
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE load weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts assert len(down_proj_weights) == layer.num_local_experts
@@ -90,25 +146,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
quanted_weight = paddle.round(quanted_weight).astype("int8") quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound quanted_weight_scale = quanted_weight_scale / max_bound
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight) getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale) getattr(layer, scale_name).set_value(quanted_weight_scale)
def apply( def apply(
@@ -264,6 +302,14 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE. Triton Group Gemm to compute Fused MoE.
""" """
self.quant_method = quant_method self.quant_method = quant_method
self.added_wfp8afp8_attrs = [
"up_gate_proj_weight",
"down_proj_weight",
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
"up_gate_proj_in_scale",
"down_proj_in_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights""" """process_prequanted_weights"""
@@ -281,15 +327,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn) up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn) down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
added_wfp8afp8_attrs = [
"up_gate_proj_weight",
"down_proj_weight",
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
"up_gate_proj_in_scale",
"down_proj_in_scale",
]
def _extract_scale_tensor(key_template): def _extract_scale_tensor(key_template):
result = [] result = []
for i in range(layer.num_experts): for i in range(layer.num_experts):
@@ -312,26 +349,58 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
down_proj_in_scale, down_proj_in_scale,
] ]
): ):
name = added_wfp8afp8_attrs[idx] name = self.added_wfp8afp8_attrs[idx]
setattr(
layer,
name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
if weight_tensor.dtype == paddle.float8_e4m3fn: if weight_tensor.dtype == paddle.float8_e4m3fn:
getattr(layer, name).copy_(weight_tensor, False) getattr(layer, name).copy_(weight_tensor, False)
else: else:
getattr(layer, name).set_value(weight_tensor) getattr(layer, name).set_value(weight_tensor)
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
""" """
Triton MoE create weight process. Triton MoE create weight process.
""" """
pass self.weight_dtype = paddle.float8_e4m3fn
self.default_dtype = layer._helper.get_default_dtype()
up_gate_proj_weight_name = self.added_wfp8afp8_attrs[0]
down_proj_weight_name = self.added_wfp8afp8_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
for idx in range(2, len(self.added_wfp8afp8_attrs)):
setattr(
layer,
self.added_wfp8afp8_attrs[idx],
layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def apply( def apply(
self, self,
@@ -531,14 +600,76 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
raise NotImplementedError raise NotImplementedError
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Triton MoE create weight process.
"""
self.weight_dtype = paddle.float8_e4m3fn
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[
layer.num_local_experts,
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
layer.hidden_size // self.quant_config.weight_block_size[1],
],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[
layer.num_local_experts,
layer.hidden_size // self.quant_config.weight_block_size[0],
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
""" """
Triton MoE create weight process. Triton MoE create weight process.
""" """
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights) self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx] weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx] scale_name = self.added_scale_attrs[idx]
@@ -554,11 +685,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
weight_scale_list.append(scale) weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0) quanted_weight = paddle.stack(weight_list, axis=0)
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn) quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn)
create_and_set_parameter(layer, weight_name, quanted_weight) getattr(layer, weight_name).copy_(quanted_weight, False)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous() quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale) getattr(layer, scale_name).set_value(quanted_weight_scale)
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
""" """

View File

@@ -22,8 +22,14 @@ from paddleformers.utils.log import logger
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger from fastdeploy.worker.experts_manager import RedundantExpertManger
# TODO(lulinjun): remove this import after supporting all backends
is_supported_moe_backend = None
if current_platform.is_cuda():
from .check_backend_supported import is_supported_moe_backend
def get_moe_method(): def get_moe_method():
""" """
@@ -121,10 +127,7 @@ class FusedMoE(nn.Layer):
self.quant_method = moe_quant_config.get_quant_method(self) self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name() self.moe_quant_type = moe_quant_config.name()
else: else:
# w_fp16 a_fp16
self.quant_method = get_moe_method() self.quant_method = get_moe_method()
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
self.redundant_table_manger = None self.redundant_table_manger = None
if self.ep_size > 1: if self.ep_size > 1:
if fd_config.model_config.enable_redundant_experts is True: if fd_config.model_config.enable_redundant_experts is True:
@@ -139,6 +142,20 @@ class FusedMoE(nn.Layer):
if fd_config.load_config.dynamic_load_weight: if fd_config.load_config.dynamic_load_weight:
# It's for RL to build model # It's for RL to build model
self.init_moe_weights() self.init_moe_weights()
else:
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None:
self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32")
if moe_quant_config:
if (
moe_quant_config
and is_supported_moe_backend is not None
and is_supported_moe_backend(self.quant_method)
):
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
else:
# w_fp16 a_fp16
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
logger.info( logger.info(
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \ f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \
@@ -475,23 +492,33 @@ class FusedMoE(nn.Layer):
gate_correction_bias_tensor = self.extract_gate_correction_bias( gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict self.gate_correction_bias_key, state_dict
) )
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_tensor.shape,
dtype="float32",
)
self.gate_correction_bias.set_value(gate_correction_bias_tensor) self.gate_correction_bias.set_value(gate_correction_bias_tensor)
else:
self.gate_correction_bias = None
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
else: else:
if self.moe_quant_config: self.gate_correction_bias = None
self.quant_method.create_weights(self, state_dict)
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.process_loaded_weights(self, state_dict)
else: else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict) self.quant_method.process_loaded_weights(self, state_dict)
else:
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
else:
if self.moe_quant_config:
self.quant_method.create_weights(self, state_dict)
else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict)
def forward(self, x: paddle.Tensor, gate: nn.Layer): def forward(self, x: paddle.Tensor, gate: nn.Layer):
""" """

View File

@@ -82,7 +82,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
self.weight_dtype = "float8_e4m3fn" self.weight_dtype = "float8_e4m3fn"
def create_weights(self, layer, **extra_weight_attrs): def create_weights(self, layer, **extra_weight_attrs):
layer.weight_dtype = "float8_e4m3fn"
layer.weight = layer.create_parameter( layer.weight = layer.create_parameter(
shape=layer.weight_shape, shape=layer.weight_shape,
dtype=layer.weight_dtype, dtype=layer.weight_dtype,

View File

@@ -15,7 +15,7 @@
""" """
import functools import functools
from typing import Tuple, Union from typing import Any, Optional, Tuple, Union
import numpy as np import numpy as np
import paddle import paddle
@@ -45,6 +45,14 @@ if cache_params != "none":
c8_state_dict = paddle.load(cache_params, return_numpy=True) c8_state_dict = paddle.load(cache_params, return_numpy=True)
# TODO(lulinjun): delete it, import from fastdeploy.model_executor.models.utils after supporting all backends
def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]):
if param_attr_map is None:
return
for key, value in param_attr_map.items():
setattr(param, key, value)
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
""" """
Only used in deep_gemm block wise quant weight. Only used in deep_gemm block wise quant weight.

View File

@@ -66,7 +66,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
""" """
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f: with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"] weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k} filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k}
num_local_ffn_keys = [] num_local_ffn_keys = []
from itertools import chain from itertools import chain

View File

@@ -424,7 +424,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
""" """
self.ernie.load_state_dict(state_dict) self.ernie.load_state_dict(state_dict)
if self.tie_word_embeddings: if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) if hasattr(self.lm_head, "linear"):
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
else: # ep
self.lm_head.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
else: else:
self.lm_head.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict)