[Metax] support cutlass moe & optimize flash attention (#4208)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
xiaozude
2025-09-29 11:22:43 +08:00
committed by GitHub
parent 2b2b645296
commit 7c919070f7
20 changed files with 2786 additions and 103 deletions

View File

@@ -14,6 +14,8 @@
#pragma once
#include <cuda_fp8.h>
#ifndef PADDLE_WITH_COREX
#include "glog/logging.h"
#endif

View File

@@ -0,0 +1,181 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "mc_fused_moe_helper.h"
#include "fused_moe_op.h"
__global__ void compute_total_rows_before_expert_kernel(
int* sorted_experts,
const int64_t sorted_experts_len,
const int64_t num_experts,
int32_t* total_rows_before_expert) {
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
total_rows_before_expert[expert] =
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
}
void compute_total_rows_before_expert(int* sorted_indices,
const int64_t total_indices,
const int64_t num_experts,
int32_t* total_rows_before_expert,
cudaStream_t stream) {
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
paddle::Tensor* output) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto* output_data = output->data<data_t>();
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
moe_compute.computeFFN(
&input,
&gate_weight,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
}
std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
// case paddle::DataType::FLOAT16:
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
// gate_weight,
// ffn1_weight,
// ffn1_scale,
// ffn1_bias,
// ffn2_weight,
// ffn2_scale,
// ffn2_bias,
// quant_method,
// moe_topk,
// group_moe,
// norm_topk_prob,
// &output);
// break;
default:
PD_THROW("Only support bf16 for FusedMoeKernel");
}
return {output};
}
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gate_weight_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
return {input_shape};
}
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gate_weight_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
return {input_dtype};
}
PD_BUILD_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn2_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"moe_topk:int",
"norm_topk_prob:bool",
"group_moe:bool"})
.SetKernelFn(PD_KERNEL(FusedExpertMoe))
.SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "fused_moe_op.h"
using namespace phi;
template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k) {
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
if (moe_token_index >= num_rows) {
return;
}
gating_output[moe_token_index * 2] =
gating_output[moe_token_index * 2] +
(moe_token_type_ids_out[moe_token_index]) * -1e10;
gating_output[moe_token_index * 2 + 1] =
gating_output[moe_token_index * 2 + 1] +
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
}
template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k,
cudaStream_t stream) {
const int blocks = num_rows * k / 512 + 1;
const int threads = 512;
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
}

View File

@@ -0,0 +1,123 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <sstream>
#include "cub/cub.cuh"
static const float HALF_FLT_MAX = 65504.F;
static const float HALF_FLT_MIN = -65504.F;
static inline size_t AlignTo16(const size_t& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
class CubKeyValueSorter {
public:
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}
explicit CubKeyValueSorter(const int num_experts)
: num_experts_(num_experts),
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
void update_num_experts(const int num_experts) {
num_experts_ = num_experts;
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
}
size_t getWorkspaceSize(const size_t num_key_value_pairs,
bool descending = false) {
num_key_value_pairs_ = num_key_value_pairs;
size_t required_storage = 0;
int* null_int = nullptr;
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
32);
} else {
cub::DeviceRadixSort::SortPairs(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
num_bits_);
}
return required_storage;
}
template <typename KeyT>
void run(void* workspace,
const size_t workspace_size,
const KeyT* keys_in,
KeyT* keys_out,
const int* values_in,
int* values_out,
const size_t num_key_value_pairs,
bool descending,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
size_t actual_ws_size = workspace_size;
if (expected_ws_size > workspace_size) {
std::stringstream err_ss;
err_ss << "[Error][CubKeyValueSorter::run]\n";
err_ss << "Error. The allocated workspace is too small to run this "
"problem.\n";
err_ss << "Expected workspace size of at least " << expected_ws_size
<< " but got problem size " << workspace_size << "\n";
throw std::runtime_error(err_ss.str());
}
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
32,
stream);
} else {
cub::DeviceRadixSort::SortPairs(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
num_bits_,
stream);
}
}
private:
size_t num_key_value_pairs_;
int num_experts_;
int num_bits_;
};

View File

