mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[GCU] Support gcu platform (#2702)
baseline: e7fa57ebae
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
gcu moe
|
||||
"""
|
@@ -0,0 +1,402 @@
|
||||
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \
|
||||
MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.utils import (CpuGuard,
|
||||
create_and_set_parameter,
|
||||
get_tensor)
|
||||
from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
topk_softmax,
|
||||
weight_quantize_custom_rtn,
|
||||
weight_quantize_rtn)
|
||||
|
||||
|
||||
class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
Use GCU to compute Fused MoE.
|
||||
"""
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
# shape [E, K, N] -> [E, N, K]
|
||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(weight_tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
enable_quant = False
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
token_num, hidden_size = x.shape
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
num_experts = layer.num_local_experts
|
||||
|
||||
topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype)
|
||||
topk_indices = paddle.empty([token_num, top_k], dtype="int32")
|
||||
token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",)
|
||||
topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
block_size = config["BLOCK_SIZE_M"]
|
||||
max_num_tokens_padded = np.prod(topk_indices.shape) + num_experts * (block_size - 1)
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
sorted_token_ids = paddle.empty([max_num_tokens_padded], dtype="int32")
|
||||
expert_ids = paddle.zeros(shape=[max_num_m_blocks], dtype="int32")
|
||||
num_tokens_post_pad = paddle.empty([1], dtype="int32")
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad = moe_align_block_size(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
topk_indices,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num, top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
|
||||
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
x, # input
|
||||
layer.moe_ffn1_weight, # weight
|
||||
intermediate_cache1, # output
|
||||
None, # A_scale
|
||||
ffn1_B_scale, # B_scale
|
||||
ffn1_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
False, # mul_routed_weight
|
||||
top_k,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num, top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size])
|
||||
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num, top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
|
||||
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2, # input
|
||||
layer.moe_ffn2_weight, # weight
|
||||
intermediate_cache3, # output
|
||||
None, # A_scale
|
||||
ffn2_B_scale, # B_scale
|
||||
ffn2_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
True, # mul_routed_weight
|
||||
1,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
fused_moe_out = intermediate_cache3.sum(axis=1)
|
||||
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
||||
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
|
||||
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
weight only for moe
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
assert self.quant_config.algo == "weight_only_int4", \
|
||||
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
|
||||
self.added_qzeros_attrs = [
|
||||
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
|
||||
]
|
||||
self.group_size = 64
|
||||
|
||||
self.quant_multi_process_group_size = int(
|
||||
os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8)
|
||||
)
|
||||
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
with CpuGuard():
|
||||
p_group_size = len(weights)
|
||||
for group_j in range(p_group_size):
|
||||
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
|
||||
quant_weight, scale = weight_quantize_custom_rtn(
|
||||
weights[group_j],
|
||||
moe_quant_type,
|
||||
group_size # group_size
|
||||
)
|
||||
shared_dict[p_group_size * p_group_idx + group_j] = (
|
||||
quant_weight, scale
|
||||
)
|
||||
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
zeros_name = self.added_qzeros_attrs[idx]
|
||||
|
||||
if self.quant_multi_process_group_size > 0:
|
||||
process_group_size = self.quant_multi_process_group_size
|
||||
process_group_num = layer.num_local_experts // process_group_size
|
||||
grouped_weights_num = process_group_num * process_group_size
|
||||
remain_weights_start_idx = grouped_weights_num
|
||||
|
||||
weight_list = [None] * grouped_weights_num
|
||||
weight_scale_list = [None] * grouped_weights_num
|
||||
|
||||
with multiprocessing.Manager() as manager:
|
||||
shared_dict = manager.dict({})
|
||||
processes = []
|
||||
|
||||
for i in range(process_group_num):
|
||||
w = []
|
||||
for j in range(process_group_size):
|
||||
w.append(weight_tensor[process_group_size * i + j].to("cpu"))
|
||||
|
||||
p = multiprocessing.Process(
|
||||
target=quant_worker,
|
||||
args=(i, shared_dict, w, self.moe_quant_type, self.group_size)
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
dict_ = dict(shared_dict)
|
||||
|
||||
for k, v in dict_.items():
|
||||
weight_list[k] = v[0].to(ffn1_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
|
||||
else:
|
||||
remain_weights_start_idx = 0
|
||||
|
||||
if remain_weights_start_idx < layer.num_local_experts:
|
||||
for i in range(remain_weights_start_idx, layer.num_local_experts):
|
||||
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
|
||||
quant_weight, scale = weight_quantize_rtn(
|
||||
weight_tensor[i],
|
||||
self.moe_quant_type,
|
||||
self.group_size # group_size
|
||||
)
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
quanted_weight_zeros = quanted_weight_scale * 8
|
||||
create_and_set_parameter(layer, zeros_name, quanted_weight_zeros)
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=True)
|
Reference in New Issue
Block a user