[fix] adjust mctlass moe api (#4474)

This commit is contained in:
SuperNova
2025-10-20 14:23:54 +08:00
committed by GitHub
parent 1e59905e34
commit 80a16c4c87
5 changed files with 76 additions and 57 deletions

View File

@@ -1,18 +1,3 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
#include "fused_moe_helper.h"
@@ -45,63 +30,75 @@ void mc_grouped_gemm_basic_kernel(
mctlassExMatrixLayout_t matLayoutC;
// mat A: (m, k)
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// mat B: (num_experts, n, k)
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k);
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// mat C: (m, n)
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n);
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// bias: (num_experts, n)
// scale: (num, n)
mctlassExDesc_t mctlass_desc;
mctlassExCreateDesc(&mctlass_desc);
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32;
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT;
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
if (ptrBias) {
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP;
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
}
// set scale
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
&ptrScale, sizeof(ptrScale));
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE,
&scale_type, sizeof(mctlassExDataType));
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
&input_type, sizeof(mctlassExDataType));
// set bias
if (ptrBias) {
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
&ptrBias, sizeof(ptrBias));
}
// set coumpute type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
&compute_type, sizeof(mctlassExDataType));
// set epilogue type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
&epilogue_type, sizeof(mctlassExEpilogueType));
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR;
int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM);
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
mctlassExContiguousGroupedDescCreate(&contiguous_group_desc,
ptrSegInd,
nullptr,
ptrMNumTilesInd,
1);
int blocksizeM;
mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream);
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
ptrA, matLayoutA,
ptrB, matLayoutB,
ptrC, matLayoutC,
ptrSegInd, nullptr, ptrMNumTilesInd,
contiguous_group_desc,
&algo, nullptr, 0, stream);
mctlassExHandleDestroy(handle);
mctlassExMatrixLayoutDestroy(matLayoutA);
mctlassExMatrixLayoutDestroy(matLayoutB);
mctlassExMatrixLayoutDestroy(matLayoutC);
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
mctlassExDestroyDesc(mctlass_desc);
mcFreeAsync(ptrMNumTilesInd, stream);
}
@@ -334,8 +331,8 @@ class McMoeHelper {
total_rows_before_expert_,
stream);
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_data_),

View File

@@ -1,18 +1,18 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// BUILD_MARK
#pragma once
#include "mc_fused_moe_helper.h"
#include "helper.h"
@@ -47,8 +47,8 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
{expanded_active_expert_rows, inter_size}, input_type, place);
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
// ffn1
auto fc1_expert_biases =
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
// ffn_out);
// break;
default:
PD_THROW("Only support bf16 for MoeExpertFFN");
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return {ffn_out};
}

View File

@@ -1,4 +1,5 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
@@ -110,7 +112,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
False,
)
if layer.reduce_results and layer.tp_size > 1:
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
return fused_moe_out

View File

@@ -1,3 +1,4 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
@@ -44,6 +46,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""process_prequanted_weights"""
pass
@paddle.no_grad()
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Triton MoE create weight process.
@@ -100,6 +103,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
),
)
@paddle.no_grad()
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE load weight process.
@@ -110,8 +114,6 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
@@ -151,31 +153,42 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
gate_out = gate(x.cast("float32"))
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
layer.top_k,
True, # apply_norm_weight
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
@@ -282,5 +295,5 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
out = tensor_model_parallel_all_reduce(out)
tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
return out

View File

@@ -56,6 +56,13 @@ def get_moe_method():
return HpuMoEMethod(None)
# return HpuTensorWiseFP8MoEMethod(None)
elif current_platform.is_maca():
from fastdeploy.model_executor.layers.backends import (
MetaxCutlassWeightOnlyMoEMethod,
)
return MetaxCutlassWeightOnlyMoEMethod(None)
raise NotImplementedError