mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[fix] adjust mctlass moe api (#4474)
This commit is contained in:
@@ -1,18 +1,3 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
#include "fused_moe_helper.h"
|
||||
@@ -45,63 +30,75 @@ void mc_grouped_gemm_basic_kernel(
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
mctlassExDesc_t mctlass_desc;
|
||||
mctlassExCreateDesc(&mctlass_desc);
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32;
|
||||
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT;
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
||||
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP;
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER,
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale, sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE,
|
||||
&scale_type, sizeof(mctlassExDataType));
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type, sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER,
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias, sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE,
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type, sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE,
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type, sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR;
|
||||
int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM);
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
||||
mctlassExContiguousGroupedDescCreate(&contiguous_group_desc,
|
||||
ptrSegInd,
|
||||
nullptr,
|
||||
ptrMNumTilesInd,
|
||||
1);
|
||||
int blocksizeM;
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
|
||||
ptrA, matLayoutA,
|
||||
ptrB, matLayoutB,
|
||||
ptrC, matLayoutC,
|
||||
ptrSegInd, nullptr, ptrMNumTilesInd,
|
||||
contiguous_group_desc,
|
||||
&algo, nullptr, 0, stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutB);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
|
||||
mctlassExDestroyDesc(mctlass_desc);
|
||||
mcFreeAsync(ptrMNumTilesInd, stream);
|
||||
}
|
||||
@@ -334,8 +331,8 @@ class McMoeHelper {
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
// BUILD_MARK
|
||||
#pragma once
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
@@ -47,8 +47,8 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
{expanded_active_expert_rows, inter_size}, input_type, place);
|
||||
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
// ffn_out);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeExpertFFN");
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
}
|
||||
return {ffn_out};
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -110,7 +112,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
False,
|
||||
)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -11,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -44,6 +46,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""process_prequanted_weights"""
|
||||
pass
|
||||
|
||||
@paddle.no_grad()
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
@@ -100,6 +103,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
),
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE load weight process.
|
||||
@@ -110,8 +114,6 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert algo == "wint8"
|
||||
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
@@ -151,31 +153,42 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = gate(x.cast("float32"))
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
if self.quant_config is not None:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
@@ -282,5 +295,5 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
|
||||
return out
|
||||
|
||||
@@ -56,6 +56,13 @@ def get_moe_method():
|
||||
|
||||
return HpuMoEMethod(None)
|
||||
# return HpuTensorWiseFP8MoEMethod(None)
|
||||
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
MetaxCutlassWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return MetaxCutlassWeightOnlyMoEMethod(None)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user