diff --git a/custom_ops/metax_ops/apply_rope.cu b/custom_ops/metax_ops/apply_rope.cu new file mode 100644 index 000000000..4e820e425 --- /dev/null +++ b/custom_ops/metax_ops/apply_rope.cu @@ -0,0 +1,291 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "helper.h" + +#define THREADS_PER_BLOCK 128 + +template +struct Converter; + +template <> +struct Converter<__half> { + // __half -> float + __device__ static float to_float(__half val) { return __half2float(val); } + // float -> __half + __device__ static __half from_float(float val) { + return __float2half_rn(val); + } + // int -> __half + __device__ static __half from_int(float val) { return __int2half_rn(val); } +}; + +template <> +struct Converter<__nv_bfloat16> { + // __nv_bfloat16 -> float + __device__ static float to_float(__nv_bfloat16 val) { + return __bfloat162float(val); + } + // float -> __nv_bfloat16 + __device__ static __nv_bfloat16 from_float(float val) { + return __float2bfloat16_rn(val); + } + // int -> __nv_bfloat16 + __device__ static __nv_bfloat16 from_int(int val) { + return __int2bfloat16_rn(val); + } +}; + +template +__device__ void RotateQKVec4(const T* qk_ptr, + const T* rot_cos_ptr, + const T* rot_sin_ptr, + const int head_num, + const int base_idx, + const int rot_base_idx, + T* out) { + using VecT = AlignedVector; + + VecT qk_vec; + Load(qk_ptr + base_idx, &qk_vec); + VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]}; + VecT cos_vec, sin_vec; + Load(rot_cos_ptr + rot_base_idx, &cos_vec); + Load(rot_sin_ptr + rot_base_idx, &sin_vec); +#pragma unroll + for (int i = 0; i < 4; ++i) { + *(out + base_idx + i) = + qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i]; + } +} + +template +__device__ void RotateQKVec4(const T* qk_ptr, + const float* rot_cos_ptr, + const float* rot_sin_ptr, + const int head_num, + const int base_idx, + const int rot_base_idx, + T* out) { + using VecT = AlignedVector; + using VecF = AlignedVector; + auto to_float = [] __device__(T val) -> float { + return Converter::to_float(val); + }; + auto from_float = [] __device__(float val) -> T { + return Converter::from_float(val); + }; + + VecT qk_vec; + Load(qk_ptr + base_idx, &qk_vec); + VecF rot_half_vec = {-to_float(qk_vec[1]), + to_float(qk_vec[0]), + -to_float(qk_vec[3]), + to_float(qk_vec[2])}; + VecF cos_vec, sin_vec; + Load(rot_cos_ptr + rot_base_idx, &cos_vec); + Load(rot_sin_ptr + rot_base_idx, &sin_vec); +#pragma unroll + for (int i = 0; i < 4; ++i) { + *(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] + + rot_half_vec[i] * sin_vec[i]); + } +} + +// qk and rope have a same type +template +__global__ void DispatchApplyRopeVec4Kernel(const T* q, + const T* k, + const T* rot_cos, + const T* rot_sin, + const int q_num_elements, + const int k_num_elements, + const int q_head_num, + const int k_head_num, + const int head_dim, + T* q_out, + T* k_out) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + int head_dim_idx = idx % head_dim; + + if (idx < q_num_elements) { + int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx; + RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out); + } + + if (idx < k_num_elements) { + int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx; + RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out); + } +} + +// rope dtype is float32 +template +__global__ void DispatchApplyRopeVec4Kernel(const T* q, + const T* k, + const float* rot_cos, + const float* rot_sin, + const int q_num_elements, + const int k_num_elements, + const int q_head_num, + const int k_head_num, + const int head_dim, + T* q_out, + T* k_out) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + int head_dim_idx = idx % head_dim; + + if (idx < q_num_elements) { + int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx; + RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out); + } + + if (idx < k_num_elements) { + int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx; + RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out); + } +} + +template +void ApplyRopeKernel(const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& rot_cos, + const paddle::Tensor& rot_sin, + paddle::Tensor& q_out, + paddle::Tensor& k_out) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + const auto q_num_elements = q.numel(); + const auto k_num_elements = k.numel(); + const auto q_shape = q.shape(); + const auto k_shape = k.shape(); + const auto dims = q_shape.size(); + const auto q_head_num = q_shape[dims - 2]; + const auto k_head_num = k_shape[dims - 2]; + const auto head_dim = q_shape.back(); + int block_num = + (std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) / + (THREADS_PER_BLOCK * 4); + auto stream = q.stream(); + + if (q.dtype() == rot_cos.dtype()) { + DispatchApplyRopeVec4Kernel + <<>>( + reinterpret_cast(q.data()), + reinterpret_cast(k.data()), + reinterpret_cast(rot_cos.data()), + reinterpret_cast(rot_sin.data()), + q_num_elements, + k_num_elements, + q_head_num, + k_head_num, + head_dim, + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data())); + } else if (rot_cos.dtype() == paddle::DataType::FLOAT32) { + DispatchApplyRopeVec4Kernel + <<>>( + reinterpret_cast(q.data()), + reinterpret_cast(k.data()), + reinterpret_cast(rot_cos.data()), + reinterpret_cast(rot_sin.data()), + q_num_elements, + k_num_elements, + q_head_num, + k_head_num, + head_dim, + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data())); + } else { + PD_THROW("Unsupported qk dtype and rope dtype."); + } +} + +std::vector ApplyRope(const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& rot_cos, + const paddle::Tensor& rot_sin) { + auto q_shape = q.shape(); + auto cos_shape = rot_cos.shape(); + + auto q_out = paddle::empty_like(q); + auto k_out = paddle::empty_like(k); + + if (q.numel() == 0 || k.numel() == 0) { + return {q_out, k_out}; + } + + PADDLE_ENFORCE_EQ( + q_shape.back() % 2, + 0, + "The last dimension (head_dim) of qk must be an even number " + "for RoPE, but got %d", + q_shape.back()); + PADDLE_ENFORCE_EQ(q_shape.size(), + cos_shape.size(), + "The shape size of cos mismatches the shape size of q, " + "expect %d but got %d", + q_shape.size(), + cos_shape.size()); + PADDLE_ENFORCE_EQ(q_shape.back(), + cos_shape.back(), + "The shape.back() of cos mismatches the shape.back() of q, " + "expect %d but got %d", + q_shape.back(), + cos_shape.back()); + + auto input_type = q.dtype(); + switch (input_type) { + case paddle::DataType::BFLOAT16: + ApplyRopeKernel( + q, k, rot_cos, rot_sin, q_out, k_out); + break; + case paddle::DataType::FLOAT16: + ApplyRopeKernel( + q, k, rot_cos, rot_sin, q_out, k_out); + break; + default: + PD_THROW("Only support qk dtype of BF16 and F16"); + } + + return {q_out, k_out}; +} + +std::vector> ApplyRopeInferShape( + const std::vector& q_shape, + const std::vector& k_shape, + const std::vector& cos_shape, + const std::vector& sin_shape) { + return {q_shape, k_shape, cos_shape, sin_shape}; +} + +std::vector ApplyRopeInferDtype( + const paddle::DataType& q_dtype, + const paddle::DataType& k_dtype, + const paddle::DataType& cos_dtype, + const paddle::DataType& sin_dtype) { + return {q_dtype, k_dtype, cos_dtype, sin_dtype}; +} + +PD_BUILD_OP(apply_rope) + .Inputs({"q", "k", "rot_cos", "rot_sin"}) + .Outputs({"q_out", "k_out"}) + .SetKernelFn(PD_KERNEL(ApplyRope)) + .SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype)); diff --git a/custom_ops/metax_ops/mc_fused_moe_helper.h b/custom_ops/metax_ops/mc_fused_moe_helper.h index 56213f80d..002c2b87e 100644 --- a/custom_ops/metax_ops/mc_fused_moe_helper.h +++ b/custom_ops/metax_ops/mc_fused_moe_helper.h @@ -17,15 +17,16 @@ #include "mctlassEx/mctlassEx.h" template -void mc_grouped_gemm_basic_kernel(const ElementA *ptrA, +void mc_grouped_gemm_basic_kernel(const ElementA* ptrA, mctlassExOrder_t majorA, - const ElementB *ptrB, + const ElementB* ptrB, mctlassExOrder_t majorB, - const ElementA *ptrScale, - const ElementA *ptrBias, - ElementC *ptrC, + const ElementA* ptrScale, + const ElementA* ptrBias, + ElementC* ptrC, mctlassExOrder_t majorC, - const int *ptrSegInd, + const int* ptrSegInd, + int* ptrMNumTilesInd, int numExperts, int m, // expanded_active_expert_rows int n, // inter_dim @@ -34,9 +35,6 @@ void mc_grouped_gemm_basic_kernel(const ElementA *ptrA, mctlassExHandle_t handle; mctlassExHandleCreate(&handle); - int *ptrMNumTilesInd; - mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream); - mctlassExMatrixLayout_t matLayoutA; mctlassExMatrixLayout_t matLayoutB; mctlassExMatrixLayout_t matLayoutC; @@ -170,7 +168,6 @@ void mc_grouped_gemm_basic_kernel(const ElementA *ptrA, mctlassExMatrixLayoutDestroy(matLayoutC); mctlassExContiguousGroupedDescDestroy(contiguous_group_desc); mctlassExDestroyDesc(mctlass_desc); - mcFreeAsync(ptrMNumTilesInd, stream); } template @@ -227,27 +224,27 @@ class McMoeHelper { return total_ws_bytes; } - void computeFFN(const paddle::Tensor *input, - const paddle::Tensor *gate_weight, - const paddle::Tensor *ffn1_weight, - const paddle::Tensor *ffn1_scale, - const paddle::Tensor *ffn1_bias, - const paddle::Tensor *ffn2_weight, - const paddle::Tensor *ffn2_scale, - const paddle::Tensor *ffn2_bias, - const paddle::Tensor *moe_token_type_ids, + void computeFFN(const paddle::Tensor* input, + const paddle::Tensor* gate_weight, + const paddle::Tensor* ffn1_weight, + const paddle::Tensor* ffn1_scale, + const paddle::Tensor* ffn1_bias, + const paddle::Tensor* ffn2_weight, + const paddle::Tensor* ffn2_scale, + const paddle::Tensor* ffn2_bias, + const paddle::Tensor* moe_token_type_ids, const int moe_topk, const bool group_moe, const bool norm_topk_prob, const float routed_scaling_factor, const std::string moe_type, - paddle::Tensor *output) { - auto *input_activations = input->data(); - auto *gating_weights = gate_weight->data(); - const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data() : nullptr; - const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data() : nullptr; + paddle::Tensor* output) { + auto* input_activations = input->data(); + auto* gating_weights = gate_weight->data(); + const T* fc1_expert_biases = ffn1_bias ? ffn1_bias->data() : nullptr; + const T* fc2_expert_biases = ffn2_bias ? ffn2_bias->data() : nullptr; - auto *output_ = output->data(); + auto* output_ = output->data(); auto stream = input->stream(); auto place = input->place(); auto input_type = input->dtype(); @@ -282,52 +279,52 @@ class McMoeHelper { getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k); // Pointers - int *expert_for_source_row; - int *source_rows_; - int *permuted_rows_; - int *permuted_experts_; - int *expanded_source_row_to_expanded_dest_row; + int* expert_for_source_row; + int* source_rows_; + int* permuted_rows_; + int* permuted_experts_; + int* expanded_source_row_to_expanded_dest_row; - T *permuted_data_; - int32_t *total_rows_before_expert_; - T *fc1_result_; - float *softmax_out_; + T* permuted_data_; + int32_t* total_rows_before_expert_; + T* fc1_result_; + float* softmax_out_; paddle::Tensor ws_ptr_tensor = GetEmptyTensor({bytes}, paddle::DataType::INT8, place); - int8_t *ws_ptr = ws_ptr_tensor.data(); + int8_t* ws_ptr = ws_ptr_tensor.data(); const int64_t buf_size = AlignTo16(k * num_rows * hidden_size); const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size); const int64_t padded_experts = AlignTo16(num_experts); const int64_t num_moe_inputs = AlignTo16(k * num_rows); - expert_for_source_row = reinterpret_cast(ws_ptr); + expert_for_source_row = reinterpret_cast(ws_ptr); source_rows_ = expert_for_source_row + num_moe_inputs; permuted_rows_ = source_rows_ + num_moe_inputs; permuted_experts_ = permuted_rows_ + num_moe_inputs; expanded_source_row_to_expanded_dest_row = permuted_experts_ + num_moe_inputs; - permuted_data_ = reinterpret_cast( + permuted_data_ = reinterpret_cast( expanded_source_row_to_expanded_dest_row + num_moe_inputs); total_rows_before_expert_ = - reinterpret_cast(permuted_data_ + buf_size); + reinterpret_cast(permuted_data_ + buf_size); fc1_result_ = - reinterpret_cast(total_rows_before_expert_ + padded_experts); + reinterpret_cast(total_rows_before_expert_ + padded_experts); const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); } else { softmax_out_ = nullptr; } paddle::Tensor expert_scales_float_tensor = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); - float *expert_scales_float = expert_scales_float_tensor.data(); + float* expert_scales_float = expert_scales_float_tensor.data(); - float *softmax_max_prob = nullptr; + float* softmax_max_prob = nullptr; if (group_moe) { paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor( {num_rows, moe_topk}, paddle::DataType::FLOAT32, place); @@ -338,16 +335,16 @@ class McMoeHelper { paddle::Tensor fc1_out_tensor = GetEmptyTensor({num_rows * k, inter_size}, input_type, place); - T *fc1_out = fc1_out_tensor.data(); + T* fc1_out = fc1_out_tensor.data(); auto input_cast_tensor = paddle::experimental::cast(*input, paddle::DataType::FLOAT32); auto gate_tensor = paddle::experimental::matmul(input_cast_tensor, *gate_weight); - float *gating_output = gate_tensor.data(); + float* gating_output = gate_tensor.data(); if (moe_token_type_ids) { - auto *moe_token_type_ids_out = moe_token_type_ids->data(); + auto* moe_token_type_ids_out = moe_token_type_ids->data(); moe_token_type_ids_kernelLauncher(gating_output, moe_token_type_ids_out, num_rows, @@ -403,17 +400,21 @@ class McMoeHelper { mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR; mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; + auto m_num_tile = + GetEmptyTensor({num_experts}, paddle::DataType::INT32, place); + int* m_num_tile_ptr = reinterpret_cast(m_num_tile.data()); mc_grouped_gemm_basic_kernel( - reinterpret_cast(permuted_data_), + reinterpret_cast(permuted_data_), row_major, - reinterpret_cast(ffn1_weight->data()), + reinterpret_cast(ffn1_weight->data()), column_major, - reinterpret_cast(ffn1_scale->data()), - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), + reinterpret_cast(ffn1_scale->data()), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), row_major, total_rows_before_expert_, + m_num_tile_ptr, num_experts, expanded_active_expert_rows, inter_size, @@ -427,18 +428,19 @@ class McMoeHelper { paddle::Tensor fc2_output_tensor = GetEmptyTensor({k * num_rows, hidden_size}, input_type, place); - T *fc2_result = fc2_output_tensor.data(); + T* fc2_result = fc2_output_tensor.data(); mc_grouped_gemm_basic_kernel( - reinterpret_cast(act_out), + reinterpret_cast(act_out), row_major, - reinterpret_cast(ffn2_weight->data()), + reinterpret_cast(ffn2_weight->data()), column_major, - reinterpret_cast(ffn2_scale->data()), + reinterpret_cast(ffn2_scale->data()), nullptr, - reinterpret_cast(fc2_result), + reinterpret_cast(fc2_result), row_major, total_rows_before_expert_, + m_num_tile_ptr, num_experts, expanded_active_expert_rows, hidden_size, @@ -449,7 +451,7 @@ class McMoeHelper { fc2_result, output_, fc2_expert_biases, - reinterpret_cast(expert_scales_float), + reinterpret_cast(expert_scales_float), expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, @@ -465,7 +467,7 @@ class McMoeHelper { fc1_out, output_, fc1_expert_biases, // fc2_expert_biases, - reinterpret_cast(expert_scales_float), + reinterpret_cast(expert_scales_float), expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, diff --git a/custom_ops/metax_ops/moe_ffn.cu b/custom_ops/metax_ops/moe_ffn.cu index b390f4e87..f4a5dbcd4 100644 --- a/custom_ops/metax_ops/moe_ffn.cu +++ b/custom_ops/metax_ops/moe_ffn.cu @@ -21,20 +21,18 @@ template -void McMoeFFNKernel(const paddle::Tensor& permute_input, +void McMoeFFNKernel(paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight, const paddle::optional& ffn1_bias, const paddle::optional& ffn1_scale, const paddle::optional& ffn2_scale, - const std::string& quant_method, - paddle::Tensor ffn_out) { + const std::string& quant_method) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; - auto ffn_out_ptr = ffn_out.data(); auto permuted_input_ptr = permute_input.data(); auto place = permute_input.place(); auto input_type = permute_input.dtype(); @@ -54,6 +52,9 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input, mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR; mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; + auto m_num_tile = + GetEmptyTensor({num_experts}, paddle::DataType::INT32, place); + int* m_num_tile_ptr = reinterpret_cast(m_num_tile.data()); // ffn1 auto fc1_expert_biases = @@ -72,6 +73,7 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input, reinterpret_cast(fc1_out_ptr), row_major, tokens_expert_prefix_sum.data(), + m_num_tile_ptr, num_experts, expanded_active_expert_rows, inter_dim, @@ -91,9 +93,10 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input, column_major, reinterpret_cast(fc2_expert_scales), nullptr, - reinterpret_cast(ffn_out_ptr), + reinterpret_cast(permuted_input_ptr), row_major, tokens_expert_prefix_sum.data(), + m_num_tile_ptr, num_experts, expanded_active_expert_rows, hidden_size, @@ -102,7 +105,7 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input, } std::vector MoeExpertFFN( - const paddle::Tensor& permute_input, + paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight, @@ -112,10 +115,9 @@ std::vector MoeExpertFFN( const std::string& quant_method) { assert(quant_method == "weight_only_int8"); const auto input_type = permute_input.dtype(); - auto ffn_out = paddle::empty_like(permute_input); if (permute_input.numel() == 0) { - return {ffn_out}; + return {permute_input}; } switch (input_type) { @@ -130,8 +132,7 @@ std::vector MoeExpertFFN( ffn1_bias, ffn1_scale, ffn2_scale, - quant_method, - ffn_out); + quant_method); break; // case paddle::DataType::FLOAT16: // MoeFFNKernel(permute_input, @@ -147,7 +148,7 @@ std::vector MoeExpertFFN( default: PD_THROW("Unsupported data type for MoeExpertFFN"); } - return {ffn_out}; + return {permute_input}; } std::vector> MoeExpertFFNInferShape( diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 3a34e5662..3460b077e 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -629,6 +629,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"): "metax_ops/moe_ffn.cu", "metax_ops/moe_reduce.cu", "metax_ops/fused_moe.cu", + "metax_ops/apply_rope.cu", ] sources += find_end_files("gpu_ops/speculate_decoding", ".cu") diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index 3e9a307a3..306b11bf0 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -31,6 +31,7 @@ from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_i flash_attn_kvcache_func, flash_attn_unpadded_func, ) +from fastdeploy.model_executor.ops.gpu import apply_rope @dataclass @@ -50,12 +51,15 @@ class FlashAttentionMetadata(AttentionMetadata): decoder_batch_ids: paddle.Tensor = None decoder_tile_ids_per_batch: paddle.Tensor = None decoder_num_blocks: paddle.Tensor = None + rotary_cos_prefill: paddle.Tensor = None + rotary_sin_prefill: paddle.Tensor = None + rotary_cos_decode: paddle.Tensor = None + rotary_sin_decode: paddle.Tensor = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 block_tables: Optional[paddle.Tensor] = None - rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None encoder_block_shape_q: int = -1 decoder_block_shape_q: int = -1 @@ -87,7 +91,7 @@ class FlashAttentionBackend(AttentionBackend): FlashAttentionBackend __init__ """ super().__init__() - self.attention_metadata: FlashAttentionMetadata = None + self.attention_metadata: FlashAttentionMetadata = FlashAttentionMetadata() self.record_block_table_metadata = {} self.block_size: int = fd_config.cache_config.block_size self.max_seq_len: int = fd_config.model_config.max_model_len @@ -122,6 +126,16 @@ class FlashAttentionBackend(AttentionBackend): self.rank, self.device_id = init_rank_and_device_id(fd_config) self.enable_mm = fd_config.model_config.enable_mm + max_num_seqs = fd_config.scheduler_config.max_num_seqs + if self.enable_mm: + self.attention_metadata.rotary_cos_decode = paddle.empty( + shape=[max_num_seqs, 1, 1, self.head_dim], + dtype="float32", + ) + self.attention_metadata.rotary_sin_decode = paddle.empty( + shape=[max_num_seqs, 1, 1, self.head_dim], + dtype="float32", + ) def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" @@ -218,6 +232,72 @@ class FlashAttentionBackend(AttentionBackend): self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"]) self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0] self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :] + # update prefilling rope + self.update_rotary_embs_prefill(forward_meta) + # update decoding rope + self.update_rotary_embs_decoder(forward_meta) + + def update_rotary_embs_prefill(self, forward_meta: ForwardMeta): + if self.batch_ids_prefill.shape[0] == 0 or forward_meta.rotary_embs is None: + return + + batch_ids = self.batch_ids_prefill + seq_lens_this_time = forward_meta.seq_lens_this_time[batch_ids] + cached_kv_lens = forward_meta.seq_lens_decoder[batch_ids, 0] + + all_indices = [] + for i in range(len(batch_ids)): + start_pos = cached_kv_lens[i] + seq_len_i = seq_lens_this_time[i] + if seq_len_i > 0: + indices_i = paddle.arange(start_pos, start_pos + seq_len_i, dtype="int64") + all_indices.append(indices_i) + if not all_indices: + return + + all_indices = paddle.concat(all_indices) # [token_num] + if self.enable_mm: + gather_nd_indices = paddle.stack( + [ # [token_num, 2] + paddle.repeat_interleave(batch_ids, repeats=seq_lens_this_time, axis=0), + all_indices, + ], + axis=1, + ) + gathered_embs = paddle.gather_nd( + forward_meta.rotary_embs.squeeze([2]).transpose( + [0, 2, 1, 3, 4] + ), # [B, 2, 1, S, 1, D // 2] -> [B, S, 2, 1, D // 2] + gather_nd_indices, + ) # [token_num, 2, 1, D // 2] + rot_cos = gathered_embs[:, 0, :, :] # [token_num, 1, D // 2] + rot_sin = gathered_embs[:, 1, :, :] + else: + gathered_embs = paddle.gather( + forward_meta.rotary_embs.squeeze([1]), all_indices, axis=1 # [2, 1, S, 1, D // 2] -> [2, S, 1, D // 2] + ) # [2, token_num, 1, D // 2] + rot_cos = gathered_embs[0, :, :, :] # [token_num, 1, D // 2] + rot_sin = gathered_embs[1, :, :, :] + + self.attention_metadata.rotary_cos_prefill = paddle.repeat_interleave( + rot_cos, repeats=2, axis=-1 + ) # [token_num, 1, D] + self.attention_metadata.rotary_sin_prefill = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1) + + def update_rotary_embs_decoder(self, forward_meta: ForwardMeta): + if not self.enable_mm: # only initialize once for text-only model + if self.attention_metadata.rotary_cos_decode is None or self.attention_metadata.rotary_sin_decode is None: + self.attention_metadata.rotary_cos_decode = forward_meta.rotary_embs[0, 0, :, 0, :].astype(self.dtype) + self.attention_metadata.rotary_sin_decode = forward_meta.rotary_embs[1, 0, :, 0, :].astype(self.dtype) + elif self.batch_ids_decode.shape[0] > 0: + bs = self.batch_ids_decode.shape[0] + index = paddle.concat( + [self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1 + ) + rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1]) + rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1]) + self.attention_metadata.rotary_cos_decode[:bs].copy_(paddle.repeat_interleave(rot_cos, repeats=2, axis=-1)) + self.attention_metadata.rotary_sin_decode[:bs].copy_(paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)) def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" @@ -231,22 +311,19 @@ class FlashAttentionBackend(AttentionBackend): """ Calculate kv cache shape """ + key_cache_shape = value_cache_shape = [max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim] + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - return ( + key_cache_shape = value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim // 2, - ) - else: - return ( - max_num_blocks, - self.block_size, - self.kv_num_heads, - self.head_dim, - ) + ] - def apply_rope(self, qk, cos, sin): + return key_cache_shape, value_cache_shape + + def apply_rope_native(self, qk, cos, sin): rotate_half = paddle.reshape( paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1), paddle.shape(qk), @@ -254,66 +331,6 @@ class FlashAttentionBackend(AttentionBackend): out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin)) return paddle.cast(out, qk.dtype) - def apply_rope_dec(self, q, k, forward_meta: ForwardMeta): - batch_ids = self.batch_ids_decode - bs = batch_ids.shape[0] - index = paddle.concat([batch_ids.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1) - rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1]) - rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1]) - rot_cos = paddle.repeat_interleave(rot_cos, repeats=2, axis=-1) - rot_sin = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1) - q = self.apply_rope(q, rot_cos, rot_sin) - k = self.apply_rope(k, rot_cos, rot_sin) - - return q, k - - def get_splited_qkv( - self, - qkv: paddle.Tensor, - forward_meta: ForwardMeta, - cu_seqlens_q: paddle.Tensor, - batch_ids=None, - ): - qkv = qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim]) - q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2) - - for idx in range(len(cu_seqlens_q) - 1): - batch_idx = batch_ids[idx] - seq_len_i = forward_meta.seq_lens_this_time[batch_idx] - if seq_len_i == 0: - continue - cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0] - cu_seq_start_q = cu_seqlens_q[idx] - cu_seq_end_q = cu_seqlens_q[idx + 1] - - if forward_meta.rotary_embs is not None: - if self.enable_mm: # vl: forward_meta.rotary_embs is [2, 1, S, 1, D // 2] - cos = paddle.repeat_interleave( - forward_meta.rotary_embs[batch_idx, 0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], - repeats=2, - axis=-1, - ) # [Si, D] - sin = paddle.repeat_interleave( - forward_meta.rotary_embs[batch_idx, 1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], - repeats=2, - axis=-1, - ) # [Si, D] - else: # text: forward_meta.rotary_embs is [2, 1, S, 1, D // 2] - cos = paddle.repeat_interleave( - forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], - repeats=2, - axis=-1, - ) # [Si, D] - sin = paddle.repeat_interleave( - forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], - repeats=2, - axis=-1, - ) # [Si, D] - q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin) - k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin) - - return q, k, v - def split_pd_qkv(self, qkv): for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]): @@ -386,18 +403,14 @@ class FlashAttentionBackend(AttentionBackend): tensor_start = tensor_end def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta): - - prefill_q, prefill_k, prefill_v = self.get_splited_qkv( - prefill_qkv, - forward_meta, - self.prefill_info_dict["cu_seqlens_q"], - batch_ids=self.batch_ids_prefill, - ) + qkv = prefill_qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim]) + q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2) + q, k = apply_rope(q, k, self.attention_metadata.rotary_cos_prefill, self.attention_metadata.rotary_sin_prefill) prefill_out = flash_attn_unpadded_func( - prefill_q, - prefill_k, - prefill_v, + q, + k, + v, self.prefill_info_dict["cu_seqlens_q"], self.prefill_info_dict["cu_seqlens_q"], max_seqlen_q=self.max_seq_len, @@ -406,9 +419,7 @@ class FlashAttentionBackend(AttentionBackend): causal=self.causal, )[0] - self.update_kv_cache( - prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill - ) + self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill) return prefill_out @@ -417,12 +428,14 @@ class FlashAttentionBackend(AttentionBackend): q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2) if self.enable_mm: # vl - q, k = self.apply_rope_dec(q, k, forward_meta) - rot_cos = None - rot_sin = None - else: # text - rot_cos = forward_meta.rotary_embs[0, 0, :, 0, :].astype(q.dtype) - rot_sin = forward_meta.rotary_embs[1, 0, :, 0, :].astype(q.dtype) + q, k = apply_rope( + q, k, self.attention_metadata.rotary_cos_decode, self.attention_metadata.rotary_sin_decode + ) + rotary_cos = None + rotary_sin = None + else: + rotary_cos = self.attention_metadata.rotary_cos_decode + rotary_sin = self.attention_metadata.rotary_sin_decode decode_out = flash_attn_kvcache_func( q, @@ -432,8 +445,8 @@ class FlashAttentionBackend(AttentionBackend): self.block_table_dec, k, v, - rotary_cos=rot_cos, - rotary_sin=rot_sin, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, causal=self.causal, is_rotary_interleaved=True, )[0].squeeze(1)