[Metax] optimize cutlass moe and flash attention backend (#5128)

This commit is contained in:
Neil Zhu
2025-11-20 16:12:35 +08:00
committed by GitHub
parent f1e36ff2f7
commit 0edda75a56
5 changed files with 469 additions and 161 deletions

View File

@@ -0,0 +1,291 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/extension.h>
#include <algorithm>
#include "helper.h"
#define THREADS_PER_BLOCK 128
template <typename T>
struct Converter;
template <>
struct Converter<__half> {
// __half -> float
__device__ static float to_float(__half val) { return __half2float(val); }
// float -> __half
__device__ static __half from_float(float val) {
return __float2half_rn(val);
}
// int -> __half
__device__ static __half from_int(float val) { return __int2half_rn(val); }
};
template <>
struct Converter<__nv_bfloat16> {
// __nv_bfloat16 -> float
__device__ static float to_float(__nv_bfloat16 val) {
return __bfloat162float(val);
}
// float -> __nv_bfloat16
__device__ static __nv_bfloat16 from_float(float val) {
return __float2bfloat16_rn(val);
}
// int -> __nv_bfloat16
__device__ static __nv_bfloat16 from_int(int val) {
return __int2bfloat16_rn(val);
}
};
template <typename T>
__device__ void RotateQKVec4(const T* qk_ptr,
const T* rot_cos_ptr,
const T* rot_sin_ptr,
const int head_num,
const int base_idx,
const int rot_base_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
VecT qk_vec;
Load(qk_ptr + base_idx, &qk_vec);
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
VecT cos_vec, sin_vec;
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(out + base_idx + i) =
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
}
}
template <typename T>
__device__ void RotateQKVec4(const T* qk_ptr,
const float* rot_cos_ptr,
const float* rot_sin_ptr,
const int head_num,
const int base_idx,
const int rot_base_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
using VecF = AlignedVector<float, 4>;
auto to_float = [] __device__(T val) -> float {
return Converter<T>::to_float(val);
};
auto from_float = [] __device__(float val) -> T {
return Converter<T>::from_float(val);
};
VecT qk_vec;
Load(qk_ptr + base_idx, &qk_vec);
VecF rot_half_vec = {-to_float(qk_vec[1]),
to_float(qk_vec[0]),
-to_float(qk_vec[3]),
to_float(qk_vec[2])};
VecF cos_vec, sin_vec;
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
rot_half_vec[i] * sin_vec[i]);
}
}
// qk and rope have a same type
template <typename T>
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
const T* k,
const T* rot_cos,
const T* rot_sin,
const int q_num_elements,
const int k_num_elements,
const int q_head_num,
const int k_head_num,
const int head_dim,
T* q_out,
T* k_out) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
int head_dim_idx = idx % head_dim;
if (idx < q_num_elements) {
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
}
if (idx < k_num_elements) {
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
}
}
// rope dtype is float32
template <typename T>
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
const T* k,
const float* rot_cos,
const float* rot_sin,
const int q_num_elements,
const int k_num_elements,
const int q_head_num,
const int k_head_num,
const int head_dim,
T* q_out,
T* k_out) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
int head_dim_idx = idx % head_dim;
if (idx < q_num_elements) {
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
}
if (idx < k_num_elements) {
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
}
}
template <paddle::DataType D>
void ApplyRopeKernel(const paddle::Tensor& q,
const paddle::Tensor& k,
const paddle::Tensor& rot_cos,
const paddle::Tensor& rot_sin,
paddle::Tensor& q_out,
paddle::Tensor& k_out) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const auto q_num_elements = q.numel();
const auto k_num_elements = k.numel();
const auto q_shape = q.shape();
const auto k_shape = k.shape();
const auto dims = q_shape.size();
const auto q_head_num = q_shape[dims - 2];
const auto k_head_num = k_shape[dims - 2];
const auto head_dim = q_shape.back();
int block_num =
(std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) /
(THREADS_PER_BLOCK * 4);
auto stream = q.stream();
if (q.dtype() == rot_cos.dtype()) {
DispatchApplyRopeVec4Kernel<DataType_>
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const DataType_*>(q.data<data_t>()),
reinterpret_cast<const DataType_*>(k.data<data_t>()),
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
q_num_elements,
k_num_elements,
q_head_num,
k_head_num,
head_dim,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
DispatchApplyRopeVec4Kernel<DataType_>
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const DataType_*>(q.data<data_t>()),
reinterpret_cast<const DataType_*>(k.data<data_t>()),
reinterpret_cast<const float*>(rot_cos.data<float>()),
reinterpret_cast<const float*>(rot_sin.data<float>()),
q_num_elements,
k_num_elements,
q_head_num,
k_head_num,
head_dim,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
} else {
PD_THROW("Unsupported qk dtype and rope dtype.");
}
}
std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q,
const paddle::Tensor& k,
const paddle::Tensor& rot_cos,
const paddle::Tensor& rot_sin) {
auto q_shape = q.shape();
auto cos_shape = rot_cos.shape();
auto q_out = paddle::empty_like(q);
auto k_out = paddle::empty_like(k);
if (q.numel() == 0 || k.numel() == 0) {
return {q_out, k_out};
}
PADDLE_ENFORCE_EQ(
q_shape.back() % 2,
0,
"The last dimension (head_dim) of qk must be an even number "
"for RoPE, but got %d",
q_shape.back());
PADDLE_ENFORCE_EQ(q_shape.size(),
cos_shape.size(),
"The shape size of cos mismatches the shape size of q, "
"expect %d but got %d",
q_shape.size(),
cos_shape.size());
PADDLE_ENFORCE_EQ(q_shape.back(),
cos_shape.back(),
"The shape.back() of cos mismatches the shape.back() of q, "
"expect %d but got %d",
q_shape.back(),
cos_shape.back());
auto input_type = q.dtype();
switch (input_type) {
case paddle::DataType::BFLOAT16:
ApplyRopeKernel<paddle::DataType::BFLOAT16>(
q, k, rot_cos, rot_sin, q_out, k_out);
break;
case paddle::DataType::FLOAT16:
ApplyRopeKernel<paddle::DataType::FLOAT16>(
q, k, rot_cos, rot_sin, q_out, k_out);
break;
default:
PD_THROW("Only support qk dtype of BF16 and F16");
}
return {q_out, k_out};
}
std::vector<std::vector<int64_t>> ApplyRopeInferShape(
const std::vector<int64_t>& q_shape,
const std::vector<int64_t>& k_shape,
const std::vector<int64_t>& cos_shape,
const std::vector<int64_t>& sin_shape) {
return {q_shape, k_shape, cos_shape, sin_shape};
}
std::vector<paddle::DataType> ApplyRopeInferDtype(
const paddle::DataType& q_dtype,
const paddle::DataType& k_dtype,
const paddle::DataType& cos_dtype,
const paddle::DataType& sin_dtype) {
return {q_dtype, k_dtype, cos_dtype, sin_dtype};
}
PD_BUILD_OP(apply_rope)
.Inputs({"q", "k", "rot_cos", "rot_sin"})
.Outputs({"q_out", "k_out"})
.SetKernelFn(PD_KERNEL(ApplyRope))
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype));