@@ -0,0 +1,990 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include "fused_moe_imp_op.h"
#include "fused_moe_helper.h"
#include "mctlass/numeric_conversion.h" // BUILD_MARK
// Ignore mctlass warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
// #include "paddle/phi/backends/gpu/gpu_info.h"
#pragma GCC diagnostic pop
#include "helper.h"
#define WARP_SIZE 32
struct GpuLaunchConfig {
dim3 block_per_grid;
dim3 thread_per_block;
};
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
int blocks_x = cols;
int blocks_y = 1;
int blocks_z = 1;
if (blocks_x > 1024) {
blocks_y = 256;
blocks_x = (blocks_x + blocks_y - 1) / blocks_y;
}
GpuLaunchConfig config;
config.block_per_grid.x = blocks_x;
config.block_per_grid.y = blocks_y;
config.block_per_grid.z = blocks_z;
return config;
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing
// the output in the softmax kernel when we extend this module to support
// expert-choice routing.
template <typename T, int TPB>
__launch_bounds__(TPB) __global__
void group_moe_softmax(const T* input,
T* output,
T* softmax_max_prob,
const int64_t num_cols,
const int64_t softmax_num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
__shared__ float max_out;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= softmax_num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val =
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
threadData = max(static_cast<float>(T(val)), threadData);
}
const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
// group max probs
max_out = 1.f / maxOut;
softmax_max_prob[globalIdx] = T(max_out);
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
// group softmax normalization
output[idx] = output[idx] * static_cast<T>(max_out);
}
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
T* output,
int* indices,
int* source_rows,
T* softmax_max_prob,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// restore normalized probes
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
T* output,
const int64_t num_cols,
const int64_t num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val =
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
}
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
T* output,
int* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the
MoE layers are a small power of 2. This allows us to cleanly share the rows
among the threads in a single warp and eliminate communication between warps
(so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is intended for when the number of experts is a small
power of 2. 2) This implementation assumes k is small, but will work for any
k.
*/
template <typename T,
int VPT,
int NUM_EXPERTS,
int WARPS_PER_CTA,
int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topk_gating_softmax(const T* input,
T* output,
const int64_t num_rows,
int* indices,
int* source_rows,
const int64_t k) {
// We begin by enforcing compile time assertions and setting up compile time
// constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS),
"NUM_EXPERTS must be power of 2");
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
"BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
// Restrictions based on previous section.
static_assert(
VPT % ELTS_PER_LDG == 0,
"The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE % THREADS_PER_ROW == 0,
"The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
"THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE,
"THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
"The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time
// variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a
// block contains WARPS_PER_CTA warps. This, each block processes a chunk of
// rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows) return;
const bool should_process_row = true;
// We finally start setting up the read pointers for each thread. First, each
// thread jumps to the start of the row it will read.
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the
// first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
// up to 16.
using AccessType = mctlass::AlignedArray<T, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem
mctlass::Array<T, VPT> row_chunk_input;
AccessType* row_chunk_vec_ptr =
reinterpret_cast<AccessType*>(&row_chunk_input);
const AccessType* vec_thread_read_ptr =
reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
using ComputeType = float;
using Converter = mctlass::NumericArrayConverter<ComputeType, T, VPT>;
Converter compute_type_converter;
mctlass::Array<ComputeType, VPT> row_chunk =
compute_type_converter(row_chunk_input);
// First, we perform a max reduce within the thread. We can do the max in fp16
// safely (I think) and just convert to float afterwards for the exp + sum
// reduction.
ComputeType thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii) {
thread_max = max(thread_max, row_chunk[ii]);
}
// Now, we find the max within the thread group and distribute among the
// threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
thread_max =
max(thread_max,
__shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp.
// We also compute the thread local sum.
float row_sum = 0;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
row_sum += row_chunk[ii];
}
// Now, we perform the sum reduce within each thread group. Similar to the max
// reduce, we use a bufferfly pattern.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
}
// From this point, all threads have the max and the sum for their rows in the
// thread_max and thread_sum variables respectively. Finally, we can scale the
// rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only
// compute for the top-k values in the row. However, this kernel will likely
// not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const float reciprocal_row_sum = 1.f / row_sum;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
// the topk elements in each row, along with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
++ldg, col += COLS_PER_GROUP_LDG) {
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index
// are processed first and only updated if > (not >=)
if (val > max_val) {
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
// reach consensus about the max. This will be useful for K > 1 so that the
// threads can agree on "who" had the max value. That thread can then blank out
// their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
float other_max =
__shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert =
__shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this
// way
if (other_max > max_val ||
(other_max == max_val && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0) {
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = T(max_val);
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
// Finally, we clear the value in the thread with the current max if there
// is another iteration to run.
if (k_idx + 1 < k) {
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group =
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the
// "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group) {
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be
// between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
ComputeType(-10000.f);
}
}
}
}
namespace detail {
// Constructs some constants needed to partition the work across threads at
// compile time.
template <typename T, int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
"");
static constexpr int VECs_PER_THREAD =
std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
template <typename T, int EXPERTS, int WARPS_PER_TB>
void topk_gating_softmax_launcher_helper(const T* input,
T* output,
int* indices,
int* source_row,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
cudaStream_t stream) {
static constexpr uint64_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG =
std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
<<<num_blocks, block_dim, 0, stream>>>(
input, output, num_rows, indices, source_row, k);
}
template <typename T>
void topk_gating_softmax_kernelLauncher(const T* input,
T* output,
T* softmax,
int* indices,
int* source_row,
T* softmax_max_prob,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
const bool group_moe,
cudaStream_t stream,
const bool topk_only_mode = false) {
if (topk_only_mode) {
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, output, indices, source_row, num_experts, k, num_rows);
return;
}
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
switch (num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
default: {
static constexpr int TPB = 256;
if (group_moe) {
const int group_experts = num_experts / k;
const int softmax_num_rows = num_rows * k;
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
group_moe_softmax<T, TPB>
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
input,
softmax,
softmax_max_prob,
group_experts,
softmax_num_rows);
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
source_row,
softmax_max_prob,
num_experts,
k,
num_rows);
} else {
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
}
// ========================== Permutation things
// =======================================
// Duplicated and permutes rows for MoE. In addition, reverse the permutation
// map to help with finalizing routing.
// "expanded_x_row" simply means that the number of values is num_rows x k. It
// is "expanded" since we will have to duplicate some rows in the input matrix
// to match the dimensions. Duplicates will always get routed to separate
// experts in the end.
// Note that the expanded_dest_row_to_expanded_source_row map referred to here
// has indices in the range (0, k*rows_in_input - 1). However, it is set up so
// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map
// to row 0 in the original matrix. Thus, to know where to read in the source
// matrix, we simply take the modulus of the expanded index.
template <typename T, int VecSize>
__global__ void initialize_moe_routing_kernel(
const T* unpermuted_input,
T* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t num_rows_k) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
if (expanded_dest_row >= num_rows_k) return;
const int expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}
if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) {
// Duplicate and permute rows
const int source_row = expanded_source_row % num_rows;
const T* source_row_ptr = unpermuted_input + source_row * cols;
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) {
// dest_row_ptr[tid] = source_row_ptr[tid];
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}
}
}
template <typename T>
void initialize_moe_routing_kernelLauncher(
const T* unpermuted_input,
T* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t k,
cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
constexpr int max_pack_size = 16 / sizeof(T);
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
if (cols % max_pack_size == 0) {
initialize_moe_routing_kernel<T, max_pack_size>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
} else {
initialize_moe_routing_kernel<T, 1>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
}
}
// ============================== Infer GEMM sizes
// =================================
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
const int64_t arr_length,
const int64_t target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] > target) {
high = mid - 1;
} else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
void compute_total_rows_before_expert(int* sorted_indices,
const int64_t total_indices,
const int64_t num_experts,
int32_t* total_rows_before_expert,
cudaStream_t stream);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, int RESIDUAL_NUM>
__global__ void finalize_moe_routing_kernel(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const float* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int64_t cols,
const int64_t k,
const int64_t compute_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int64_t num_rows) {
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
// const int original_row = blockIdx.x;
// const int num_rows = gridDim.x;
if (original_row >= num_rows) return;
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
T thread_output{0.f};
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expanded_original_row = original_row + k_idx * num_rows;
const int expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
const int64_t k_offset = original_row * k + k_idx;
const float row_scale = scales[k_offset];
row_rescale = row_rescale + row_scale;
const T* expanded_permuted_rows_row_ptr =
expanded_permuted_rows + expanded_permuted_row * cols;
const int expert_idx = expert_for_source_row[k_offset];
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
thread_output =
static_cast<float>(thread_output) +
row_scale * static_cast<float>(
expanded_permuted_rows_row_ptr[tid] +
bias_value *
static_cast<T>(static_cast<float>(compute_bias)));
}
thread_output = static_cast<float>(thread_output) /
(norm_topk_prob ? row_rescale : 1.0f) *
routed_scaling_factor;
reduced_row_ptr[tid] = thread_output;
}
}
template <typename T>
void finalize_moe_routing_kernelLauncher(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const float* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int64_t num_rows,
const int64_t cols,
const int64_t k,
const int64_t compute_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
finalize_moe_routing_kernel<T, 1>
<<<config_final.block_per_grid, threads, 0, stream>>>(
expanded_permuted_rows,
reduced_unpermuted_output,
bias,
scales,
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
cols,
k,
compute_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows);
}
// ========================= TopK Softmax specializations
// ===========================
template void topk_gating_softmax_kernelLauncher(const float*,
float*,
float*,
int*,
int*,
float*,
const int64_t,
const int64_t,
const int64_t,
const bool,
cudaStream_t,
const bool);
template void topk_gating_softmax_kernelLauncher(const half*,
half*,
half*,
int*,
int*,
half*,
const int64_t,
const int64_t,
const int64_t,
const bool,
cudaStream_t,
const bool);
#ifdef PADDLE_CUDA_BF16
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
int*,
int*,
__nv_bfloat16*,
const int64_t,
const int64_t,
const int64_t,
const bool,
cudaStream_t,
const bool);
#endif
// ===================== Specializations for init routing
// =========================
template void initialize_moe_routing_kernelLauncher(const float*,
float*,
const int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
cudaStream_t);
template void initialize_moe_routing_kernelLauncher(const half*,
half*,
const int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
cudaStream_t);
#ifdef PADDLE_CUDA_BF16
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*,
__nv_bfloat16*,
const int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
cudaStream_t);
#endif
// ==================== Specializations for final routing
// ===================================
template void finalize_moe_routing_kernelLauncher(const float*,
float*,
const float*,
const float*,
const int*,
const int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
const bool,
const float,
cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const half*,
half*,
const half*,
const float*,
const int*,
const int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
const bool,
const float,
cudaStream_t);
#ifdef PADDLE_CUDA_BF16
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*,
__nv_bfloat16*,
const __nv_bfloat16*,
const float*,
const int*,
const int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
const bool,
const float,
cudaStream_t);
#endif

