mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
Move create_parameters to __init__ in FuseMOE for CultassBackend and TritonBackend (#3148)
* w4a8 bug * fix w4a8 bug * remove code * modify the triton backend * fix ep * fix the bug with tensor_wise_fp8 in triton backend * fix the RL * fix bug by merge * fix the bug in w4a8 * fix the tensor_wise_fp8 bug * fix RL
This commit is contained in:
@@ -59,7 +59,7 @@ class ExpertService:
|
||||
self.cfg.disaggregate_info = None
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
if cfg.splitwise_role != "mixed":
|
||||
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
|
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||
CutlassMoEMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
BlockWiseFP8MoEMethod,
|
||||
TensorWiseFP8MoEMethod,
|
||||
TritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod)
|
||||
|
||||
|
||||
def is_supported_moe_backend(quant_method: MoEMethodBase):
|
||||
return isinstance(quant_method, pre_create_weights_list)
|
@@ -19,7 +19,7 @@ from abc import abstractmethod
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||
from fastdeploy.model_executor.layers.utils import set_weight_attrs
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
@@ -23,7 +23,7 @@ import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import create_and_set_parameter, get_tensor
|
||||
from ..utils import get_tensor
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@@ -202,7 +202,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
expertwise_scale = None
|
||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
|
||||
@@ -382,12 +385,48 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
"down_proj_in_scale": down_proj_in_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
self.weight_dtype = "int8"
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 2,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size // 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
self.create_w4a8_scale_weights(layer, layer.weight_key_map)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
@@ -397,11 +436,63 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
|
||||
weight_list.append(quant_weight)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||
self.load_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||
|
||||
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
||||
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
||||
"""
|
||||
Get w4a8 weights from state dict and process them.
|
||||
Args:
|
||||
layer (nn.Layer): The layer to add parameters to.
|
||||
weight_key_map (dict): The weight key map.
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
if layer.ep_size > 1:
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_in_scale_all_experts",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
# in_scales
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
setattr(
|
||||
layer,
|
||||
in_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
# weight_scales
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
"down_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
||||
"""
|
||||
Get w4a8 weights from state dict and process them.
|
||||
Args:
|
||||
@@ -415,7 +506,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||
create_and_set_parameter(layer, name, processed_in_scale)
|
||||
getattr(layer, name).set_value(processed_in_scale)
|
||||
return processed_in_scale
|
||||
|
||||
def _process_weight_scale(
|
||||
@@ -426,7 +517,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
processed_weight_scale = (
|
||||
paddle.stack(weight_scales, axis=0) / (127 * 112) / processed_in_scale[:, None]
|
||||
).cast(paddle.get_default_dtype())
|
||||
create_and_set_parameter(layer, name, processed_weight_scale)
|
||||
getattr(layer, name).set_value(processed_weight_scale)
|
||||
|
||||
# 1. Init scale containers and maps
|
||||
up_gate_proj_weight_scales = []
|
||||
@@ -456,8 +547,8 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
for expert_idx in range(layer.num_experts):
|
||||
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
|
||||
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||
create_and_set_parameter(
|
||||
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
|
||||
paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
)
|
||||
|
||||
for local_expert_idx in range(layer.num_local_experts):
|
||||
@@ -527,15 +618,85 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
self.weight_dtype = "int8"
|
||||
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
if self.moe_quant_type == "weight_only_int4":
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
else:
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
if self.moe_quant_type == "weight_only_int4":
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 2,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
else:
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
@@ -547,7 +708,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
@@ -19,7 +19,7 @@ from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
@@ -52,10 +52,66 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""process_prequanted_weights"""
|
||||
pass
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
self.weight_dtype = "int8"
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
@@ -90,25 +146,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
),
|
||||
)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
@@ -264,6 +302,14 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.added_wfp8afp8_attrs = [
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
"up_gate_proj_weight_scale",
|
||||
"down_proj_weight_scale",
|
||||
"up_gate_proj_in_scale",
|
||||
"down_proj_in_scale",
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
@@ -281,15 +327,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||
down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||
|
||||
added_wfp8afp8_attrs = [
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
"up_gate_proj_weight_scale",
|
||||
"down_proj_weight_scale",
|
||||
"up_gate_proj_in_scale",
|
||||
"down_proj_in_scale",
|
||||
]
|
||||
|
||||
def _extract_scale_tensor(key_template):
|
||||
result = []
|
||||
for i in range(layer.num_experts):
|
||||
@@ -312,26 +349,58 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
down_proj_in_scale,
|
||||
]
|
||||
):
|
||||
name = added_wfp8afp8_attrs[idx]
|
||||
setattr(
|
||||
layer,
|
||||
name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
name = self.added_wfp8afp8_attrs[idx]
|
||||
if weight_tensor.dtype == paddle.float8_e4m3fn:
|
||||
getattr(layer, name).copy_(weight_tensor, False)
|
||||
else:
|
||||
getattr(layer, name).set_value(weight_tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
pass
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
up_gate_proj_weight_name = self.added_wfp8afp8_attrs[0]
|
||||
down_proj_weight_name = self.added_wfp8afp8_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
for idx in range(2, len(self.added_wfp8afp8_attrs)):
|
||||
setattr(
|
||||
layer,
|
||||
self.added_wfp8afp8_attrs[idx],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -531,14 +600,76 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
|
||||
layer.hidden_size // self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // self.quant_config.weight_block_size[0],
|
||||
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
@@ -554,11 +685,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
getattr(layer, weight_name).copy_(quanted_weight, False)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
|
||||
"""
|
||||
|
@@ -22,8 +22,14 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
# TODO(lulinjun): remove this import after supporting all backends
|
||||
is_supported_moe_backend = None
|
||||
if current_platform.is_cuda():
|
||||
from .check_backend_supported import is_supported_moe_backend
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
@@ -121,10 +127,7 @@ class FusedMoE(nn.Layer):
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method = get_moe_method()
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
|
||||
self.redundant_table_manger = None
|
||||
if self.ep_size > 1:
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
@@ -139,6 +142,20 @@ class FusedMoE(nn.Layer):
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# It's for RL to build model
|
||||
self.init_moe_weights()
|
||||
else:
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None:
|
||||
self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32")
|
||||
if moe_quant_config:
|
||||
if (
|
||||
moe_quant_config
|
||||
and is_supported_moe_backend is not None
|
||||
and is_supported_moe_backend(self.quant_method)
|
||||
):
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \
|
||||
@@ -475,12 +492,22 @@ class FusedMoE(nn.Layer):
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict
|
||||
)
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
|
||||
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
else:
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
|
@@ -82,7 +82,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
self.weight_dtype = "float8_e4m3fn"
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
|
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -45,6 +45,14 @@ if cache_params != "none":
|
||||
c8_state_dict = paddle.load(cache_params, return_numpy=True)
|
||||
|
||||
|
||||
# TODO(lulinjun): delete it, import from fastdeploy.model_executor.models.utils after supporting all backends
|
||||
def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]):
|
||||
if param_attr_map is None:
|
||||
return
|
||||
for key, value in param_attr_map.items():
|
||||
setattr(param, key, value)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Only used in deep_gemm block wise quant weight.
|
||||
|
@@ -66,7 +66,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
||||
"""
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
|
||||
weight_list = json.load(f)["weight_map"]
|
||||
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
|
||||
filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k}
|
||||
num_local_ffn_keys = []
|
||||
|
||||
from itertools import chain
|
||||
|
@@ -424,7 +424,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
self.ernie.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
if hasattr(self.lm_head, "linear"):
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
else: # ep
|
||||
self.lm_head.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
|
Reference in New Issue
Block a user