View File

@@ -17,15 +17,16 @@
#include "mctlassEx/mctlassEx.h"
template <typename ElementA, typename ElementB, typename ElementC>
void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
void mc_grouped_gemm_basic_kernel(const ElementA* ptrA,
mctlassExOrder_t majorA,
const ElementB *ptrB,
const ElementB* ptrB,
mctlassExOrder_t majorB,
const ElementA *ptrScale,
const ElementA *ptrBias,
ElementC *ptrC,
const ElementA* ptrScale,
const ElementA* ptrBias,
ElementC* ptrC,
mctlassExOrder_t majorC,
const int *ptrSegInd,
const int* ptrSegInd,
int* ptrMNumTilesInd,
int numExperts,
int m, // expanded_active_expert_rows
int n, // inter_dim
@@ -34,9 +35,6 @@ void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
mctlassExHandle_t handle;
mctlassExHandleCreate(&handle);
int *ptrMNumTilesInd;
mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
mctlassExMatrixLayout_t matLayoutA;
mctlassExMatrixLayout_t matLayoutB;
mctlassExMatrixLayout_t matLayoutC;
@@ -170,7 +168,6 @@ void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
mctlassExMatrixLayoutDestroy(matLayoutC);
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
mctlassExDestroyDesc(mctlass_desc);
mcFreeAsync(ptrMNumTilesInd, stream);
}
template <typename T, typename ElementA, typename ElementB, typename ElementC>
@@ -227,27 +224,27 @@ class McMoeHelper {
return total_ws_bytes;
}
void computeFFN(const paddle::Tensor *input,
const paddle::Tensor *gate_weight,
const paddle::Tensor *ffn1_weight,
const paddle::Tensor *ffn1_scale,
const paddle::Tensor *ffn1_bias,
const paddle::Tensor *ffn2_weight,
const paddle::Tensor *ffn2_scale,
const paddle::Tensor *ffn2_bias,
const paddle::Tensor *moe_token_type_ids,
void computeFFN(const paddle::Tensor* input,
const paddle::Tensor* gate_weight,
const paddle::Tensor* ffn1_weight,
const paddle::Tensor* ffn1_scale,
const paddle::Tensor* ffn1_bias,
const paddle::Tensor* ffn2_weight,
const paddle::Tensor* ffn2_scale,
const paddle::Tensor* ffn2_bias,
const paddle::Tensor* moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
paddle::Tensor* output) {
auto* input_activations = input->data<T>();
auto* gating_weights = gate_weight->data<float>();
const T* fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T* fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
auto *output_ = output->data<T>();
auto* output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
@@ -282,52 +279,52 @@ class McMoeHelper {
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
// Pointers
int *expert_for_source_row;
int *source_rows_;
int *permuted_rows_;
int *permuted_experts_;
int *expanded_source_row_to_expanded_dest_row;
int* expert_for_source_row;
int* source_rows_;
int* permuted_rows_;
int* permuted_experts_;
int* expanded_source_row_to_expanded_dest_row;
T *permuted_data_;
int32_t *total_rows_before_expert_;
T *fc1_result_;
float *softmax_out_;
T* permuted_data_;
int32_t* total_rows_before_expert_;
T* fc1_result_;
float* softmax_out_;
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T *>(
permuted_data_ = reinterpret_cast<T*>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float *expert_scales_float = expert_scales_float_tensor.data<float>();
float* expert_scales_float = expert_scales_float_tensor.data<float>();
float *softmax_max_prob = nullptr;
float* softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
@@ -338,16 +335,16 @@ class McMoeHelper {
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T *fc1_out = fc1_out_tensor.data<T>();
T* fc1_out = fc1_out_tensor.data<T>();
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float *gating_output = gate_tensor.data<float>();
float* gating_output = gate_tensor.data<float>();
if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
@@ -403,17 +400,21 @@ class McMoeHelper {
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major =
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
auto m_num_tile =
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_data_),
reinterpret_cast<const ElementA*>(permuted_data_),
row_major,
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
reinterpret_cast<const ElementB*>(ffn1_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
reinterpret_cast<const ElementA *>(fc1_expert_biases),
reinterpret_cast<ElementC *>(fc1_out),
reinterpret_cast<const ElementA*>(ffn1_scale->data<T>()),
reinterpret_cast<const ElementA*>(fc1_expert_biases),
reinterpret_cast<ElementC*>(fc1_out),
row_major,
total_rows_before_expert_,
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
inter_size,
@@ -427,18 +428,19 @@ class McMoeHelper {
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T *fc2_result = fc2_output_tensor.data<T>();
T* fc2_result = fc2_output_tensor.data<T>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(act_out),
reinterpret_cast<const ElementA*>(act_out),
row_major,
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
reinterpret_cast<const ElementB*>(ffn2_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
reinterpret_cast<const ElementA*>(ffn2_scale->data<T>()),
nullptr,
reinterpret_cast<ElementC *>(fc2_result),
reinterpret_cast<ElementC*>(fc2_result),
row_major,
total_rows_before_expert_,
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
hidden_size,
@@ -449,7 +451,7 @@ class McMoeHelper {
fc2_result,
output_,
fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
reinterpret_cast<float*>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
@@ -465,7 +467,7 @@ class McMoeHelper {
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
reinterpret_cast<float*>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,

View File

@@ -21,20 +21,18 @@ template <paddle::DataType T,
typename ElementA,
typename ElementB,
typename ElementC>
void McMoeFFNKernel(const paddle::Tensor& permute_input,
void McMoeFFNKernel(paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
paddle::Tensor ffn_out) {
const std::string& quant_method) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto ffn_out_ptr = ffn_out.data<data_t>();
auto permuted_input_ptr = permute_input.data<data_t>();
auto place = permute_input.place();
auto input_type = permute_input.dtype();
@@ -54,6 +52,9 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major =
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
auto m_num_tile =
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
// ffn1
auto fc1_expert_biases =
@@ -72,6 +73,7 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
reinterpret_cast<ElementC*>(fc1_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
inter_dim,
@@ -91,9 +93,10 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
column_major,
reinterpret_cast<const ElementA*>(fc2_expert_scales),
nullptr,
reinterpret_cast<ElementC*>(ffn_out_ptr),
reinterpret_cast<ElementC*>(permuted_input_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
hidden_size,
@@ -102,7 +105,7 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
}
std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& permute_input,
paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
@@ -112,10 +115,9 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const std::string& quant_method) {
assert(quant_method == "weight_only_int8");
const auto input_type = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input);
if (permute_input.numel() == 0) {
return {ffn_out};
return {permute_input};
}
switch (input_type) {
@@ -130,8 +132,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
ffn1_bias,
ffn1_scale,
ffn2_scale,
quant_method,
ffn_out);
quant_method);
break;
// case paddle::DataType::FLOAT16:
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
@@ -147,7 +148,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return {ffn_out};
return {permute_input};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(

View File

@@ -629,6 +629,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
"metax_ops/fused_moe.cu",
"metax_ops/apply_rope.cu",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")

View File

@@ -31,6 +31,7 @@ from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_i
flash_attn_kvcache_func,
flash_attn_unpadded_func,
)
from fastdeploy.model_executor.ops.gpu import apply_rope
@dataclass
@@ -50,12 +51,15 @@ class FlashAttentionMetadata(AttentionMetadata):
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
rotary_cos_prefill: paddle.Tensor = None
rotary_sin_prefill: paddle.Tensor = None
rotary_cos_decode: paddle.Tensor = None
rotary_sin_decode: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
@@ -87,7 +91,7 @@ class FlashAttentionBackend(AttentionBackend):
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.attention_metadata: FlashAttentionMetadata = FlashAttentionMetadata()
self.record_block_table_metadata = {}
self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.model_config.max_model_len
@@ -122,6 +126,16 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.enable_mm = fd_config.model_config.enable_mm
max_num_seqs = fd_config.scheduler_config.max_num_seqs
if self.enable_mm:
self.attention_metadata.rotary_cos_decode = paddle.empty(
shape=[max_num_seqs, 1, 1, self.head_dim],
dtype="float32",
)
self.attention_metadata.rotary_sin_decode = paddle.empty(
shape=[max_num_seqs, 1, 1, self.head_dim],
dtype="float32",
)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -218,6 +232,72 @@ class FlashAttentionBackend(AttentionBackend):
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
# update prefilling rope
self.update_rotary_embs_prefill(forward_meta)
# update decoding rope
self.update_rotary_embs_decoder(forward_meta)
def update_rotary_embs_prefill(self, forward_meta: ForwardMeta):
if self.batch_ids_prefill.shape[0] == 0 or forward_meta.rotary_embs is None:
return
batch_ids = self.batch_ids_prefill
seq_lens_this_time = forward_meta.seq_lens_this_time[batch_ids]
cached_kv_lens = forward_meta.seq_lens_decoder[batch_ids, 0]
all_indices = []
for i in range(len(batch_ids)):
start_pos = cached_kv_lens[i]
seq_len_i = seq_lens_this_time[i]
if seq_len_i > 0:
indices_i = paddle.arange(start_pos, start_pos + seq_len_i, dtype="int64")
all_indices.append(indices_i)
if not all_indices:
return
all_indices = paddle.concat(all_indices) # [token_num]
if self.enable_mm:
gather_nd_indices = paddle.stack(
[ # [token_num, 2]
paddle.repeat_interleave(batch_ids, repeats=seq_lens_this_time, axis=0),
all_indices,
],
axis=1,
)
gathered_embs = paddle.gather_nd(
forward_meta.rotary_embs.squeeze([2]).transpose(
[0, 2, 1, 3, 4]
), # [B, 2, 1, S, 1, D // 2] -> [B, S, 2, 1, D // 2]
gather_nd_indices,
) # [token_num, 2, 1, D // 2]
rot_cos = gathered_embs[:, 0, :, :] # [token_num, 1, D // 2]
rot_sin = gathered_embs[:, 1, :, :]
else:
gathered_embs = paddle.gather(
forward_meta.rotary_embs.squeeze([1]), all_indices, axis=1 # [2, 1, S, 1, D // 2] -> [2, S, 1, D // 2]
) # [2, token_num, 1, D // 2]
rot_cos = gathered_embs[0, :, :, :] # [token_num, 1, D // 2]
rot_sin = gathered_embs[1, :, :, :]
self.attention_metadata.rotary_cos_prefill = paddle.repeat_interleave(
rot_cos, repeats=2, axis=-1
) # [token_num, 1, D]
self.attention_metadata.rotary_sin_prefill = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)
def update_rotary_embs_decoder(self, forward_meta: ForwardMeta):
if not self.enable_mm: # only initialize once for text-only model
if self.attention_metadata.rotary_cos_decode is None or self.attention_metadata.rotary_sin_decode is None:
self.attention_metadata.rotary_cos_decode = forward_meta.rotary_embs[0, 0, :, 0, :].astype(self.dtype)
self.attention_metadata.rotary_sin_decode = forward_meta.rotary_embs[1, 0, :, 0, :].astype(self.dtype)
elif self.batch_ids_decode.shape[0] > 0:
bs = self.batch_ids_decode.shape[0]
index = paddle.concat(
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
)
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
self.attention_metadata.rotary_cos_decode[:bs].copy_(paddle.repeat_interleave(rot_cos, repeats=2, axis=-1))
self.attention_metadata.rotary_sin_decode[:bs].copy_(paddle.repeat_interleave(rot_sin, repeats=2, axis=-1))
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
@@ -231,22 +311,19 @@ class FlashAttentionBackend(AttentionBackend):
"""
Calculate kv cache shape
"""
key_cache_shape = value_cache_shape = [max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
return (
key_cache_shape = value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.block_size,
self.kv_num_heads,
self.head_dim,
)
]
def apply_rope(self, qk, cos, sin):
return key_cache_shape, value_cache_shape
def apply_rope_native(self, qk, cos, sin):
rotate_half = paddle.reshape(
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
paddle.shape(qk),
@@ -254,66 +331,6 @@ class FlashAttentionBackend(AttentionBackend):
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype)
def apply_rope_dec(self, q, k, forward_meta: ForwardMeta):
batch_ids = self.batch_ids_decode
bs = batch_ids.shape[0]
index = paddle.concat([batch_ids.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1)
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
rot_cos = paddle.repeat_interleave(rot_cos, repeats=2, axis=-1)
rot_sin = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)
q = self.apply_rope(q, rot_cos, rot_sin)
k = self.apply_rope(k, rot_cos, rot_sin)
return q, k
def get_splited_qkv(
self,
qkv: paddle.Tensor,
forward_meta: ForwardMeta,
cu_seqlens_q: paddle.Tensor,
batch_ids=None,
):
qkv = qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
for idx in range(len(cu_seqlens_q) - 1):
batch_idx = batch_ids[idx]
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
cu_seq_start_q = cu_seqlens_q[idx]
cu_seq_end_q = cu_seqlens_q[idx + 1]
if forward_meta.rotary_embs is not None:
if self.enable_mm: # vl: forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
cos = paddle.repeat_interleave(
forward_meta.rotary_embs[batch_idx, 0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
repeats=2,
axis=-1,
) # [Si, D]
sin = paddle.repeat_interleave(
forward_meta.rotary_embs[batch_idx, 1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
repeats=2,
axis=-1,
) # [Si, D]
else: # text: forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
cos = paddle.repeat_interleave(
forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
repeats=2,
axis=-1,
) # [Si, D]
sin = paddle.repeat_interleave(
forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
repeats=2,
axis=-1,
) # [Si, D]
q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin)
k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin)
return q, k, v
def split_pd_qkv(self, qkv):
for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
@@ -386,18 +403,14 @@ class FlashAttentionBackend(AttentionBackend):
tensor_start = tensor_end
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
prefill_qkv,
forward_meta,
self.prefill_info_dict["cu_seqlens_q"],
batch_ids=self.batch_ids_prefill,
)
qkv = prefill_qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
q, k = apply_rope(q, k, self.attention_metadata.rotary_cos_prefill, self.attention_metadata.rotary_sin_prefill)
prefill_out = flash_attn_unpadded_func(
prefill_q,
prefill_k,
prefill_v,
q,
k,
v,
self.prefill_info_dict["cu_seqlens_q"],
self.prefill_info_dict["cu_seqlens_q"],
max_seqlen_q=self.max_seq_len,
@@ -406,9 +419,7 @@ class FlashAttentionBackend(AttentionBackend):
causal=self.causal,
)[0]
self.update_kv_cache(
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill
)
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill)
return prefill_out
@@ -417,12 +428,14 @@ class FlashAttentionBackend(AttentionBackend):
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
if self.enable_mm: # vl
q, k = self.apply_rope_dec(q, k, forward_meta)
rot_cos = None
rot_sin = None
else: # text
rot_cos = forward_meta.rotary_embs[0, 0, :, 0, :].astype(q.dtype)
rot_sin = forward_meta.rotary_embs[1, 0, :, 0, :].astype(q.dtype)
q, k = apply_rope(
q, k, self.attention_metadata.rotary_cos_decode, self.attention_metadata.rotary_sin_decode
)
rotary_cos = None
rotary_sin = None
else:
rotary_cos = self.attention_metadata.rotary_cos_decode
rotary_sin = self.attention_metadata.rotary_sin_decode
decode_out = flash_attn_kvcache_func(
q,
@@ -432,8 +445,8 @@ class FlashAttentionBackend(AttentionBackend):
self.block_table_dec,
k,
v,
rotary_cos=rot_cos,
rotary_sin=rot_sin,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
causal=self.causal,
is_rotary_interleaved=True,
)[0].squeeze(1)