View File

@@ -0,0 +1,417 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
#include "fused_moe_helper.h"
template <typename ElementA, typename ElementB, typename ElementC>
void mc_grouped_gemm_basic_kernel(
const ElementA* ptrA,
mctlassExOrder_t majorA,
const ElementB* ptrB,
mctlassExOrder_t majorB,
const ElementA* ptrScale,
const ElementA* ptrBias,
ElementC* ptrC,
mctlassExOrder_t majorC,
const int *ptrSegInd,
int numExperts,
int m, // expanded_active_expert_rows
int n, // inter_dim
int k, // hidden_size
mcStream_t stream) {
mctlassExHandle_t handle;
mctlassExHandleCreate(&handle);
int* ptrMNumTilesInd;
mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
mctlassExMatrixLayout_t matLayoutA;
mctlassExMatrixLayout_t matLayoutB;
mctlassExMatrixLayout_t matLayoutC;
// mat A: (m, k)
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA, sizeof(mctlassExOrder_t));
// mat B: (num_experts, n, k)
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k);
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// mat C: (m, n)
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n);
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC, sizeof(mctlassExOrder_t));
// bias: (num_experts, n)
// scale: (num, n)
mctlassExDesc_t mctlass_desc;
mctlassExCreateDesc(&mctlass_desc);
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32;
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT;
if (ptrBias) {
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP;
}
// set scale
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER,
&ptrScale, sizeof(ptrScale));
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE,
&scale_type, sizeof(mctlassExDataType));
// set bias
if (ptrBias) {
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER,
&ptrBias, sizeof(ptrBias));
}
// set coumpute type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE,
&compute_type, sizeof(mctlassExDataType));
// set epilogue type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE,
&epilogue_type, sizeof(mctlassExEpilogueType));
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR;
int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM);
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
ptrA, matLayoutA,
ptrB, matLayoutB,
ptrC, matLayoutC,
ptrSegInd, nullptr, ptrMNumTilesInd,
&algo, nullptr, 0, stream);
mctlassExHandleDestroy(handle);
mctlassExMatrixLayoutDestroy(matLayoutA);
mctlassExMatrixLayoutDestroy(matLayoutB);
mctlassExMatrixLayoutDestroy(matLayoutC);
mctlassExDestroyDesc(mctlass_desc);
mcFreeAsync(ptrMNumTilesInd, stream);
}
template<typename T, typename ElementA, typename ElementB, typename ElementC>
class McMoeHelper {
public:
McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {}
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
num_softmax_outs = AlignTo16(num_rows * num_experts);
}
total_ws_bytes += num_softmax_outs * sizeof(float);
return total_ws_bytes;
}
void computeFFN(const paddle::Tensor *input,
const paddle::Tensor *gate_weight,
const paddle::Tensor *ffn1_weight,
const paddle::Tensor *ffn1_scale,
const paddle::Tensor *ffn1_bias,
const paddle::Tensor *ffn2_weight,
const paddle::Tensor *ffn2_scale,
const paddle::Tensor *ffn2_bias,
const paddle::Tensor *moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
auto *output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
auto input_dims = input->dims();
auto ffn1_dims = ffn1_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int64_t num_rows = token_num;
const int64_t hidden_size = ffn1_dims[2];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
} else {
inter_dim = ffn1_dims[1];
}
// if (gemm_method == "weight_only_int4") {
// inter_dim = inter_dim * 2;
// }
const int64_t inter_size = inter_dim;
const int64_t num_experts = ffn1_dims[0];
const int64_t k = moe_topk;
int64_t bytes =
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
// Pointers
int *expert_for_source_row;
int *source_rows_;
int *permuted_rows_;
int *permuted_experts_;
int *expanded_source_row_to_expanded_dest_row;
T *permuted_data_;
int32_t *total_rows_before_expert_;
T *fc1_result_;
float *softmax_out_;
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T *>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float *expert_scales_float = expert_scales_float_tensor.data<float>();
float *softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T *fc1_out = fc1_out_tensor.data<T>();
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float *gating_output = gate_tensor.data<float>();
if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
num_experts,
k,
stream);
}
topk_gating_softmax_kernelLauncher<float>(gating_output,
expert_scales_float,
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
sorter_.run(fc1_result_,
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
initialize_moe_routing_kernelLauncher(
input_activations,
permuted_data_,
permuted_rows_,
expanded_source_row_to_expanded_dest_row,
num_rows,
num_rows,
hidden_size,
k,
stream);
const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows,
num_experts,
total_rows_before_expert_,
stream);
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_data_),
row_major,
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
reinterpret_cast<const ElementA *>(fc1_expert_biases),
reinterpret_cast<ElementC *>(fc1_out),
row_major,
total_rows_before_expert_,
num_experts,
expanded_active_expert_rows,
inter_size,
hidden_size,
stream);
if (moe_type == "ffn") {
auto act_out_tensor =
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<T>();
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T *fc2_result = fc2_output_tensor.data<T>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(act_out),
row_major,
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
nullptr,
reinterpret_cast<ElementC *>(fc2_result),
row_major,
total_rows_before_expert_,
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
stream);
finalize_moe_routing_kernelLauncher(
fc2_result,
output_,
fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
}
}
private:
std::string gemm_method_;
CubKeyValueSorter sorter_;
};

