mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] optimize cutlass moe and flash attention backend (#5128)
This commit is contained in:
291
custom_ops/metax_ops/apply_rope.cu
Normal file
291
custom_ops/metax_ops/apply_rope.cu
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <paddle/extension.h>
|
||||
#include <algorithm>
|
||||
#include "helper.h"
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
template <typename T>
|
||||
struct Converter;
|
||||
|
||||
template <>
|
||||
struct Converter<__half> {
|
||||
// __half -> float
|
||||
__device__ static float to_float(__half val) { return __half2float(val); }
|
||||
// float -> __half
|
||||
__device__ static __half from_float(float val) {
|
||||
return __float2half_rn(val);
|
||||
}
|
||||
// int -> __half
|
||||
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<__nv_bfloat16> {
|
||||
// __nv_bfloat16 -> float
|
||||
__device__ static float to_float(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
// float -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_float(float val) {
|
||||
return __float2bfloat16_rn(val);
|
||||
}
|
||||
// int -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_int(int val) {
|
||||
return __int2bfloat16_rn(val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ void RotateQKVec4(const T* qk_ptr,
|
||||
const T* rot_cos_ptr,
|
||||
const T* rot_sin_ptr,
|
||||
const int head_num,
|
||||
const int base_idx,
|
||||
const int rot_base_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qk_ptr + base_idx, &qk_vec);
|
||||
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
|
||||
VecT cos_vec, sin_vec;
|
||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*(out + base_idx + i) =
|
||||
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void RotateQKVec4(const T* qk_ptr,
|
||||
const float* rot_cos_ptr,
|
||||
const float* rot_sin_ptr,
|
||||
const int head_num,
|
||||
const int base_idx,
|
||||
const int rot_base_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
using VecF = AlignedVector<float, 4>;
|
||||
auto to_float = [] __device__(T val) -> float {
|
||||
return Converter<T>::to_float(val);
|
||||
};
|
||||
auto from_float = [] __device__(float val) -> T {
|
||||
return Converter<T>::from_float(val);
|
||||
};
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qk_ptr + base_idx, &qk_vec);
|
||||
VecF rot_half_vec = {-to_float(qk_vec[1]),
|
||||
to_float(qk_vec[0]),
|
||||
-to_float(qk_vec[3]),
|
||||
to_float(qk_vec[2])};
|
||||
VecF cos_vec, sin_vec;
|
||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
||||
rot_half_vec[i] * sin_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// qk and rope have a same type
|
||||
template <typename T>
|
||||
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
|
||||
const T* k,
|
||||
const T* rot_cos,
|
||||
const T* rot_sin,
|
||||
const int q_num_elements,
|
||||
const int k_num_elements,
|
||||
const int q_head_num,
|
||||
const int k_head_num,
|
||||
const int head_dim,
|
||||
T* q_out,
|
||||
T* k_out) {
|
||||
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
|
||||
int head_dim_idx = idx % head_dim;
|
||||
|
||||
if (idx < q_num_elements) {
|
||||
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
|
||||
}
|
||||
|
||||
if (idx < k_num_elements) {
|
||||
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
|
||||
}
|
||||
}
|
||||
|
||||
// rope dtype is float32
|
||||
template <typename T>
|
||||
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
|
||||
const T* k,
|
||||
const float* rot_cos,
|
||||
const float* rot_sin,
|
||||
const int q_num_elements,
|
||||
const int k_num_elements,
|
||||
const int q_head_num,
|
||||
const int k_head_num,
|
||||
const int head_dim,
|
||||
T* q_out,
|
||||
T* k_out) {
|
||||
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
|
||||
int head_dim_idx = idx % head_dim;
|
||||
|
||||
if (idx < q_num_elements) {
|
||||
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
|
||||
}
|
||||
|
||||
if (idx < k_num_elements) {
|
||||
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
void ApplyRopeKernel(const paddle::Tensor& q,
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& rot_cos,
|
||||
const paddle::Tensor& rot_sin,
|
||||
paddle::Tensor& q_out,
|
||||
paddle::Tensor& k_out) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const auto q_num_elements = q.numel();
|
||||
const auto k_num_elements = k.numel();
|
||||
const auto q_shape = q.shape();
|
||||
const auto k_shape = k.shape();
|
||||
const auto dims = q_shape.size();
|
||||
const auto q_head_num = q_shape[dims - 2];
|
||||
const auto k_head_num = k_shape[dims - 2];
|
||||
const auto head_dim = q_shape.back();
|
||||
int block_num =
|
||||
(std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) /
|
||||
(THREADS_PER_BLOCK * 4);
|
||||
auto stream = q.stream();
|
||||
|
||||
if (q.dtype() == rot_cos.dtype()) {
|
||||
DispatchApplyRopeVec4Kernel<DataType_>
|
||||
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(q.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(k.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
|
||||
q_num_elements,
|
||||
k_num_elements,
|
||||
q_head_num,
|
||||
k_head_num,
|
||||
head_dim,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
|
||||
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
|
||||
DispatchApplyRopeVec4Kernel<DataType_>
|
||||
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(q.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(k.data<data_t>()),
|
||||
reinterpret_cast<const float*>(rot_cos.data<float>()),
|
||||
reinterpret_cast<const float*>(rot_sin.data<float>()),
|
||||
q_num_elements,
|
||||
k_num_elements,
|
||||
q_head_num,
|
||||
k_head_num,
|
||||
head_dim,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
|
||||
} else {
|
||||
PD_THROW("Unsupported qk dtype and rope dtype.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q,
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& rot_cos,
|
||||
const paddle::Tensor& rot_sin) {
|
||||
auto q_shape = q.shape();
|
||||
auto cos_shape = rot_cos.shape();
|
||||
|
||||
auto q_out = paddle::empty_like(q);
|
||||
auto k_out = paddle::empty_like(k);
|
||||
|
||||
if (q.numel() == 0 || k.numel() == 0) {
|
||||
return {q_out, k_out};
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
q_shape.back() % 2,
|
||||
0,
|
||||
"The last dimension (head_dim) of qk must be an even number "
|
||||
"for RoPE, but got %d",
|
||||
q_shape.back());
|
||||
PADDLE_ENFORCE_EQ(q_shape.size(),
|
||||
cos_shape.size(),
|
||||
"The shape size of cos mismatches the shape size of q, "
|
||||
"expect %d but got %d",
|
||||
q_shape.size(),
|
||||
cos_shape.size());
|
||||
PADDLE_ENFORCE_EQ(q_shape.back(),
|
||||
cos_shape.back(),
|
||||
"The shape.back() of cos mismatches the shape.back() of q, "
|
||||
"expect %d but got %d",
|
||||
q_shape.back(),
|
||||
cos_shape.back());
|
||||
|
||||
auto input_type = q.dtype();
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
ApplyRopeKernel<paddle::DataType::BFLOAT16>(
|
||||
q, k, rot_cos, rot_sin, q_out, k_out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
ApplyRopeKernel<paddle::DataType::FLOAT16>(
|
||||
q, k, rot_cos, rot_sin, q_out, k_out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Only support qk dtype of BF16 and F16");
|
||||
}
|
||||
|
||||
return {q_out, k_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ApplyRopeInferShape(
|
||||
const std::vector<int64_t>& q_shape,
|
||||
const std::vector<int64_t>& k_shape,
|
||||
const std::vector<int64_t>& cos_shape,
|
||||
const std::vector<int64_t>& sin_shape) {
|
||||
return {q_shape, k_shape, cos_shape, sin_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> ApplyRopeInferDtype(
|
||||
const paddle::DataType& q_dtype,
|
||||
const paddle::DataType& k_dtype,
|
||||
const paddle::DataType& cos_dtype,
|
||||
const paddle::DataType& sin_dtype) {
|
||||
return {q_dtype, k_dtype, cos_dtype, sin_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(apply_rope)
|
||||
.Inputs({"q", "k", "rot_cos", "rot_sin"})
|
||||
.Outputs({"q_out", "k_out"})
|
||||
.SetKernelFn(PD_KERNEL(ApplyRope))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype));
|
||||
@@ -17,15 +17,16 @@
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC>
|
||||
void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
|
||||
void mc_grouped_gemm_basic_kernel(const ElementA* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB *ptrB,
|
||||
const ElementB* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA *ptrScale,
|
||||
const ElementA *ptrBias,
|
||||
ElementC *ptrC,
|
||||
const ElementA* ptrScale,
|
||||
const ElementA* ptrBias,
|
||||
ElementC* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
const int* ptrSegInd,
|
||||
int* ptrMNumTilesInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
@@ -34,9 +35,6 @@ void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
int *ptrMNumTilesInd;
|
||||
mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
@@ -170,7 +168,6 @@ void mc_grouped_gemm_basic_kernel(const ElementA *ptrA,
|
||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
|
||||
mctlassExDestroyDesc(mctlass_desc);
|
||||
mcFreeAsync(ptrMNumTilesInd, stream);
|
||||
}
|
||||
|
||||
template <typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
@@ -227,27 +224,27 @@ class McMoeHelper {
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
void computeFFN(const paddle::Tensor* input,
|
||||
const paddle::Tensor* gate_weight,
|
||||
const paddle::Tensor* ffn1_weight,
|
||||
const paddle::Tensor* ffn1_scale,
|
||||
const paddle::Tensor* ffn1_bias,
|
||||
const paddle::Tensor* ffn2_weight,
|
||||
const paddle::Tensor* ffn2_scale,
|
||||
const paddle::Tensor* ffn2_bias,
|
||||
const paddle::Tensor* moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
paddle::Tensor* output) {
|
||||
auto* input_activations = input->data<T>();
|
||||
auto* gating_weights = gate_weight->data<float>();
|
||||
const T* fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T* fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto* output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
@@ -282,52 +279,52 @@ class McMoeHelper {
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
int* expert_for_source_row;
|
||||
int* source_rows_;
|
||||
int* permuted_rows_;
|
||||
int* permuted_experts_;
|
||||
int* expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
T* permuted_data_;
|
||||
int32_t* total_rows_before_expert_;
|
||||
T* fc1_result_;
|
||||
float* softmax_out_;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
permuted_data_ = reinterpret_cast<T*>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
float* expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
float *softmax_max_prob = nullptr;
|
||||
float* softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
@@ -338,16 +335,16 @@ class McMoeHelper {
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
T* fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
float* gating_output = gate_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
@@ -403,17 +400,21 @@ class McMoeHelper {
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
auto m_num_tile =
|
||||
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
reinterpret_cast<const ElementA*>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
reinterpret_cast<const ElementB*>(ffn1_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
reinterpret_cast<const ElementA*>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC*>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
@@ -427,18 +428,19 @@ class McMoeHelper {
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
T* fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
reinterpret_cast<const ElementA*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
reinterpret_cast<const ElementB*>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA*>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
reinterpret_cast<ElementC*>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
@@ -449,7 +451,7 @@ class McMoeHelper {
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
reinterpret_cast<float*>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
@@ -465,7 +467,7 @@ class McMoeHelper {
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
reinterpret_cast<float*>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
|
||||
@@ -21,20 +21,18 @@ template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
void McMoeFFNKernel(paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
const std::string& quant_method) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto ffn_out_ptr = ffn_out.data<data_t>();
|
||||
auto permuted_input_ptr = permute_input.data<data_t>();
|
||||
auto place = permute_input.place();
|
||||
auto input_type = permute_input.dtype();
|
||||
@@ -54,6 +52,9 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
auto m_num_tile =
|
||||
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
@@ -72,6 +73,7 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
reinterpret_cast<ElementC*>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
@@ -91,9 +93,10 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC*>(ffn_out_ptr),
|
||||
reinterpret_cast<ElementC*>(permuted_input_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
@@ -102,7 +105,7 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::Tensor& permute_input,
|
||||
paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
@@ -112,10 +115,9 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const std::string& quant_method) {
|
||||
assert(quant_method == "weight_only_int8");
|
||||
const auto input_type = permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input);
|
||||
|
||||
if (permute_input.numel() == 0) {
|
||||
return {ffn_out};
|
||||
return {permute_input};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
@@ -130,8 +132,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
quant_method);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -147,7 +148,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
}
|
||||
return {ffn_out};
|
||||
return {permute_input};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
|
||||
@@ -629,6 +629,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"metax_ops/moe_ffn.cu",
|
||||
"metax_ops/moe_reduce.cu",
|
||||
"metax_ops/fused_moe.cu",
|
||||
"metax_ops/apply_rope.cu",
|
||||
]
|
||||
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
||||
|
||||
@@ -31,6 +31,7 @@ from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_i
|
||||
flash_attn_kvcache_func,
|
||||
flash_attn_unpadded_func,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import apply_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,12 +51,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
rotary_cos_prefill: paddle.Tensor = None
|
||||
rotary_sin_prefill: paddle.Tensor = None
|
||||
rotary_cos_decode: paddle.Tensor = None
|
||||
rotary_sin_decode: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
@@ -87,7 +91,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.attention_metadata: FlashAttentionMetadata = FlashAttentionMetadata()
|
||||
self.record_block_table_metadata = {}
|
||||
self.block_size: int = fd_config.cache_config.block_size
|
||||
self.max_seq_len: int = fd_config.model_config.max_model_len
|
||||
@@ -122,6 +126,16 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
self.enable_mm = fd_config.model_config.enable_mm
|
||||
max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
if self.enable_mm:
|
||||
self.attention_metadata.rotary_cos_decode = paddle.empty(
|
||||
shape=[max_num_seqs, 1, 1, self.head_dim],
|
||||
dtype="float32",
|
||||
)
|
||||
self.attention_metadata.rotary_sin_decode = paddle.empty(
|
||||
shape=[max_num_seqs, 1, 1, self.head_dim],
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
@@ -218,6 +232,72 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
|
||||
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
|
||||
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
|
||||
# update prefilling rope
|
||||
self.update_rotary_embs_prefill(forward_meta)
|
||||
# update decoding rope
|
||||
self.update_rotary_embs_decoder(forward_meta)
|
||||
|
||||
def update_rotary_embs_prefill(self, forward_meta: ForwardMeta):
|
||||
if self.batch_ids_prefill.shape[0] == 0 or forward_meta.rotary_embs is None:
|
||||
return
|
||||
|
||||
batch_ids = self.batch_ids_prefill
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time[batch_ids]
|
||||
cached_kv_lens = forward_meta.seq_lens_decoder[batch_ids, 0]
|
||||
|
||||
all_indices = []
|
||||
for i in range(len(batch_ids)):
|
||||
start_pos = cached_kv_lens[i]
|
||||
seq_len_i = seq_lens_this_time[i]
|
||||
if seq_len_i > 0:
|
||||
indices_i = paddle.arange(start_pos, start_pos + seq_len_i, dtype="int64")
|
||||
all_indices.append(indices_i)
|
||||
if not all_indices:
|
||||
return
|
||||
|
||||
all_indices = paddle.concat(all_indices) # [token_num]
|
||||
if self.enable_mm:
|
||||
gather_nd_indices = paddle.stack(
|
||||
[ # [token_num, 2]
|
||||
paddle.repeat_interleave(batch_ids, repeats=seq_lens_this_time, axis=0),
|
||||
all_indices,
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
gathered_embs = paddle.gather_nd(
|
||||
forward_meta.rotary_embs.squeeze([2]).transpose(
|
||||
[0, 2, 1, 3, 4]
|
||||
), # [B, 2, 1, S, 1, D // 2] -> [B, S, 2, 1, D // 2]
|
||||
gather_nd_indices,
|
||||
) # [token_num, 2, 1, D // 2]
|
||||
rot_cos = gathered_embs[:, 0, :, :] # [token_num, 1, D // 2]
|
||||
rot_sin = gathered_embs[:, 1, :, :]
|
||||
else:
|
||||
gathered_embs = paddle.gather(
|
||||
forward_meta.rotary_embs.squeeze([1]), all_indices, axis=1 # [2, 1, S, 1, D // 2] -> [2, S, 1, D // 2]
|
||||
) # [2, token_num, 1, D // 2]
|
||||
rot_cos = gathered_embs[0, :, :, :] # [token_num, 1, D // 2]
|
||||
rot_sin = gathered_embs[1, :, :, :]
|
||||
|
||||
self.attention_metadata.rotary_cos_prefill = paddle.repeat_interleave(
|
||||
rot_cos, repeats=2, axis=-1
|
||||
) # [token_num, 1, D]
|
||||
self.attention_metadata.rotary_sin_prefill = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)
|
||||
|
||||
def update_rotary_embs_decoder(self, forward_meta: ForwardMeta):
|
||||
if not self.enable_mm: # only initialize once for text-only model
|
||||
if self.attention_metadata.rotary_cos_decode is None or self.attention_metadata.rotary_sin_decode is None:
|
||||
self.attention_metadata.rotary_cos_decode = forward_meta.rotary_embs[0, 0, :, 0, :].astype(self.dtype)
|
||||
self.attention_metadata.rotary_sin_decode = forward_meta.rotary_embs[1, 0, :, 0, :].astype(self.dtype)
|
||||
elif self.batch_ids_decode.shape[0] > 0:
|
||||
bs = self.batch_ids_decode.shape[0]
|
||||
index = paddle.concat(
|
||||
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
|
||||
)
|
||||
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
self.attention_metadata.rotary_cos_decode[:bs].copy_(paddle.repeat_interleave(rot_cos, repeats=2, axis=-1))
|
||||
self.attention_metadata.rotary_sin_decode[:bs].copy_(paddle.repeat_interleave(rot_sin, repeats=2, axis=-1))
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
@@ -231,22 +311,19 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Calculate kv cache shape
|
||||
"""
|
||||
key_cache_shape = value_cache_shape = [max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim]
|
||||
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
key_cache_shape = value_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
]
|
||||
|
||||
def apply_rope(self, qk, cos, sin):
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def apply_rope_native(self, qk, cos, sin):
|
||||
rotate_half = paddle.reshape(
|
||||
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
|
||||
paddle.shape(qk),
|
||||
@@ -254,66 +331,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
|
||||
return paddle.cast(out, qk.dtype)
|
||||
|
||||
def apply_rope_dec(self, q, k, forward_meta: ForwardMeta):
|
||||
batch_ids = self.batch_ids_decode
|
||||
bs = batch_ids.shape[0]
|
||||
index = paddle.concat([batch_ids.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1)
|
||||
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
rot_cos = paddle.repeat_interleave(rot_cos, repeats=2, axis=-1)
|
||||
rot_sin = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)
|
||||
q = self.apply_rope(q, rot_cos, rot_sin)
|
||||
k = self.apply_rope(k, rot_cos, rot_sin)
|
||||
|
||||
return q, k
|
||||
|
||||
def get_splited_qkv(
|
||||
self,
|
||||
qkv: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
batch_ids=None,
|
||||
):
|
||||
qkv = qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
|
||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
||||
|
||||
for idx in range(len(cu_seqlens_q) - 1):
|
||||
batch_idx = batch_ids[idx]
|
||||
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
|
||||
cu_seq_start_q = cu_seqlens_q[idx]
|
||||
cu_seq_end_q = cu_seqlens_q[idx + 1]
|
||||
|
||||
if forward_meta.rotary_embs is not None:
|
||||
if self.enable_mm: # vl: forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
|
||||
cos = paddle.repeat_interleave(
|
||||
forward_meta.rotary_embs[batch_idx, 0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
|
||||
repeats=2,
|
||||
axis=-1,
|
||||
) # [Si, D]
|
||||
sin = paddle.repeat_interleave(
|
||||
forward_meta.rotary_embs[batch_idx, 1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
|
||||
repeats=2,
|
||||
axis=-1,
|
||||
) # [Si, D]
|
||||
else: # text: forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
|
||||
cos = paddle.repeat_interleave(
|
||||
forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
|
||||
repeats=2,
|
||||
axis=-1,
|
||||
) # [Si, D]
|
||||
sin = paddle.repeat_interleave(
|
||||
forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :],
|
||||
repeats=2,
|
||||
axis=-1,
|
||||
) # [Si, D]
|
||||
q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||
k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def split_pd_qkv(self, qkv):
|
||||
|
||||
for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
|
||||
@@ -386,18 +403,14 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
tensor_start = tensor_end
|
||||
|
||||
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||
|
||||
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
|
||||
prefill_qkv,
|
||||
forward_meta,
|
||||
self.prefill_info_dict["cu_seqlens_q"],
|
||||
batch_ids=self.batch_ids_prefill,
|
||||
)
|
||||
qkv = prefill_qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
|
||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
||||
q, k = apply_rope(q, k, self.attention_metadata.rotary_cos_prefill, self.attention_metadata.rotary_sin_prefill)
|
||||
|
||||
prefill_out = flash_attn_unpadded_func(
|
||||
prefill_q,
|
||||
prefill_k,
|
||||
prefill_v,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self.prefill_info_dict["cu_seqlens_q"],
|
||||
self.prefill_info_dict["cu_seqlens_q"],
|
||||
max_seqlen_q=self.max_seq_len,
|
||||
@@ -406,9 +419,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
causal=self.causal,
|
||||
)[0]
|
||||
|
||||
self.update_kv_cache(
|
||||
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill
|
||||
)
|
||||
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill)
|
||||
|
||||
return prefill_out
|
||||
|
||||
@@ -417,12 +428,14 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
||||
|
||||
if self.enable_mm: # vl
|
||||
q, k = self.apply_rope_dec(q, k, forward_meta)
|
||||
rot_cos = None
|
||||
rot_sin = None
|
||||
else: # text
|
||||
rot_cos = forward_meta.rotary_embs[0, 0, :, 0, :].astype(q.dtype)
|
||||
rot_sin = forward_meta.rotary_embs[1, 0, :, 0, :].astype(q.dtype)
|
||||
q, k = apply_rope(
|
||||
q, k, self.attention_metadata.rotary_cos_decode, self.attention_metadata.rotary_sin_decode
|
||||
)
|
||||
rotary_cos = None
|
||||
rotary_sin = None
|
||||
else:
|
||||
rotary_cos = self.attention_metadata.rotary_cos_decode
|
||||
rotary_sin = self.attention_metadata.rotary_sin_decode
|
||||
|
||||
decode_out = flash_attn_kvcache_func(
|
||||
q,
|
||||
@@ -432,8 +445,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.block_table_dec,
|
||||
k,
|
||||
v,
|
||||
rotary_cos=rot_cos,
|
||||
rotary_sin=rot_sin,
|
||||
rotary_cos=rotary_cos,
|
||||
rotary_sin=rotary_sin,
|
||||
causal=self.causal,
|
||||
is_rotary_interleaved=True,
|
||||
)[0].squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user