mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 11:02:01 +08:00
[XPU] Supports BF16 for ERNIE-4.5-21B-A3B and ERNIE-4.5-0.3B (#2765)
* fix no quant xpu moe * change dir of xpu moe weight only
This commit is contained in:
@@ -16,6 +16,6 @@
|
||||
xpu backend methods
|
||||
"""
|
||||
|
||||
from .quantization.weight_only import XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod
|
||||
from .quantization.weight_only import XPUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = ['XPUWeightOnlyLinearMethod', 'XPUWeightOnlyMoEMethod']
|
||||
__all__ = ['XPUWeightOnlyLinearMethod']
|
||||
|
||||
@@ -13,14 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WeightOnlyConfig, WeightOnlyLinearMethod)
|
||||
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
|
||||
@@ -63,103 +58,3 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
layer.linear_weight.set_value(
|
||||
paddle.transpose(quanted_weight_tensor, [1, 0]))
|
||||
layer.linear_weight_scale.set_value(weight_scale_tensor)
|
||||
|
||||
|
||||
class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
XPU Fused MoE Method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: WeightOnlyConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict: Dict[str,
|
||||
paddle.Tensor]):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = added_weight_attrs[idx]
|
||||
scale_name = added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize_xpu(
|
||||
weight_tensor[i], self.moe_quant_type, -1,
|
||||
-1) # weight is [k,n]
|
||||
weight_list.append(quant_weight.transpose(
|
||||
[1, 0])) # transpose weight to [n,k]
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
setattr(
|
||||
layer, scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
))
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
XPU compute Fused MoE.
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
|
||||
|
||||
fused_moe_out = xpu_moe_layer(
|
||||
x,
|
||||
layer.gate_weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None, # ffn1 bias
|
||||
None, # ffn2 bias
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
self.moe_quant_type,
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
)
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
211
fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py
Normal file
211
fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import \
|
||||
WeightOnlyConfig
|
||||
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
|
||||
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
|
||||
|
||||
class XPUMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
XPU MOE
|
||||
"""
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
for weights in [ffn1_weights, ffn2_weights]:
|
||||
for idx, weight in enumerate(weights):
|
||||
weights[idx] = weight.transpose([1, 0])
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(weight_tensor)
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
|
||||
|
||||
fused_moe_out = xpu_moe_layer(
|
||||
x,
|
||||
layer.gate_weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None, # ffn1 bias
|
||||
None, # ffn2 bias
|
||||
None, # ffn1 scale
|
||||
None, # ffn2 scale
|
||||
None, # ffn1_in_scale
|
||||
"", # moe_quant_type
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
)
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
XPU Fused MoE Method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: WeightOnlyConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict: Dict[str,
|
||||
paddle.Tensor]):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = added_weight_attrs[idx]
|
||||
scale_name = added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize_xpu(
|
||||
weight_tensor[i], self.moe_quant_type, -1,
|
||||
-1) # weight is [k,n]
|
||||
weight_list.append(quant_weight.transpose(
|
||||
[1, 0])) # transpose weight to [n,k]
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
setattr(
|
||||
layer, scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
))
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
XPU compute Fused MoE.
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
|
||||
|
||||
fused_moe_out = xpu_moe_layer(
|
||||
x,
|
||||
layer.gate_weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None, # ffn1 bias
|
||||
None, # ffn2 bias
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
self.moe_quant_type,
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
)
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
@@ -20,9 +20,24 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
return moe method based on device platform
|
||||
"""
|
||||
from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
return CutlassMoEMethod(None)
|
||||
elif current_platform.is_xpu():
|
||||
from .fused_moe_xpu_backend import XPUMoEMethod
|
||||
return XPUMoEMethod(None)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
|
||||
return GCUFusedMoeMethod(None)
|
||||
raise NotImplementedError()
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
"""
|
||||
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
||||
@@ -96,13 +111,7 @@ class FusedMoE(nn.Layer):
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
if current_platform.is_cuda():
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
self.quant_method = CutlassMoEMethod(None)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
GCUFusedMoeMethod
|
||||
self.quant_method = GCUFusedMoeMethod(None)
|
||||
self.quant_method = get_moe_method()
|
||||
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
@@ -60,8 +60,10 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod)
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
XPUWeightOnlyLinearMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_xpu_backend import \
|
||||
XPUWeightOnlyMoEMethod
|
||||
if isinstance(layer, FusedMoE):
|
||||
return XPUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user