View File

@@ -0,0 +1,274 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma once
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
#pragma GCC diagnostic pop
#include "helper.h"
template <paddle::DataType T>
void MoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor* permute_input,
paddle::Tensor* tokens_expert_prefix_sum,
paddle::Tensor* permute_indices_per_token,
paddle::Tensor* top_k_weight,
paddle::Tensor* top_k_indices) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = input.stream();
auto place = input.place();
if (group_moe) {
// Check if expert_num is divisible by moe_topk, else throw an error
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
0,
common::errors::InvalidArgument(
"The number of experts (expert_num) "
"must be divisible by moe_topk. "
"Got expert_num = %d and moe_topk = %d.",
expert_num,
moe_topk));
}
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
const int bytes = num_moe_inputs * sizeof(int);
CubKeyValueSorter sorter_;
sorter_.update_num_experts(expert_num);
const int sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows));
const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int);
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
paddle::DataType::INT8,
place);
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
int* source_rows_ = reinterpret_cast<int*>(ws_ptr);
int8_t* sorter_ws_ptr = reinterpret_cast<int8_t*>(ws_ptr + bytes);
int* permuted_experts_ =
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
int* expert_for_source_row = top_k_indices->data<int>();
float* softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
float* softmax_out_;
const bool is_pow_2 =
(expert_num != 0) && ((expert_num & (expert_num - 1)) == 0);
paddle::Tensor softmax_buffer;
if (!is_pow_2 || expert_num > 256 || group_moe) {
softmax_buffer = GetEmptyTensor(
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
softmax_out_ = softmax_buffer.data<float>();
} else {
softmax_out_ = nullptr;
}
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
top_k_weight->data<float>(),
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode);
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
moe_topk * num_rows,
false,
stream);
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<data_t>(),
permuted_rows_,
permute_indices_per_token->data<int32_t>(),
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
compute_total_rows_before_expert(
permuted_experts_,
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int32_t>(),
stream);
}
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
int token_rows = 0;
auto input_dims = input.dims();
auto gating_dims = gating_output.dims();
const int expert_num = gating_dims[gating_dims.size() - 1];
if (input_dims.size() == 3) {
token_rows = input_dims[0] * input_dims[1];
} else {
token_rows = input_dims[0];
}
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];
auto permute_input =
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
// correspond to the weighted coefficients of the results from each expert.
auto top_k_weight =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
auto top_k_indices =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
auto tokens_expert_prefix_sum =
GetEmptyTensor({expert_num}, paddle::DataType::INT32, place);
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
gating_output,
moe_topk,
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&top_k_weight,
&top_k_indices);
break;
// case paddle::DataType::FLOAT16:
// MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
// gating_output,
// moe_topk,
// group_moe,
// topk_only_mode,
// num_rows,
// hidden_size,
// expert_num,
// &permute_input,
// &tokens_expert_prefix_sum,
// &permute_indices_per_token,
// &top_k_weight,
// &top_k_indices);
// break;
default:
PD_THROW("Only support bf16 for MoeDispatchKernel");
}
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weight,
top_k_indices};
}
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gating_output_shape,
const int moe_topk) {
int token_rows = -1;
if (input_shape.size() == 3) {
token_rows = input_shape[0] * input_shape[1];
} else {
token_rows = input_shape[0];
}
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
const int num_rows = token_rows;
const int hidden_size = input_shape[input_shape.size() - 1];
return {{moe_topk * num_rows, hidden_size},
{expert_num},
{moe_topk, num_rows},
{num_rows, moe_topk},
{num_rows, moe_topk}};
}
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gating_output_dtype,
const int moe_topk) {
return {input_dtype,
paddle::DataType::INT64,
paddle::DataType::INT32,
paddle::DataType::FLOAT32,
paddle::DataType::INT32};
}
PD_BUILD_OP(moe_expert_dispatch)
.Inputs({"input", "gating_output"})
.Outputs({"permute_input",
"tokens_expert_prefix_sum",
"permute_indices_per_token",
"top_k_weight",
"top_k_indices"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -0,0 +1,173 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mc_fused_moe_helper.h"
#include "helper.h"
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
void McMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
paddle::Tensor ffn_out) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto ffn_out_ptr = ffn_out.data<data_t>();
auto permuted_input_ptr = permute_input.data<data_t>();
auto place = permute_input.place();
auto input_type = permute_input.dtype();
auto stream = permute_input.stream();
const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k
const int num_experts = ffn1_weight.dims()[0]; // batchsize
const int hidden_size = ffn1_weight.dims()[2]; // n
int inter_dim = ffn1_weight.dims()[1]; // k
const int64_t inter_size = inter_dim; // since weight_only_int_8
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
{expanded_active_expert_rows, inter_size}, input_type, place);
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
// ffn1
auto fc1_expert_biases =
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
: nullptr;
auto fc1_expert_scales = const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_input_ptr),
row_major,
reinterpret_cast<const ElementB *>(ffn1_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(fc1_expert_scales),
reinterpret_cast<const ElementA *>(fc1_expert_biases),
reinterpret_cast<ElementC *>(fc1_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
num_experts,
expanded_active_expert_rows,
inter_dim,
hidden_size,
stream);
// swiglu
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<data_t>();
auto fc2_expert_scales = const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(act_out),
row_major,
reinterpret_cast<const ElementB *>(ffn2_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(fc2_expert_scales),
nullptr,
reinterpret_cast<ElementC *>(ffn_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_dim / 2,
stream);
}
std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method) {
assert(quant_method == "weight_only_int8");
const auto input_type = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input);
switch (input_type) {
case paddle::DataType::BFLOAT16:
McMoeFFNKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
quant_method,
ffn_out);
break;
// case paddle::DataType::FLOAT16:
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
// tokens_expert_prefix_sum,
// ffn1_weight,
// ffn2_weight,
// ffn1_bias,
// ffn1_scale,
// ffn2_scale,
// quant_method,
// ffn_out);
// break;
default:
PD_THROW("Only support bf16 for MoeExpertFFN");
}
return {ffn_out};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType& permute_input_dtype,
const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
return {permute_input_dtype};
}
PD_BUILD_OP(moe_expert_ffn)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -0,0 +1,143 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
template <paddle::DataType T>
void MoeReduceKernel(const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int num_rows,
const int hidden_size,
const int topk,
paddle::Tensor* output) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = ffn_out.stream();
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(),
output->data<data_t>(),
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
num_rows,
hidden_size,
topk,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
}
std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
const int topk = top_k_indices.dims()[1];
const int num_rows = ffn_out.dims()[0] / topk;
const int hidden_size = ffn_out.dims()[1];
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
// Avoids invalid configuration argument when we launch the kernel.
if (ffn_out.dims()[0] == 0) return {output};
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
// case paddle::DataType::FLOAT16:
// MoeReduceKernel<paddle::DataType::FLOAT16>(ffn_out,
// top_k_weight,
// permute_indices_per_token,
// top_k_indices,
// ffn2_bias,
// norm_topk_prob,
// routed_scaling_factor,
// num_rows,
// hidden_size,
// topk,
// &output);
// break;
default:
PD_THROW("Only support bf16 for MoeDispatchKernel");
}
return {output};
}
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t>& ffn_out_shape,
const std::vector<int64_t>& top_k_weight_shape,
const std::vector<int64_t>& permute_indices_per_token_shape,
const std::vector<int64_t>& top_k_indices_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
const int topk = top_k_indices_shape[1];
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
ffn_out_shape[1]};
return {fused_moe_out_shape};
}
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType& ffn_out_dtype,
const paddle::DataType& top_k_weight_dtype,
const paddle::DataType& permute_indices_per_token_dtype,
const paddle::DataType& top_k_indices_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
return {ffn_out_dtype};
}
PD_BUILD_OP(moe_expert_reduce)
.Inputs({"ffn_out",
"top_k_weight",
"permute_indices_per_token",
"top_k_indices",
paddle::Optional("ffn2_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype));

View File

@@ -597,6 +597,10 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/recover_decode_task.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
"metax_ops/fused_moe.cu",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
@@ -617,7 +621,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
],
},
library_dirs=[os.path.join(maca_path, "lib")],
extra_link_args=["-lruntime_cu"],
extra_link_args=["-lruntime_cu", "-lmctlassEx"],
include_dirs=[
os.path.join(maca_path, "include"),
os.path.join(maca_path, "include/mcr"),

View File

@@ -19,8 +19,8 @@ docker login --username=cr_temp_user --password=eyJpbnN0YW5jZUlkIjoiY3JpLXpxYTIz
## 2. paddlepaddle and custom device installation
```shell
1pip install paddlepaddle==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
2pip install paddle-metax-gpu==3.0.0.dev20250807 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/
1pip install paddlepaddle==3.0.0.dev20250825 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
2pip install paddle-metax-gpu==3.0.0.dev20250826 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/
```
## 3. Build Wheel from Source
@@ -47,6 +47,8 @@ from fastdeploy.model_executor.ops.gpu import beam_search_softmax
If the above code executes successfully, the environment is ready.
## 5. Demo
```python
from fastdeploy import LLM, SamplingParams
prompts = [
@@ -68,7 +70,9 @@ for output in outputs:
print(prompt)
print(generated_text)
print("-" * 50)
```
```
Output
INFO 2025-08-18 10:54:18,455 416822 engine.py[line:202] Waiting worker processes ready...
Loading Weights: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [03:33<00:00, 2.14s/it]
@@ -81,3 +85,4 @@ Generated 1 outputs
Hello. My name is
Alice and I'm here to help you. What can I do for you today?
Hello Alice! I'm trying to organize a small party
```

View File

@@ -13,9 +13,11 @@
# limitations under the License.
from .attention.flash_attn_backend import FlashAttentionBackend
from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
__all__ = [
"FlashAttentionBackend",
"MetaxTritonWeightOnlyMoEMethod",
"MetaxCutlassWeightOnlyMoEMethod",
]

View File

@@ -1,3 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Tuple, Union

View File

@@ -1,4 +1,3 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
@@ -261,27 +259,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta: ForwardMeta,
cu_seqlens_q: paddle.Tensor,
batch_ids=None,
is_decode=False,
):
q_end = self.num_heads * self.head_dim
k_end = q_end + self.kv_num_heads * self.head_dim
v_end = k_end + self.kv_num_heads * self.head_dim
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
if batch_ids is None:
batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
q = qkv[..., 0:q_end]
k = qkv[..., q_end:k_end]
v = qkv[..., k_end:v_end]
q = q.view([-1, self.num_heads, self.head_dim])
k = k.view([-1, self.kv_num_heads, self.head_dim])
v = v.view([-1, self.kv_num_heads, self.head_dim])
if is_decode:
return q, k, v
qkv = qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
for idx in range(len(cu_seqlens_q) - 1):
batch_idx = batch_ids[idx]
@@ -375,41 +355,6 @@ class FlashAttentionBackend(AttentionBackend):
cache_start += self.block_size
tensor_start = tensor_end
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
if prefill_out is None:
return decode_out
elif decode_out is None:
return prefill_out
else:
prefill_out = prefill_out
decode_out = decode_out
merged_output = []
prefill_tensor_start = 0
decode_tensor_start = 0
for seq_lens_this_time in forward_meta.seq_lens_this_time:
if seq_lens_this_time == 0:
continue
if seq_lens_this_time > 1:
tensor_end = prefill_tensor_start + seq_lens_this_time
merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
prefill_tensor_start = tensor_end
else:
assert seq_lens_this_time == 1
tensor_end = decode_tensor_start + seq_lens_this_time
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
decode_tensor_start = tensor_end
assert (
prefill_tensor_start == prefill_out.shape[0]
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
assert (
decode_tensor_start == decode_out.shape[0]
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
merged_output = paddle.concat(merged_output, axis=0)
return merged_output
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
@@ -438,23 +383,17 @@ class FlashAttentionBackend(AttentionBackend):
return prefill_out
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
cache_k = forward_meta.caches[k_cache_id]
cache_v = forward_meta.caches[v_cache_id]
cu_seq_lens = list(range(self.decode_len + 1))
q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True)
decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim])
decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
qkv = decode_qkv.view([-1, 1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
decode_out = flash_attn_kvcache_func(
decoder_q,
cache_k,
cache_v,
q,
forward_meta.caches[k_cache_id],
forward_meta.caches[v_cache_id],
self.seq_lens_dec,
self.block_table_dec,
decoder_k_,
decoder_v_,
k,
v,
rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"),
rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"),
causal=self.causal,

View File

@@ -0,0 +1,370 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from paddle.nn.quant import weight_quantize
import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import fused_expert_moe
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
class MetaxCutlassMoEMethod(MoEMethodBase):
"""
Use Cutlass Group Gemm to compute Fused MoE.
This method is the oldest way to compute MoE in Paddle.
"""
def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
layer.down_proj_weight.set_value(stacked_down_proj_weights)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
"""
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
"weight_only_int8",
)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
fused_moe_out = fused_expert_moe(
x,
gate.weight,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
None,
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
"weight_only_int8",
layer.top_k,
True,
False,
)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out
class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
"""
weight only for moe
"""
def __init__(self, quant_config=None):
"""
weight only for moe
"""
super().__init__(quant_config)
# print(f"[DEBUG] quant_config: {quant_config}")
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Paddle cutlass create weight process.
"""
self.default_dtype = layer._helper.get_default_dtype()
if self.moe_quant_type == "weight_only_int4":
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
else:
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
if self.moe_quant_type == "weight_only_int4":
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size // 2,
layer.moe_intermediate_size,
]
else:
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
if layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
self.weight_dtype = "int8"
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
scale_extra_weight_attrs = {
**extra_weight_attrs,
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None},
}
set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs)
def process_weights_after_loading(self, layer):
""" """
if not layer.fd_config.load_config.load_choices == "default_v1":
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
else:
weight_type = "down"
# 1.init shape and type
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
# 2.crate tmp tensor
weight = paddle.empty(weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
# 3.quantize weight
for expert_id in range(layer.num_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type, arch=80, group_size=-1
)
free_tensor(getattr(layer, unquantized_weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
weight_list = []
weight_scale_list = []
for i in range(layer.num_local_experts):
quant_weight, scale = weight_quantize(
weight_tensor[i], algo=self.moe_quant_type, arch=80, group_size=-1
)
quant_weight = paddle.transpose(quant_weight, [1, 0])
weight_list.append(quant_weight)
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
getattr(layer, scale_name).set_value(quanted_weight_scale)

View File

@@ -1,4 +1,3 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.utils import ceil_div
@@ -153,7 +152,6 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
Triton compute Fused MoE.
"""
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
@@ -172,21 +170,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
dtype=x.dtype,
)
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
}
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
@@ -292,4 +281,6 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out

View File

@@ -1,5 +1,4 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton
import triton.language as tl

View File

@@ -50,12 +50,7 @@ def get_moe_method():
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
return GCUFusedMoeMethod(None)
elif current_platform.is_maca():
from fastdeploy.model_executor.layers.backends import (
MetaxTritonWeightOnlyMoEMethod,
)
return MetaxTritonWeightOnlyMoEMethod(None)
elif current_platform.is_intel_hpu():
from fastdeploy.model_executor.layers.backends import HpuMoEMethod

View File

@@ -123,10 +123,18 @@ class WeightOnlyConfig(QuantConfigBase):
elif current_platform.is_maca():
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.backends import (
MetaxCutlassWeightOnlyMoEMethod,
MetaxTritonWeightOnlyMoEMethod,
)
return MetaxTritonWeightOnlyMoEMethod(self)
if layer.use_method == "cutlass":
return MetaxCutlassWeightOnlyMoEMethod(self)
elif layer.use_method == "triton":
return MetaxTritonWeightOnlyMoEMethod(self)
else:
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
else:
return GPUWeightOnlyLinearMethod(self)

View File

@@ -8,9 +8,9 @@ aiozmq
openai>=1.93.0
tqdm
pynvml
uvicorn
uvicorn==0.29.0
fastapi
paddleformers
paddleformers>=0.2
redis
etcd3
httpx
@@ -30,11 +30,12 @@ use-triton-in-paddle
crcmod
fastsafetensors==0.1.14
msgpack
modelscope
opentelemetry-api>=1.24.0
opentelemetry-sdk>=1.24.0
opentelemetry-instrumentation-redis
opentelemetry-instrumentation-mysql
opentelemetry-distro 
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
partial_json_parser