From c694fa2879f118a793c830cf19f58cb3ba7ea788 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Wed, 27 Aug 2025 17:35:04 +0800 Subject: [PATCH] Revert "[Feature] block sparse attention (#3209)" (#3647) This reverts commit 646a0c2fd8a2557e80c17b4c5323cd3833ae6dfd. --- custom_ops/gpu_ops/cpp_extensions.cc | 18 +- custom_ops/gpu_ops/moba_attn/moba_attn.cu | 330 ------- custom_ops/gpu_ops/moba_attn/moba_attn.h | 204 ----- .../gpu_ops/moba_attn/moba_attn_utils.hpp | 748 ---------------- .../moba_decoder_attn/moba_decoder_attn.cu | 802 ------------------ .../moba_decoder_attn_kernel.h | 225 ----- .../moba_decoder_write_cache.cu | 189 ----- .../moba_decoder_attn/moba_qk_sort_decoder.cu | 236 ------ .../moba_encoder_attn/kernel_traits.h | 143 ---- .../moba_encoder_attn/mainloop_attn.hpp | 473 ----------- .../moba_encoder_attn/moba_encoder_attn.cu | 384 --------- .../moba_encoder_write_cache.cu | 163 ---- .../moba_encoder_attn/moba_qk_sort_encoder.cu | 341 -------- .../moba_attn/moba_encoder_attn/softmax.hpp | 194 ----- .../moba_process/moba_get_kv_from_cache.cu | 288 ------- .../moba_attn/moba_process/moba_mlp_einsum.cu | 221 ----- .../moba_attn/moba_process/moba_qk_gemm.cu | 465 ---------- .../moba_process/split_qkv_and_rope.cu | 370 -------- custom_ops/setup_ops.py | 5 - docs/features/moba_sparse_attention.md | 31 - fastdeploy/config.py | 64 +- fastdeploy/engine/args_utils.py | 25 - fastdeploy/engine/engine.py | 1 - .../layers/attention/__init__.py | 2 - .../layers/attention/attention.py | 41 - .../attention/moba_attention_backend.py | 198 ----- fastdeploy/platforms/base.py | 1 - fastdeploy/platforms/cuda.py | 3 - fastdeploy/rl/rollout_config.py | 2 - fastdeploy/worker/worker_process.py | 10 - tests/layers/test_moba_attention.py | 340 -------- 31 files changed, 10 insertions(+), 6507 deletions(-) delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_attn.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_attn.h delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn_kernel.h delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_write_cache.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_qk_sort_decoder.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_attn.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_write_cache.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_qk_sort_encoder.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_process/moba_get_kv_from_cache.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu delete mode 100644 custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu delete mode 100644 docs/features/moba_sparse_attention.md delete mode 100644 fastdeploy/model_executor/layers/attention/moba_attention_backend.py delete mode 100644 tests/layers/test_moba_attention.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 75470b906..609fc65d3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -784,15 +784,15 @@ void SpeculateStepPaddle( const int max_draft_tokens); void MergePrefillDecodeOutput( - const paddle::Tensor &encoder_res, - const paddle::Tensor &decoder_res, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &cu_seq_q, - const int head_num, - const int head_dim, - const int max_token); + const paddle::Tensor &encoder_res, + const paddle::Tensor &decoder_res, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &cu_seq_q, + const int head_num, + const int head_dim, + const int max_token); std::vector TopPSamplingReject(const paddle::Tensor &probs, const paddle::Tensor &top_p, diff --git a/custom_ops/gpu_ops/moba_attn/moba_attn.cu b/custom_ops/gpu_ops/moba_attn/moba_attn.cu deleted file mode 100644 index 847db9c3e..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_attn.cu +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn.h" - - -std::vector MobaAttention( - const paddle::Tensor& qkv, - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& q_pack_tokens, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& key_cache, - const paddle::Tensor& value_cache, - const paddle::Tensor& block_tables, - const paddle::Tensor& rope_sin_cos, - const paddle::Tensor& k_block_means, - const paddle::optional& attn_gate_weight, - const paddle::optional& qkv_bias, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_seq_len, - const int max_enc_len_this_time, - const int max_dec_len_this_time, - const int moba_encoder_top_k_left, - const int moba_encoder_top_k_right, - const int moba_use_encoder_seq_limit, - const int moba_decoder_top_k_left, - const int moba_decoder_top_k_right, - const int moba_use_decoder_seq_limit, - const bool moba_use_mlp, - const std::string &cache_quant_type_str) { - - paddle::Tensor out = paddle::empty({qkv.dims()[0], head_num * head_dim}, qkv.dtype(), qkv.place()); - if (max_dec_len_this_time > 0) { - MobaDecoderAttnWriteCacheKv( - qkv, - q_input, - cu_seq_q, - cu_seq_k, - seq_len_encoder, - seq_len_decoder, - key_cache, - value_cache, - block_tables, - rope_sin_cos, - k_block_means, - qkv_bias, - cache_k_quant_scale, - cache_v_quant_scale, - cache_k_dequant_scale, - cache_v_dequant_scale, - cache_k_zero_points, - cache_v_zero_points, - head_num, - kv_head_num, - head_dim, - max_seq_len, - cache_quant_type_str); - - auto qk_gate_weight = MobaQKGemm( - q_input, - k_block_means, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - max_dec_len_this_time, - max_dec_len_this_time, - head_num, - kv_head_num, - true, - moba_use_decoder_seq_limit - )[0]; - - auto qk_gate_topk_idx = QkSortDecoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - head_num, - kv_head_num, - moba_decoder_top_k_left, - moba_decoder_top_k_right, - moba_use_decoder_seq_limit - )[0]; - - MobaDecoderAttn( - q_input, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - key_cache, - value_cache, - block_tables, - k_block_means, - out, - qk_gate_topk_idx, - cache_k_quant_scale, - cache_v_quant_scale, - cache_k_dequant_scale, - cache_v_dequant_scale, - cache_k_zero_points, - cache_v_zero_points, - head_num, - kv_head_num, - head_dim, - max_seq_len, - moba_use_decoder_seq_limit, - max_dec_len_this_time, - max_dec_len_this_time, - cache_quant_type_str - ); - } - - if (max_enc_len_this_time > 0) { - FusedBlockMeanAndRope( - qkv, - k_block_means, - q_input, - k_input, - v_input, - rope_sin_cos, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - qkv_bias, - head_num, - kv_head_num, - head_dim, - max_seq_len, - max_enc_len_this_time, - max_enc_len_this_time, - cache_quant_type_str - ); - - MobaEncoderAttnWriteCacheKv( - k_input, - v_input, - cu_seq_k, - seq_len_encoder, - seq_len_decoder, - key_cache, - value_cache, - block_tables, - cache_k_quant_scale, - cache_v_quant_scale, - cache_k_dequant_scale, - cache_v_dequant_scale, - cache_k_zero_points, - cache_v_zero_points, - head_num, - kv_head_num, - head_dim, - max_enc_len_this_time, - cache_quant_type_str - ); - - GetKVFromCache( - k_input, - v_input, - cu_seq_k, - seq_len_encoder, - seq_len_decoder, - key_cache, - value_cache, - block_tables, - cache_k_dequant_scale, - cache_v_dequant_scale, - cache_k_zero_points, - cache_v_zero_points, - head_num, - kv_head_num, - head_dim, - max_seq_len, - max_enc_len_this_time + max_dec_len_this_time, - cache_quant_type_str - ); - - paddle::Tensor *k_gate_weight = const_cast(&k_block_means); - - if (moba_use_mlp && attn_gate_weight) { - paddle::Tensor k_gate_mlp = MobaMlpEinsum( - k_input, - attn_gate_weight.get(), - seq_len_encoder, - seq_len_decoder, - cu_seq_k, - max_seq_len, - kv_head_num - )[0]; - k_gate_weight = &k_gate_mlp; - } - - auto qk_gate_weight = MobaQKGemm( - q_input, - k_block_means, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - max_enc_len_this_time, - max_enc_len_this_time + max_dec_len_this_time, - head_num, - kv_head_num, - false, - moba_use_encoder_seq_limit - )[0]; - - - auto qk_gate_topk_idx = QkSortEncoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - q_pack_tokens, - max_enc_len_this_time, - max_enc_len_this_time + max_dec_len_this_time, - head_num, - kv_head_num, - moba_encoder_top_k_left, - moba_encoder_top_k_right, - moba_use_encoder_seq_limit)[0]; - - MobaEncoderAttn( - q_input, - k_input, - v_input, - qk_gate_topk_idx, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - seq_len_encoder, - seq_len_decoder, - out, - max_enc_len_this_time, - max_enc_len_this_time + max_dec_len_this_time, - head_num, - kv_head_num, - head_dim, - max_seq_len - ); - } - - return {out}; -} - - -PD_BUILD_OP(moba_attention) - .Inputs({ - "qkv", - "q_input", - "k_input", - "v_input", - "cu_seq_q", - "cu_seq_k", - "cu_seq_q_pack", - "q_pack_tokens", - "seq_len_encoder", - "seq_len_decoder", - "key_cache", - "value_cache", - "block_tables", - "rope_sin_cos", - "k_block_means", - paddle::Optional("attn_gate_weight"), - paddle::Optional("qkv_bias"), - paddle::Optional("cache_k_quant_scale"), - paddle::Optional("cache_v_quant_scale"), - paddle::Optional("cache_k_dequant_scale"), - paddle::Optional("cache_v_dequant_scale"), - paddle::Optional("cache_k_zero_points"), - paddle::Optional("cache_v_zero_points")}) - .Attrs({ - "head_num: int", - "kv_head_num: int", - "head_dim: int", - "max_seq_len: int", - "max_enc_len_this_time: int", - "max_dec_len_this_time: int", - "moba_encoder_top_k_left: int", - "moba_encoder_top_k_right: int", - "moba_use_encoder_seq_limit: int", - "moba_decoder_top_k_left: int", - "moba_decoder_top_k_right: int", - "moba_use_decoder_seq_limit: int", - "moba_use_mlp: bool", - "cache_quant_type_str: std::string"}) - .Outputs({ - "out", - "q_input_out", - "k_input_out", - "v_input_out", - "key_cache_out", - "value_cache_out", - "k_block_means_out"}) - .SetInplaceMap({{ - "q_input", "q_input_out"}, - {"k_input", "k_input_out"}, - {"v_input", "v_input_out"}, - {"key_cache", "key_cache_out"}, - {"value_cache", "value_cache_out"}, - {"k_block_means", "k_block_means_out"}}) - .SetKernelFn(PD_KERNEL(MobaAttention)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_attn.h b/custom_ops/gpu_ops/moba_attn/moba_attn.h deleted file mode 100644 index 1b4611022..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_attn.h +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/extension.h" - -void MobaDecoderAttnWriteCacheKv( - const paddle::Tensor& qkv_out, - const paddle::Tensor& q_input, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::Tensor& rope_sin_cos, - const paddle::Tensor& k_block_means, - const paddle::optional& qkv_bias, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const std::string &cache_quant_type_str); - -void MobaEncoderAttnWriteCacheKv( - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_seq_q, - const std::string &cache_quant_type_str); - -void MobaDecoderAttn( - const paddle::Tensor& q_input, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::Tensor& k_block_means, - const paddle::Tensor& out, - const paddle::Tensor& qk_gate_topk_idx, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const int use_moba_seq_limit, - const int max_seq_q, - const int max_seq_k, - const std::string &cache_quant_type_str); - - -void FusedBlockMeanAndRope( - const paddle::Tensor& qkv_out, - const paddle::Tensor& k_block_means, - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& rotary_embs, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::optional& qkv_bias, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const int max_seq_q, - const int max_seq_k, - const std::string &cache_quant_type_str); - -std::vector GetCurCuSeqLenk( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const int pack_size); - -std::vector MobaQKGemm( - const paddle::Tensor& q_input, - const paddle::Tensor& k_block_means, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const bool is_split_kv, - const int use_moba_seq_limit); - -std::vector QkSortDecoder( - const paddle::Tensor& qk_gate_weight, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const int head_num, - const int kv_head_num, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit); - -void GetKVFromCache( - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const int max_seq_k, - const std::string &cache_quant_type_str); - - -void MobaEncoderAttn( - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& qk_gate_topk_idx, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& out, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length); - -std::vector QkSortEncoder( - const paddle::Tensor& qk_gate_weight, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& q_pack_tokens, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit); - -std::vector MobaMlpEinsum( - const paddle::Tensor& k_input, - const paddle::Tensor& attn_gate_weight, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& cu_seq_k, - const int max_seq_len, - const int kv_head_num); diff --git a/custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp b/custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp deleted file mode 100644 index 824b2e0f9..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp +++ /dev/null @@ -1,748 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -#endif -#include -#include -#include -#include -#include -#include "cute/tensor.hpp" -#include "cute/algorithm/copy.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/int_tuple.hpp" -#include -#include -#include "cutlass/layout/layout.h" -#include "cutlass/numeric_types.h" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/cluster_launch.hpp" -#include "cutlass/arch/reg_reconfig.h" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -using namespace cute; - -template -struct PackedHalf; - -template<> -struct PackedHalf { - using Type = __half2; -}; - -template<> -struct PackedHalf { - using Type = nv_bfloat162; -}; - -template<> -struct PackedHalf { - using Type = __half2; -}; - -template<> -struct PackedHalf { - using Type = nv_bfloat162; -}; - - -template -struct HalfSub; - -template<> -struct HalfSub { - inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) { - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num)); - } -}; - -template<> -struct HalfSub { - inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) { - *reinterpret_cast(result_ptr) -= *reinterpret_cast(&magic_num); - } -}; - -template -struct HalfMul; - -template<> -struct HalfMul { - inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) { - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num)); - } -}; - -template<> -struct HalfMul { - inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) { - *reinterpret_cast(result_ptr) *= *reinterpret_cast(&magic_num); - } -}; - - -template -struct HalfMax; -template<> -struct HalfMax { - inline __device__ __half2 operator()(const __half2 x, const __half2 y) { - __half2 res; - asm volatile("max.f16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } -}; - -template<> -struct HalfMax { - inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) { - nv_bfloat162 res; - asm volatile("max.bf16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } -}; - - -template -struct HalfMin; -template<> -struct HalfMin { - inline __device__ __half2 operator()(const __half2 x, const __half2 y) { - __half2 res; - asm volatile("min.f16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } -}; - -template<> -struct HalfMin { - inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) { - nv_bfloat162 res; - asm volatile("min.bf16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } -}; - - -template -struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } -}; - -template -struct MinOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; } -}; - -template <> -struct MinOp { -__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); } -}; - - -template -struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } -}; - -template -inline __device__ static void convert_c8_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) { - uint32_t* half_result_ptr = reinterpret_cast(dst); - if constexpr (std::is_same_v) { - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; - - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(*src, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(*src, fp32_base, 0x7651); - fp32_intermediates_casted[2] = __byte_perm(*src, fp32_base, 0x7652); - fp32_intermediates_casted[3] = __byte_perm(*src, fp32_base, 0x7653); - - #pragma unroll - for (int ii = 0; ii < 4; ++ii) { - fp32_intermediates[ii] -= 8388608.f; - } - - #pragma unroll - for (int ii = 0; ii < 2; ++ii) { - half_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } - } else { - static constexpr uint32_t head_for_fp16 = 0x64006400; - half_result_ptr[0] = __byte_perm(*src, head_for_fp16, 0x7150); - half_result_ptr[1] = __byte_perm(*src, head_for_fp16, 0x7352); - } - - using pack_half = typename PackedHalf::Type; - #pragma unroll - for (int i = 0; i < 2; i++){ - if constexpr (Is_K) { - HalfSub()(half_result_ptr + i, *reinterpret_cast(cache_zp + i * 2)); - HalfMul()(half_result_ptr + i, *reinterpret_cast(cache_scale + i * 2)); - } else { - pack_half bias; - pack_half scale; - bias.x = cache_zp[0]; - bias.y = cache_zp[0]; - scale.x = cache_scale[0]; - scale.y = cache_scale[0]; - HalfSub()(half_result_ptr + i, *reinterpret_cast(&bias)); - HalfMul()(half_result_ptr + i, *reinterpret_cast(&scale)); - } - } -} - -template -inline __device__ static void convert_c4_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) { - using pack_half = typename PackedHalf::Type; - static constexpr uint32_t MASK = 0x0f0f0f0f; - static constexpr uint32_t head_for_fp16 = std::is_same_v ? 0x43004300 : 0x64006400; - static constexpr uint32_t mask_for_c42fp16_one = 0x7253; - static constexpr uint32_t mask_for_c42fp16_two = 0x7051; - uint32_t* result_ptr = reinterpret_cast(dst); - uint32_t source = *reinterpret_cast(src); - // source = {e0 e4 e1 e5 e2 e6 e3 e7} - uint32_t bottom_i4s = source & MASK; - // bottom_i4s = {0 e4 0 e5 0 e6 0 e7} - uint32_t top_i4s = (source >> 4) & MASK; - // top_i4s = {0 e0 0 e1 0 e2 0 e3} - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[0]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one)); - // result_ptr[0] = {e0 e1} - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[1]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[2]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[3]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two)); - - #pragma unroll - for (int i = 0; i < 4; ++i) { - if constexpr (Is_K) { - const int ith_col = i % 2 * 2; - HalfSub()(result_ptr + i, *reinterpret_cast(cache_zp + ith_col)); - HalfMul()(result_ptr + i, *reinterpret_cast(cache_scale + ith_col)); - } else { - const int ith_col = i / 2; - pack_half bias; - pack_half scale; - bias.x = cache_zp[ith_col]; - bias.y = cache_zp[ith_col]; - scale.x = cache_scale[ith_col]; - scale.y = cache_scale[ith_col]; - HalfSub()(result_ptr + i, *reinterpret_cast(&bias)); - HalfMul()(result_ptr + i, *reinterpret_cast(&scale)); - } - } -} - -template -inline __device__ void gemm_qk_quant( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB, - Tensor4 const& sB, TiledMma tiled_mma, - ThrCopy0 smem_thr_copy_A, - TiledCopy0 smem_tiled_copy_A, - const int32_t tidx, - const T * cache_scale, const T * cache_zp) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); - if (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); - } - uint32_t *sBdata = reinterpret_cast(sB.data().get()) + tidx * (kDataNumPer2Byte / 4); - - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); - } - } - if constexpr (kDataNumPer2Byte == 4) { - convert_c4_2_half(sBdata + i * kHeadDim, tCrB.data(), cache_scale + i * 4, cache_zp + i * 4); - } else { - convert_c8_2_half(sBdata + i * (kHeadDim * 2), tCrB.data(), cache_scale + i * 4, cache_zp + i * 4); - convert_c8_2_half(sBdata + i * (kHeadDim * 2) + 1, tCrB.data() + 4, cache_scale + i * 4, cache_zp + i * 4); - } - - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc); - } -} - -template -inline __device__ void gemm_value_quant( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB, - Tensor4 const& sB, TiledMma tiled_mma, - ThrCopy0 smem_thr_copy_A, - TiledCopy0 smem_tiled_copy_A, - int32_t tidx, - const T * cache_scale, const T * cache_zp) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); - if (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); - } - uint32_t *sBdata = reinterpret_cast(sB.data().get()) + tidx * (2 * kDataNumPer2Byte / 4); - - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - const int cur_idx = i * kHeadDim * (2 * kDataNumPer2Byte / 4); - - if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); - } - } - if constexpr (kDataNumPer2Byte == 4) { - convert_c4_2_half(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp); - convert_c4_2_half(sBdata + cur_idx + 1, tCrB.data() + 8, cache_scale + 2, cache_zp + 2); - } else { - convert_c8_2_half(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp); - convert_c8_2_half(sBdata + cur_idx + 1, tCrB.data() + 4, cache_scale + 1, cache_zp + 1); - convert_c8_2_half(sBdata + cur_idx + 2, tCrB.data() + 8, cache_scale + 2, cache_zp + 2); - convert_c8_2_half(sBdata + cur_idx + 3, tCrB.data() + 12, cache_scale + 3, cache_zp + 3); - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc); - } -} - - -template -inline __device__ void apply_mask(Tensor &scores, const uint32_t warp_id, const uint32_t col, const uint32_t reamin_seq_len) { - const int cols = size<1>(scores) / 2; - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - #pragma unroll - for (int ni = 0; ni < cols; ++ni) { - const int col_index = warp_id * 8 + ni * 32 + col * 2; - if (col_index >= reamin_seq_len) { - scores(mi, ni * 2) = -INFINITY; - } - if (col_index + 1 >= reamin_seq_len) { - scores(mi, ni * 2 + 1) = -INFINITY; - } - } - } -} - - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -template -__device__ inline void reduce_max(Tensor const& tensor, T *scores_max){ - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - MaxOp max_op; - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ni++) { - scores_max[mi] = max_op(scores_max[mi], tensor(mi, ni)); - } - scores_max[mi] = Allreduce<4>::run(scores_max[mi], max_op); - } -} - -template -inline __device__ void scale_apply_exp2(Tensor &tensor, T const *max, T *sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - const float max_scaled = max[mi] * scale; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled); - sum[mi] += tensor(mi, ni); - } - } -} - - -template -struct cuteType; - -template <> -struct cuteType { - using type = cutlass::half_t; -}; - -template <> -struct cuteType { - using type = cutlass::bfloat16_t; -}; - -template -__forceinline__ __device__ auto float_2_half2(const float x) { - if constexpr (std::is_same::value) { - return __float2half2_rn(x); - } else { - return __float2bfloat162_rn(x); - } -} - - -struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; -}; - - -struct uint8 { - uint4 u; - uint4 v; -}; - -template -struct BytesToType {}; - -template<> -struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); -}; - -template<> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template<> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -template -struct Vec { - - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - - using Vec_type = typename BytesToType::Type; - - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; - - Alias_type data; - - inline __device__ Vec() {} - - template - inline __device__ void to(Vec &other) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } - } - - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } - } - - inline __device__ void load_from(const void *base_ptr) { - this->data.vec = *reinterpret_cast(base_ptr); - } - - - inline __device__ void store_to(void *base_ptr) { - *reinterpret_cast(base_ptr) = this->data.vec; - } - - inline __device__ void add(const Vec &other) { - static_assert(NUM_ELT % 2 == 0); - using type = typename PackedHalf::Type; - #pragma unroll - for (int it = 0; it < NUM_ELT / 2; it++) { - type b = *reinterpret_cast(other.data.elt + it * 2); - *reinterpret_cast(this->data.elt + it * 2) += b; - } - } - - inline __device__ void set_zero() { - constexpr int size = sizeof(Vec_type) / sizeof(int); - #pragma unroll - for (int i = 0; i < size; ++i) { - (reinterpret_cast(this->data.elt))[i] = 0; - } - } - - inline __device__ void fma(const Vec &scale, const Vec &bias) { - static_assert(NUM_ELT % 2 == 0); - using type = typename PackedHalf::Type; - #pragma unroll - for (int it = 0; it < NUM_ELT / 2; it++) { - type a = *reinterpret_cast(scale.data.elt + it * 2); - type b = *reinterpret_cast(bias.data.elt + it * 2); - *reinterpret_cast(this->data.elt + it * 2) += a * b; - } - } -}; - -template -inline __device__ void apply_rotary_embedding(Vec& vec, Vec& cos, Vec& sin) { - static_assert(PackSize % 2 == 0); - #pragma unroll - for (int i = 0; i < PackSize / 2; i++) { - const float cos_inv_freq = cos.data.elt[i]; - const float sin_inv_freq = sin.data.elt[i]; - const float v1 = static_cast(vec.data.elt[2 * i]); - const float v2 = static_cast(vec.data.elt[2 * i + 1]); - vec.data.elt[2 * i] = static_cast(cos_inv_freq * v1 - sin_inv_freq * v2); - vec.data.elt[2 * i + 1] = static_cast(sin_inv_freq * v1 + cos_inv_freq * v2); - } -} - -template -__forceinline__ __device__ void copy( - TiledCopy tiled_copy, Tensor const &S, - Tensor &D, - Tensor const &identity_MN, - const int max_MN = 0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); - } - } - } -} - -template -inline __device__ void gemm( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, - Tensor4 const& tCsB, TiledMma tiled_mma, - ThrCopy0 &smem_thr_copy_A, ThrCopy1 &smem_thr_copy_B, - TiledCopy0 &smem_tiled_copy_A, TiledCopy1 &smem_tiled_copy_B) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); - - if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } - - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } -} - -template -inline __device__ auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); -} - -template -__inline__ __device__ T BlockAllReduce(T val) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T result_broadcast; - T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); - if (threadIdx.x == 0) { result_broadcast = result; } - __syncthreads(); - return result_broadcast; -} - -template -__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { - using X = Underscore; - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); - auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) - return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } - } -}; - -template -__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { - constexpr bool Is_RS = !cute::is_base_of::value; - if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } - warpgroup_fence_operand(tCrC); - if constexpr (arrive) { - warpgroup_arrive(); - } - if constexpr (zero_init) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } else { - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } - if constexpr (commit) { - warpgroup_commit_batch(); - } - if constexpr (wg_wait >= 0) { warpgroup_wait(); } - warpgroup_fence_operand(tCrC); - if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } -} - - -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = acc_layout; - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - } -}; - -template -__inline__ __device__ T WarpAllReduce(T val) { - ReductionOp op; - #pragma unroll - for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { - val = op(val, __shfl_xor_sync(0xffffffff, val, mask)); - } - return val; -} - - -template -__device__ inline int get_data_count(const float * src, const float limit_value) { - int count = 0; - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - if (src[i] >= limit_value) { - count++; - } - } - count = BlockAllReduce, knthreads>(count); - return count; -} diff --git a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn.cu b/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn.cu deleted file mode 100644 index 82af1586a..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn.cu +++ /dev/null @@ -1,802 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/extension.h" -#include "moba_decoder_attn_kernel.h" -#include "moba_attn/moba_attn.h" - - -template -inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &acc_o, const T *scores_max, const T *scores_max_prev, T * scores_sum, const float softmax_scale) { - if (Is_first) { - scale_apply_exp2(scores, scores_max, scores_sum, softmax_scale); - } else { - Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - const float scores_scale = expf((scores_max_prev[mi] - scores_max[mi]) * softmax_scale); - scores_sum[mi] *= scores_scale; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale; - } - } - scale_apply_exp2(scores, scores_max, scores_sum, softmax_scale); - } -}; - -template -__global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attention_kernel(ParamType params) { - using cuteType = typename Kernel_traits::cuteType; - using ElementAccum = typename Kernel_traits::ElementAccum; - using CacheKV_traits = typename Kernel_traits::CacheKV_traits; - constexpr int32_t kHeadDim = Kernel_traits::kHeadDim; - constexpr int32_t kHeadDimKV = Kernel_traits::kHeadDimKV; - constexpr int32_t kBlockM = Kernel_traits::kBlockM; - constexpr int32_t kBlockSize = Kernel_traits::kBlockSize; - constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize; - constexpr int32_t kNWarps = Kernel_traits::kNWarps; - constexpr int32_t kTileN = Kernel_traits::kTileN; - constexpr int32_t kBlockN = kTileN * kBlockSize; - constexpr int32_t kDataBits = Kernel_traits::kDataBits; - constexpr int32_t kMiLen = (kGqaGroupSize + 7) / 8; - - const int32_t bi = blockIdx.y; - const int32_t tidx = threadIdx.x; - const int32_t partition_idx = blockIdx.x; - const int32_t kv_head_idx = blockIdx.z; - const int32_t q_head_idx = kv_head_idx * kGqaGroupSize; - - const int32_t seq_len = params.seq_lens_decoder[bi] == 0 ? 0 : params.seq_lens_decoder[bi] + 1; - - const int32_t head_num = params.head_num; - const int32_t kv_head_num = params.kv_head_num; - - const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN; - - if (seq_len == 0 || partition_idx >= partition_num) { - return; - } - - if (seq_len >= params.use_moba_seq_limit && params.qk_gate_topk_idx_ptr[(bi * kv_head_num + kv_head_idx) * Kernel_traits::kMaxN + partition_idx] == 0) { - return; - } - - - const int q_bias_offset = q_head_idx * kHeadDim; - - cuteType * q_input = reinterpret_cast(params.q_input) + params.cu_seq_q[bi] * head_num * kHeadDim; - - Tensor gQ = make_tensor( - make_gmem_ptr(reinterpret_cast(q_input) + q_bias_offset), - Shape, Int>{}, - Stride, _1>{}); - - const int32_t block_idx = partition_idx * kTileN; - const int* block_table = params.block_table + bi * params.max_num_blocks_per_seq + block_idx; - const int32_t physical_block_number = block_table[0]; - - const int32_t cache_offset = (physical_block_number * kv_head_num + kv_head_idx) * kBlockSize * kHeadDimKV; - - Tensor gK = make_tensor( - make_gmem_ptr(reinterpret_cast(params.cache_k) + cache_offset), - Shape, Int>{}, - Stride, _1>{}); - - Tensor gV = make_tensor( - make_gmem_ptr(reinterpret_cast(params.cache_v) + cache_offset), - Shape, Int>{}, - Stride, _1>{}); - - extern __shared__ char smem_[]; - Tensor sQ = make_tensor( - make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - Tensor sQK = make_tensor( - sQ.data() + size(sQ), - typename Kernel_traits::SmemLayoutQK{}); - - Tensor sK = make_tensor(sQK.data() + size(sQK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - __shared__ ElementAccum scores_warp[kNWarps][kMiLen * kBlockM]; - - auto gmem_tiled_copy_Q = typename Kernel_traits::GmemTiledCopyQ{}; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); - - auto gmem_tiled_copy_KV = typename Kernel_traits::GmemTiledCopyKV{}; - auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - - Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); - Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); - Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); - - Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim)); - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); - - Tensor cKV = make_identity_tensor(make_shape(kBlockSize, kHeadDim)); - Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); - - typename Kernel_traits::TiledMma tiled_mma; - - auto thr_mma = tiled_mma.get_thread_slice(tidx); - using SmemCopyAtom = typename Kernel_traits::SmemCopyAtom; - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); - - Tensor tSsQK = smem_thr_copy_Q.partition_S(sQK); - Tensor tSrQK = thr_mma.partition_fragment_A(sQK); - - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - Tensor tSrK = thr_mma.partition_fragment_B(sK); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); - - copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, kGqaGroupSize); - - - cute::cp_async_fence(); - cp_async_wait<0>(); - - const int32_t remain_seq_len = seq_len - partition_idx * kTileN * kBlockSize; - - copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV); - - cute::cp_async_fence(); - - const int32_t warp_id = tidx / 32; - const int32_t lane_id = tidx % 32; - const int32_t row = lane_id / 4; - const int32_t col = lane_id % 4; - const int row_idx = tidx / 4; - - using scale_k_vec = Vec; - using scale_v_vec = Vec; - - scale_k_vec scale_k; - scale_k_vec zp_k; - scale_v_vec scale_v; - scale_v_vec zp_v; - if constexpr (kDataBits == 4) { - scale_k = *reinterpret_cast(params.cache_k_dequant_scale + kv_head_idx * kHeadDim + col * 32); - zp_k = *reinterpret_cast(params.cache_k_zp + kv_head_idx * kHeadDim + col * 32); - scale_v = *reinterpret_cast(params.cache_v_dequant_scale + kv_head_idx * kHeadDim + row_idx * 4); - zp_v = *reinterpret_cast(params.cache_v_zp + kv_head_idx * kHeadDim + row_idx * 4); - } - - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); - clear(acc_o); - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); - - ElementAccum scores_max[kMiLen]; - ElementAccum scores_max_prev[kMiLen]; - ElementAccum scores_sum[kMiLen]; - - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - scores_max[mi] = -INFINITY; - scores_sum[mi] = 0; - } - - const int cache_offset_step = kv_head_num * kBlockSize * kHeadDimKV; - - #pragma unroll - for (int n = 0; n < kTileN; ++n) { - const int cur_remain_seq_len = remain_seq_len - n * kBlockSize; - - if (cur_remain_seq_len <= 0) { - break; - } - - clear(acc_s); - cp_async_wait<0>(); - __syncthreads(); - - if (n > 0) { - tVgV.data() = tVgV.data() + (block_table[n] - block_table[n - 1]) * cache_offset_step; - } - - copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV); - - cute::cp_async_fence(); - - if constexpr (kDataBits == 16) { - if (n == 0) { - gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K); - } else { - gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K); - } - } else { - Tensor tSrKQuant = make_tensor( - Layout< - Shape, Int>, - Stride, _4>>{}); - if (n == 0) { - gemm_qk_quant(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt); - } else { - gemm_qk_quant(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt); - } - } - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - - if (partition_idx == partition_num - 1 && cur_remain_seq_len < kBlockSize) { - apply_mask(scores, warp_id, col, cur_remain_seq_len); - } - - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - scores_max_prev[mi] = scores_max[mi]; - } - - reduce_max(scores, scores_max); - - if (col == 0) { - scores_warp[warp_id][row] = scores_max[0]; - if constexpr (kMiLen > 1) { - scores_warp[warp_id][row + 8] = scores_max[1]; - } - } - - __syncthreads(); - - MaxOp max_op; - - if (tidx < kGqaGroupSize) { - float cur_max = scores_warp[0][tidx]; - #pragma unroll - for (uint32_t i = 1; i < kNWarps; ++i) { - cur_max = max_op(scores_warp[i][tidx], cur_max); - } - scores_warp[0][tidx] = cur_max; - } - - cp_async_wait<0>(); - __syncthreads(); - - if (cur_remain_seq_len > kBlockSize && n < kTileN - 1) { - tKgK.data() = tKgK.data() + (block_table[n + 1] - block_table[n]) * cache_offset_step; - copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV); - cute::cp_async_fence(); - } - - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - scores_max[mi] = scores_warp[0][row + mi * 8]; - } - - if (n == 0) { - softmax_rescale_o(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh); - } else { - softmax_rescale_o(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh); - } - - Tensor rS = convert_type(acc_s); - - Tensor trQK = smem_thr_copy_O.retile_S(rS); - Tensor tsQK = smem_thr_copy_O.partition_D(sQK); - cute::copy(smem_tiled_copy_O, trQK, tsQK); - - __syncthreads(); - - if constexpr (kDataBits == 16) { - gemm(acc_o, tSrQK, tOrVt, tSsQK, tOsVt, tiled_mma, smem_thr_copy_Q, smem_thr_copy_V, smem_tiled_copy_Q, smem_tiled_copy_V); - } else { - Tensor tSrVQuant = make_tensor( - Layout< - Shape<_4, Shape<_2, _2>>, - Stride<_1, Shape<_4, _8>>>{}); - gemm_value_quant(acc_o, tSrQK, tSsQK, tSrVQuant, sV, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_v.data.elt, zp_v.data.elt); - } - } - - const uint32_t pack_max_partition_num = (params.max_num_partitions + 3) / 4 * 4; - uint32_t max_sum_offset = bi * pack_max_partition_num * head_num + (tidx + q_head_idx) * pack_max_partition_num + partition_idx; - - if (tidx < kGqaGroupSize) { - params.maxs[max_sum_offset] = scores_warp[0][tidx] * params.inv_sqrt_dh; - } - - SumOp sum_op; - #pragma unroll - for (int mi = 0; mi < kMiLen; ++mi) { - scores_sum[mi] = Allreduce<4>::run(scores_sum[mi], sum_op); - } - __syncthreads(); - - if (col == 0) { - scores_warp[warp_id][row] = scores_sum[0]; - if constexpr (kMiLen > 1) { - scores_warp[warp_id][row + 8] = scores_sum[1]; - } - } - - - Tensor rO = convert_type(acc_o); - Tensor taccOrO = smem_thr_copy_O.retile_S(rO); - Tensor taccOsO = smem_thr_copy_O.partition_D(sQ); - - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - - __syncthreads(); - - if (tidx < kGqaGroupSize) { - float cur_sum = scores_warp[0][tidx]; - #pragma unroll - for (uint32_t i = 1; i < kNWarps; ++i) { - cur_sum = sum_op(scores_warp[i][tidx], cur_sum); - } - scores_warp[0][tidx] = cur_sum; - } - - Tensor gO = make_tensor( - make_gmem_ptr(reinterpret_cast(params.partition_attn_out) + ((bi * params.max_num_partitions + partition_idx) * head_num + q_head_idx)* kHeadDim), - Shape, Int>{}, - Stride, _1>{}); - - auto gmem_tiled_copy_O = typename Kernel_traits::GmemTiledCopyO{}; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOsO = gmem_thr_copy_O.partition_S(sQ); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - constexpr int32_t copy_size = kGqaGroupSize * 16; - __syncthreads(); - - if (tidx < copy_size) { - cute::copy(gmem_tiled_copy_O, tOsO(_, 0, _), tOgO(_, 0, _)); - } - - if constexpr (kMiLen > 1) { - if (tidx < copy_size - 128) { - cute::copy(gmem_tiled_copy_O, tOsO(_, 1, _), tOgO(_, 1, _)); - } - } - - if (tidx < kGqaGroupSize) { - params.sums[max_sum_offset] = scores_warp[0][tidx]; - } -} - - -template -inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) { - constexpr int32_t kNFloatPacksize = 16 / sizeof(float); - constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads; - const int32_t bi = blockIdx.z; - const int32_t tidx = threadIdx.x; - const int32_t head_idx = blockIdx.y; - const int32_t head_num = params.head_num; - - using float_vec = Vec; - const int32_t offset = bi * head_num * pack_max_partition_num + head_idx * pack_max_partition_num; - - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = params.maxs + offset; - float global_max_logit = -FLT_MAX; - - int32_t idx = tidx * kNFloatPacksize; - #pragma unroll - for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) { - float_vec cur_max = *reinterpret_cast(max_logits_ptr + idx); - #pragma unroll - for (int32_t j = 0; j < kNFloatPacksize; ++j) { - if (seq_len >= params.use_moba_seq_limit) { - if (qk_gate_topk_idx_ptr[idx + j] != 0) { - global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]); - } - } else { - global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]); - } - } - cur_max.store_to(shared_max_logits + idx); - } - - const int32_t packed_data_num = partition_num / kNFloatPacksize * kNFloatPacksize; - - idx = packed_data_num + tidx; - #pragma unroll - for (; idx < partition_num; idx += kNReduceThreads) { - if (seq_len >= params.use_moba_seq_limit) { - if (qk_gate_topk_idx_ptr[idx] != 0) { - float cur_max = max_logits_ptr[idx]; - global_max_logit = fmaxf(global_max_logit, cur_max); - shared_max_logits[idx] = cur_max; - } - } else { - float cur_max = max_logits_ptr[idx]; - global_max_logit = fmaxf(global_max_logit, cur_max); - shared_max_logits[idx] = cur_max; - } - } - __syncthreads(); - - global_max_logit = BlockAllReduce, kNReduceThreads>(global_max_logit); - - float* share_sum_scale = reinterpret_cast(shared_mem + sizeof(float) * pack_max_partition_num); - const float* exp_sums_ptr = params.sums + offset; - float global_exp_sum = 0.0f; - - idx = tidx * kNFloatPacksize; - #pragma unroll - for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) { - float_vec share_max = *reinterpret_cast(shared_max_logits + idx); - #pragma unroll - for (int32_t j = 0; j < kNFloatPacksize; ++j) { - if (seq_len >= params.use_moba_seq_limit) { - if (qk_gate_topk_idx_ptr[idx + j] != 0) { - float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit); - float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max; - global_exp_sum += rescaled_exp_sum; - share_max.data.elt[j] = exp_sub_max; - } - } else { - float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit); - float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max; - global_exp_sum += rescaled_exp_sum; - share_max.data.elt[j] = exp_sub_max; - } - } - share_max.store_to(share_sum_scale + idx); - } - - idx = packed_data_num + tidx; - #pragma unroll - for (; idx < partition_num; idx += kNReduceThreads) { - if (seq_len >= params.use_moba_seq_limit) { - if (qk_gate_topk_idx_ptr[idx] != 0) { - float share_max = shared_max_logits[idx]; - float exp_sub_max = expf(share_max - global_max_logit); - float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max; - global_exp_sum += rescaled_exp_sum; - share_sum_scale[idx] = exp_sub_max; - } - } else { - float share_max = shared_max_logits[idx]; - float exp_sub_max = expf(share_max - global_max_logit); - float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max; - global_exp_sum += rescaled_exp_sum; - share_sum_scale[idx] = exp_sub_max; - } - } - __syncthreads(); - - global_exp_sum = BlockAllReduce, kNReduceThreads>(global_exp_sum); - - const float inv_global_exp_sum = fdividef(1.0f, global_exp_sum + 1e-6f); - return inv_global_exp_sum; -} - -template -__global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_attention_merge_kernel(ParamType params) { - using cuteType = typename Kernel_traits::cuteType; - constexpr int32_t kBlockN = Kernel_traits::kTileN * Kernel_traits::kBlockSize; - constexpr int32_t kNReducePacksize = 16 / sizeof(cuteType); - constexpr int32_t kNFloatPacksize = 16 / sizeof(float); - constexpr int32_t kNReduceWarps = Kernel_traits::kNReduceWarps; - constexpr int32_t kHeadDim = Kernel_traits::kHeadDim; - const int32_t bi = blockIdx.z; - const int32_t headdim_idx = kNReducePacksize * kNReduceWarps * blockIdx.x; - const int32_t tidx = threadIdx.x; - const int32_t head_idx = blockIdx.y; - const int32_t warp_id = tidx / 32; - const int32_t lane_id = tidx % 32; - const int32_t seq_len = params.seq_lens_decoder[bi] + 1; - const int32_t head_num = params.head_num; - using pack_half = typename PackedHalf::Type; - - - if (params.seq_lens_decoder[bi] == 0) { - return; - } - - extern __shared__ char shared_mem[]; - - const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN; - const int32_t pack_max_partition_num = (params.max_num_partitions + kNFloatPacksize - 1) / kNFloatPacksize * kNFloatPacksize; - - float* share_sum_scale = reinterpret_cast(shared_mem + sizeof(float) * pack_max_partition_num); - - constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize; - const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize; - const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN; - - float inv_global_exp_sum = caluate_logit_scale(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr); - - - using T_vec = Vec; - - cuteType* partition_attn_out = reinterpret_cast(params.partition_attn_out) + bi * head_num * params.max_num_partitions * kHeadDim + head_idx * kHeadDim + headdim_idx; - - Vec acc; - acc.set_zero(); - #pragma unroll - for (int idx = lane_id; idx < partition_num; idx += 32) { - if (seq_len >= params.use_moba_seq_limit && qk_gate_topk_idx_ptr[idx] == 0) { - continue; - } - T_vec sub_logits = *reinterpret_cast(&partition_attn_out[idx * head_num * kHeadDim + warp_id * kNReducePacksize]); - float scale = share_sum_scale[idx]; - #pragma unroll - for (int k = 0; k < kNReducePacksize; ++k) { - acc.data.elt[k] += static_cast(sub_logits.data.elt[k]) * scale; - } - } - - __syncthreads(); - - T_vec out; - #pragma unroll - for (int k = 0; k < kNReducePacksize; ++k) { - out.data.elt[k] = static_cast(WarpAllReduce>(acc.data.elt[k]) * inv_global_exp_sum); - } - - const int ori_token_idx = params.cu_seq_q[bi]; - cuteType * attn_out = reinterpret_cast(params.attn_out) + ori_token_idx * head_num * kHeadDim + head_idx * kHeadDim + headdim_idx + warp_id * kNReducePacksize; - - if (lane_id == 0) { - out.store_to(attn_out); - } -} - - -template -void run_moba_decoder_attn(ParamType ¶ms, cudaStream_t stream) { - dim3 grid; - grid.x = params.max_num_partitions; - grid.y = params.batch_size; - grid.z = params.kv_head_num; - constexpr int smem_size = Kernel_traits::kShareMemSize; - constexpr auto kernel = &moba_decoder_attention_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); - - int32_t reduce_shared_mem_size = 2 * (params.max_num_partitions + 4) * sizeof(float); - constexpr int32_t pack_size = 16 / sizeof(typename Kernel_traits::cuteType); - static_assert(Kernel_traits::kHeadDim % pack_size == 0); - static_assert((Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps) % pack_size == 0); - grid.x = Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps / pack_size; - grid.y = params.head_num; - grid.z = params.batch_size; - auto reduce_kernel = &moba_decoder_attention_merge_kernel; - - if (reduce_shared_mem_size >= 48 * 1024) { - cudaFuncSetAttribute( - reduce_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size); - } - reduce_kernel<<>>(params); -} - - -template -void run_moba_decoder_attn_hdim128(ParamType ¶ms, cudaStream_t stream) { - const int gqaGroupSize = params.head_num / params.kv_head_num; - using CacheKVTraits = CacheKV_quant_traits; - constexpr int kTileN = kBlockN / CacheKVTraits::kBlockSize; - switch (gqaGroupSize) { - case 12: { - run_moba_decoder_attn>(params, stream); - break; - } - case 8: { - run_moba_decoder_attn>(params, stream); - break; - } - case 7: { - run_moba_decoder_attn>(params, stream); - break; - } - case 6: { - run_moba_decoder_attn>(params, stream); - break; - } - case 5: { - run_moba_decoder_attn>(params, stream); - break; - } - case 4: { - run_moba_decoder_attn>(params, stream); - break; - } - default: { - PADDLE_THROW(phi::errors::Unimplemented( - "DecoderBlockAttention not implemented for gqaGroupSize = %d", gqaGroupSize)); - } - } -} - - -template -void DispatchMobaDecoderAttn( - const paddle::Tensor& q_input, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::Tensor& k_block_means, - const paddle::Tensor& out, - const paddle::Tensor& qk_gate_topk_idx, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_seq_q, - const int max_seq_k, - const int batch_size, - const int max_input_length, - const int use_moba_seq_limit, - const std::string &cache_quant_type_str) { - - using cute_type = typename cuteType::type; - const int kMobaBlockSize = 128; - const int kMaxN = 1024; - - constexpr int max_seq_per_block = kMobaBlockSize; - moba_decoder_attn_params params; - memset(¶ms, 0, sizeof(params)); - const uint32_t max_num_partitions = (max_seq_k + max_seq_per_block) / max_seq_per_block; - assert(head_dim == 128); - - paddle::Tensor maxs = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place()); - paddle::Tensor sums = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place()); - paddle::Tensor partition_attn_out = paddle::empty({batch_size, max_num_partitions, head_num, head_dim}, q_input.dtype(), q_input.place()); - - params.q_input = reinterpret_cast(const_cast(q_input.data())); - params.attn_out = reinterpret_cast(const_cast(out.data())); - params.seq_lens_encoder = const_cast(seq_len_encoder.data()); - params.seq_lens_decoder = const_cast(seq_len_decoder.data()); - params.block_table = const_cast(block_tables.data()); - params.max_input_length = max_input_length; - params.head_num = head_num; - params.kv_head_num = kv_head_num; - params.max_num_blocks_per_seq = block_tables.dims()[1]; - params.batch_size = batch_size; - params.inv_sqrt_dh = 1.0f / std::sqrt(head_dim); - params.max_num_partitions = max_num_partitions; - params.maxs = reinterpret_cast(maxs.data()); - params.sums = reinterpret_cast(sums.data()); - params.partition_attn_out = reinterpret_cast(partition_attn_out.data()); - params.qk_gate_topk_idx_ptr = const_cast(qk_gate_topk_idx.data()); - params.use_moba_seq_limit = use_moba_seq_limit; - params.cu_seq_q = const_cast(cu_seq_q.data()); - - - if (cache_quant_type_str == "none") { - params.cache_k = reinterpret_cast(const_cast(cache_k.data())); - params.cache_v = reinterpret_cast(const_cast(cache_v.data())); - run_moba_decoder_attn_hdim128(params, q_input.stream()); - } else { - params.cache_k = const_cast(cache_k.data()); - params.cache_v = const_cast(cache_v.data()); - params.cache_k_quant_scale = reinterpret_cast(const_cast(cache_k_quant_scale.get().data())); - params.cache_v_quant_scale = reinterpret_cast(const_cast(cache_v_quant_scale.get().data())); - params.cache_k_dequant_scale = reinterpret_cast(const_cast(cache_k_dequant_scale.get().data())); - params.cache_v_dequant_scale = reinterpret_cast(const_cast(cache_v_dequant_scale.get().data())); - params.cache_k_zp = reinterpret_cast(const_cast(cache_k_zero_points.get().data())); - params.cache_v_zp = reinterpret_cast(const_cast(cache_v_zero_points.get().data())); - if (cache_quant_type_str == "cache_int8_zp") { - run_moba_decoder_attn_hdim128(params, q_input.stream()); - } else if (cache_quant_type_str == "cache_int4_zp") { - run_moba_decoder_attn_hdim128(params, q_input.stream()); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "GQA Attention not implemented for cache_quant_type_str = %s", cache_quant_type_str.c_str())); - } - } -} - -void MobaDecoderAttn( - const paddle::Tensor& q_input, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::Tensor& k_block_means, - const paddle::Tensor& out, - const paddle::Tensor& qk_gate_topk_idx, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const int use_moba_seq_limit, - const int max_seq_q, - const int max_seq_k, - const std::string &cache_quant_type_str) { - - const int batch_size = block_tables.dims()[0]; - if (q_input.dtype() == paddle::DataType::FLOAT16) { - return DispatchMobaDecoderAttn( - q_input, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cache_k, - cache_v, - block_tables, - k_block_means, - out, - qk_gate_topk_idx, - cache_k_quant_scale, - cache_v_quant_scale, - cache_k_dequant_scale, - cache_v_dequant_scale, - cache_k_zero_points, - cache_v_zero_points, - head_num, - kv_head_num, - head_dim, - max_seq_q, - max_seq_k, - batch_size, - max_input_length, - use_moba_seq_limit, - cache_quant_type_str); - } else if (q_input.dtype() == paddle::DataType::BFLOAT16) { - return DispatchMobaDecoderAttn( - q_input, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cache_k, - cache_v, - block_tables, - k_block_means, - out, - qk_gate_topk_idx, - cache_k_quant_scale, - cache_v_quant_scale, - cache_k_dequant_scale, - cache_v_dequant_scale, - cache_k_zero_points, - cache_v_zero_points, - head_num, - kv_head_num, - head_dim, - max_seq_q, - max_seq_k, - batch_size, - max_input_length, - use_moba_seq_limit, - cache_quant_type_str); - } -} diff --git a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn_kernel.h b/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn_kernel.h deleted file mode 100644 index 1b21f9b32..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_attn_kernel.h +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/extension.h" -#include "cute/tensor.hpp" -#include "cute/algorithm/copy.hpp" -#include "cute/algorithm/gemm.hpp" -#include "../moba_attn_utils.hpp" - -using namespace cute; -template -struct moba_decoder_attn_params { - T *__restrict__ q_input; - void *__restrict__ cache_k; - void *__restrict__ cache_v; - - T *__restrict__ attn_out; - T *__restrict__ partition_attn_out; - T *__restrict__ cache_k_dequant_scale; - T *__restrict__ cache_v_dequant_scale; - T *__restrict__ cache_k_quant_scale; - T *__restrict__ cache_v_quant_scale; - T *__restrict__ cache_k_zp; - T *__restrict__ cache_v_zp; - int * __restrict__ cu_seq_q; - float * sums; - float * maxs; - int * seq_lens_encoder; - int * seq_lens_decoder; - int * block_table; - int max_input_length; - int max_seq_len; - int head_num; - int kv_head_num; - int max_num_blocks_per_seq; - float scale_softmax; - int batch_size; - int max_num_partitions; - float inv_sqrt_dh; - int *qk_gate_topk_idx_ptr; - int use_moba_seq_limit; -}; - -template -struct CacheKV_quant_traits { - using cuteType = cute_type_; - static constexpr int kDataBits = DataBits_; - static constexpr int kBlockSize = 64; - static constexpr int kHeadDim = 128; - static constexpr int kBlockKSmem = 64; - using SmemLayoutAtomQ = decltype( - composition(Swizzle<3, 3, 3>{}, - Layout< - Shape, Int>, - Stride, _1>>{})); - - using SmemLayoutKV = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - static constexpr int kNWarps = 4; - static constexpr int kNThreads = kNWarps * 32; - - - static constexpr int kThreadPerValue = 16 / sizeof(cuteType); - static constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue; - - using GmemLayoutAtom = Layout< - Shape , Int>, - Stride, _1>>; - - using GmemTiledCopyQ = decltype( - make_tiled_copy(Copy_Atom< - SM80_CP_ASYNC_CACHEGLOBAL, cuteType>{}, - GmemLayoutAtom{}, - Layout>>{})); - - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; - - using ValLayoutMNK = Layout>; - - using PermutationMNK = Tile<_16, Int<16 * kNWarps>, _16>; - - using TiledMma = TiledMMA< - MMA_Atom_Arch, - ValLayoutMNK, - PermutationMNK>; - - using SmemCopyAtom = Copy_Atom; - - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle<3, 3, 3>{}, - Layout, Int>, - Stride<_1, Int>>{})); - - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - - using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - - using SmemCopyAtomTransposed = Copy_Atom; - - static constexpr int kShareMemSize = size(SmemLayoutKV{}) * 2 * sizeof(cuteType); -}; - -template -struct moba_decoder_attn_kernel_traits { - using ElementAccum = float; - using CacheKV_traits = CacheKV_traits_; - using cuteType = typename CacheKV_traits::cuteType; - static constexpr int kDataBits = CacheKV_traits::kDataBits; - static constexpr int kTileN = kTileN_; - static constexpr int kMaxN = kMaxN_; - static constexpr int kGqaGroupSize = kGqaGroupSize_; - static constexpr int kHeadDim = CacheKV_traits::kHeadDim; - static constexpr int kHeadDimKV = kHeadDim / (16 / kDataBits); - static constexpr int kMinGemmM = 16; - static constexpr int kBlockM = (kGqaGroupSize + kMinGemmM - 1) / kMinGemmM * kMinGemmM; - static constexpr int kBlockSize = CacheKV_traits::kBlockSize; - static_assert(kGqaGroupSize <= 16); - static constexpr int32_t kNWarps = CacheKV_traits::kNWarps; - - static constexpr int kBlockKSmem = CacheKV_traits::kBlockKSmem; - static constexpr int kBlockKVSmem = kHeadDimKV <= 64 ? kHeadDimKV : 64; - static_assert(kHeadDim % kBlockKSmem == 0); - static constexpr int kNReduceWarps = 4; - static constexpr int kNReduceThreads = kNReduceWarps * 32; - - - using SmemLayoutAtomQ = typename CacheKV_traits::SmemLayoutAtomQ; - - using SmemLayoutQ = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutQK = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutAtomKV = decltype( - composition(Swizzle<3, 3, 3>{}, - Layout< - Shape, Int>, - Stride, _1>>{})); - - using SmemLayoutKV_ = decltype(tile_to_shape( - SmemLayoutAtomKV{}, - Shape, Int>{})); - - using SmemLayoutKV = std::conditional_t< - kDataBits == 16, - SmemLayoutKV_, - decltype(get_nonswizzle_portion(SmemLayoutKV_{})) - >; - - constexpr static int kBlockKVSize = kDataBits == 4 ? 32 : kBlockSize; - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle<3, 3, 3>{}, - Layout, Int>, - Stride<_1, Int>>{})); - - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - - using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - - static constexpr int kThreadsPerRow = CacheKV_traits::kThreadsPerRow; - static constexpr int kThreadsKVPerRow = kThreadsPerRow / (16 / kDataBits); - static constexpr int kNThreads = CacheKV_traits::kNThreads; - - using GmemKVLayoutAtom = Layout< - Shape, Int>, - Stride, _1>>; - - using SmemCopyAtom = typename CacheKV_traits::SmemCopyAtom; - using TiledMma = typename CacheKV_traits::TiledMma; - - static constexpr int kThreadPerValue = CacheKV_traits::kThreadPerValue; - - using GmemTiledCopyQ = typename CacheKV_traits::GmemTiledCopyQ; - using GmemLayoutAtom = typename CacheKV_traits::GmemLayoutAtom; - using GmemTiledCopyKV = decltype( - make_tiled_copy(Copy_Atom, cuteType>{}, - GmemKVLayoutAtom{}, - Layout>>{})); - - - using SmemCopyAtomTransposed = typename CacheKV_traits::SmemCopyAtomTransposed; - - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>>{})); - using SmemCopyAtomO = Copy_Atom; - - using SmemLayoutAtomO = decltype( - composition(Swizzle<3, 3, 3>{}, - Layout< - Shape, Int>, - Stride, _1>>{})); - - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - - static constexpr int kShareMemSize = (size(SmemLayoutQ{}) + size(SmemLayoutQK{}) + size(SmemLayoutKV{}) * 2) * sizeof(cuteType); -}; diff --git a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_write_cache.cu b/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_write_cache.cu deleted file mode 100644 index b1cc5b1ef..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_decoder_write_cache.cu +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "../moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" - -template -__global__ void moba_decoder_attn_write_c16( - const T * qkv_out, - const T * qkv_bias, - T * q_input, - const int * cu_seq_q, - const int * cu_seq_k, - const int * seq_len_encoder, - const int * seq_len_decoder, - T * cache_k, - T * cache_v, - const int * block_tables, - const float * rope_sin_cos, - T *k_block_means, - const int head_num, - const int kv_head_num, - const int max_blocks_per_seq, - const int max_input_length) { - - int bidh = blockIdx.x; - const int bidb = blockIdx.y; - const int tidx = threadIdx.x; - const int seq_len = seq_len_decoder[bidb]; - - if (seq_len == 0) { - return; - } - - constexpr int kPackSize = 4; - using SrcType = Vec; - using rope_type = Vec; - SrcType src, bias, k_prev; - rope_type sin, cos; - const int bias_idx = bidh * kHeadDim + tidx * kPackSize; - const int ori_token_idx = cu_seq_q[bidb]; - src.load_from(qkv_out + ori_token_idx * (head_num + 2 * kv_head_num) * kHeadDim + bias_idx); - if (qkv_bias != nullptr) { - bias.load_from(qkv_bias + bias_idx); - src.add(bias); - } - - const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq; - const int32_t physical_block_number = block_table_now[seq_len / kBlockSize]; - - - if (bidh < head_num) { - const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2); - const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2); - sin.load_from(sin_rope); - cos.load_from(cos_rope); - apply_rotary_embedding(src, cos, sin); - - src.store_to(q_input + cu_seq_q[bidb] * head_num * kHeadDim + bias_idx); - } else if (bidh < head_num + kv_head_num) { - bidh -= head_num; - const int token_in_blocks = seq_len % kBlockSize; - const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2); - const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2); - sin.load_from(sin_rope); - cos.load_from(cos_rope); - apply_rotary_embedding(src, cos, sin); - - T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim; - src.store_to(cache); - - const int seq_len_block = seq_len / moba_block_size; - - const int store_mean_idx = (bidb * kMaxN + seq_len_block) * kv_head_num * kHeadDim + bidh * kHeadDim + tidx * kPackSize; - - if (seq_len % moba_block_size != 0) { - const int token_num_prev = seq_len % moba_block_size; - const float inv_tokens_sum = fdividef(1.0f, token_num_prev + 1); - k_prev.load_from(k_block_means + store_mean_idx); - - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - src.data.elt[i] = T(inv_tokens_sum * (float(src.data.elt[i]) + float(k_prev.data.elt[i]) * token_num_prev)); - } - } - - src.store_to(k_block_means + store_mean_idx); - - } else { - bidh -= head_num + kv_head_num; - const int token_in_blocks = seq_len % kBlockSize; - T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim; - src.store_to(cache); - } - -} - -void MobaDecoderAttnWriteCacheKv( - const paddle::Tensor& qkv_out, - const paddle::Tensor& q_input, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::Tensor& rope_sin_cos, - const paddle::Tensor& k_block_means, - const paddle::optional& qkv_bias, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const std::string &cache_quant_type_str) { - - constexpr int kThreads = 32; - constexpr int kHeadDim = 128; - constexpr int kMobaBlockSize = 128; - constexpr int kMaxN = 1024; - assert(kHeadDim == head_dim); - constexpr int kBlockSize = 64; - const int max_blocks_per_seq = block_tables.dims()[1]; - const int batch_size = block_tables.dims()[0]; - if (cache_quant_type_str == "none") { - dim3 grid_dims; - grid_dims.x = head_num + kv_head_num * 2; - grid_dims.y = batch_size; - if (qkv_out.dtype() == paddle::DataType::FLOAT16) { - using T = phi::dtype::float16; - moba_decoder_attn_write_c16<<>>( - qkv_out.data(), - qkv_bias ? qkv_bias.get().data() : nullptr, - const_cast(q_input.data()), - cu_seq_q.data(), - cu_seq_k.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - block_tables.data(), - rope_sin_cos.data(), - const_cast(k_block_means.data()), - head_num, - kv_head_num, - max_blocks_per_seq, - max_input_length); - } else if (qkv_out.dtype() == paddle::DataType::BFLOAT16) { - using T = phi::dtype::bfloat16; - moba_decoder_attn_write_c16<<>>( - qkv_out.data(), - qkv_bias ? qkv_bias.get().data() : nullptr, - const_cast(q_input.data()), - cu_seq_q.data(), - cu_seq_k.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - block_tables.data(), - rope_sin_cos.data(), - const_cast(k_block_means.data()), - head_num, - kv_head_num, - max_blocks_per_seq, - max_input_length); - } - } else { - PD_THROW("Only supported cache_quant_type_str in ['none']."); - } -} diff --git a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_qk_sort_decoder.cu b/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_qk_sort_decoder.cu deleted file mode 100644 index 3575cf5b8..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_decoder_attn/moba_qk_sort_decoder.cu +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" - - -template -__global__ void qk_gate_sort_decoder_kernel( - const T* qk_gate_weight, - int * qk_gate_topk_idx, - const int *decoder_seq_lens, - const int head_num, - const int kv_head_num, - const int kGqaGroupSize, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit) { - - const int bidb = blockIdx.x; - const int bidh = blockIdx.y; - const int tidx = threadIdx.x; - const int bidh_kv = bidh / kGqaGroupSize; - - if (decoder_seq_lens[bidb] == 0 || decoder_seq_lens[bidb] < use_moba_seq_limit) { - return; - } - const int seq_len = (decoder_seq_lens[bidb] + moba_block_size - 1) / moba_block_size; - - constexpr int kPackSize = kBlockMaxN / knthreads; - - static_assert(kBlockMaxN % knthreads == 0); - - T token_mean[kPackSize]; - - using SrcType = Vec; - using SrcType_f = Vec; - using SrcType_i = Vec; - - SrcType src; - SrcType_f src_f; - SrcType_i select_idx; - - select_idx.set_zero(); - - const int load_offset = bidb * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize; - - src.load_from(qk_gate_weight + load_offset); - - float max_global = -FLT_MAX; - float min_global = FLT_MAX; - - const int data_len = seq_len - tidx * kPackSize; - - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - if (i < data_len) { - src_f.data.elt[i] = float(src.data.elt[i]); - min_global = min(min_global, src_f.data.elt[i]); - } else { - src_f.data.elt[i] = -FLT_MAX; - } - max_global = max(max_global, src_f.data.elt[i]); - } - - - max_global = BlockAllReduce, knthreads>(max_global); - min_global = BlockAllReduce, knthreads>(min_global); - - - float right_limit = max_global; - float left_limit = min_global; - - float mid_limit; - int count; - - #pragma unroll - for (int i = 0; i < searchtimes; i++) { - mid_limit = (left_limit + right_limit) * 0.5f; - count = get_data_count(src_f.data.elt, mid_limit); - if (count < top_k_left) { - right_limit = mid_limit; - } else if (count > top_k_right) { - left_limit = mid_limit; - } else { - break; - } - } - - const int store_idx = bidb * kv_head_num * kBlockMaxN + bidh_kv * kBlockMaxN + tidx * kPackSize; - - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - if (src_f.data.elt[i] >= mid_limit) { - qk_gate_topk_idx[store_idx + i] = 1; - } - } - - if (tidx == 0) { - qk_gate_topk_idx[store_idx] = 1; - qk_gate_topk_idx[store_idx + seq_len - 1] = 1; - qk_gate_topk_idx[store_idx + seq_len - 2] = 1; - } -} - -template -void qk_gate_sort_decoder( - const T* qk_gate_weight, - int * qk_gate_topk_idx, - const int *decoder_seq_lens, - const int head_num, - const int kv_head_num, - const int batch_size, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit, - cudaStream_t stream) { - - const int gqa_group_size = head_num / kv_head_num; - constexpr int kPackSize = 16 / sizeof(T); - const int knthreads = kBlockMaxN / kPackSize; - dim3 grid_dims; - grid_dims.x = batch_size; - grid_dims.y = head_num; - const int searchtimes = 6; - - constexpr auto kernel = qk_gate_sort_decoder_kernel; - - kernel<<>>( - qk_gate_weight, - qk_gate_topk_idx, - decoder_seq_lens, - head_num, - kv_head_num, - gqa_group_size, - top_k_left, - top_k_right, - use_moba_seq_limit); -} - - -template -std::vector DispatchQkSortDecoder( - const paddle::Tensor& qk_gate_weight, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const int head_num, - const int kv_head_num, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit) { - - constexpr int kMobaBlockSize = 128; - constexpr int kMaxN = 1024; - - const int batch_size = seq_len_decoder.dims()[0]; - paddle::Tensor qk_gate_topk_idx = paddle::empty({batch_size, kv_head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place()); - - qk_gate_sort_decoder( - qk_gate_weight.data(), - qk_gate_topk_idx.data(), - seq_len_decoder.data(), - head_num, - kv_head_num, - batch_size, - top_k_left, - top_k_right, - use_moba_seq_limit, - qk_gate_weight.stream() - ); - - return {qk_gate_topk_idx}; -} - -std::vector QkSortDecoder( - const paddle::Tensor& qk_gate_weight, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const int head_num, - const int kv_head_num, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit) { - - if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) { - return std::move( - DispatchQkSortDecoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - head_num, - kv_head_num, - top_k_left, - top_k_right, - use_moba_seq_limit) - ); - } else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) { - return std::move( - DispatchQkSortDecoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - head_num, - kv_head_num, - top_k_left, - top_k_right, - use_moba_seq_limit) - ); - } -} - -PD_BUILD_OP(moba_qk_sort_decoder) - .Inputs({ - "qk_gate_weight", - "seq_len_encoder", - "seq_len_decoder"}) - .Attrs({ - "head_num: int", - "kv_head_num: int", - "top_k_left: int", - "top_k_right: int", - "use_moba_seq_limit: int"}) - .Outputs({"qk_gate_topk_idx"}) - .SetKernelFn(PD_KERNEL(QkSortDecoder)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h b/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h deleted file mode 100644 index 2673708e6..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/algorithm/copy.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/layout/layout.h" -#include "cutlass/numeric_types.h" -#include "cutlass/pipeline/pipeline.hpp" - -using namespace cute; - -struct moba_encoder_attn_params { - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - void * __restrict__ o_ptr; - int * __restrict__ cu_seq_q; - int * __restrict__ cu_seq_k; - int * __restrict__ qk_gate_topk_idx; - int * __restrict__ seq_len_encoder; - int * __restrict__ cu_seq_q_pack; - int head_num; - int kv_head_num; - int max_seq_q; - int max_seq_k; - int batch_size; - int gqa_group_size; - float scale_softmax_log2; - int use_moba_seq_limit; -}; - -template -struct SharedStorageQKVO { - cute::array_aligned> smem_q; - cute::array_aligned> smem_k; - union { - cute::array_aligned> smem_v; - cute::array_aligned> smem_o; - }; - struct { - cutlass::arch::ClusterTransactionBarrier barrier_Q; - typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; - typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; - }; -}; - -template -struct moba_encoder_attn_kernel_traits { - using Element = elem_type; - using ElementAccum = float; - using index_t = int32_t; - - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; - - static constexpr int UseMoba = UseMoba_; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static constexpr int kMaxN = kMaxN_; - static_assert(kHeadDim % 32 == 0); - using TileShape_MNK = Shape, Int, Int>; - using ClusterShape_MNK = Shape, Int<1>, Int<1>>; - static constexpr int kStages = kStages_; - - using AtomLayoutMNK = Layout, _1, _1>>; - using TiledMma0 = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutMNK{})); - using TiledMma1 = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(TileShape_MNK{})), - GMMA::Major::K, GMMA::Major::MN>(), - AtomLayoutMNK{})); - - using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtomK{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtomV{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); - - using SmemCopyAtomQ = Copy_Atom; - using SmemCopyAtomO = Copy_Atom; - - using SharedStorage = SharedStorageQKVO; - - static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; - static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; - static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); - static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem; - static_assert(NumMmaThreads % kNumThreadsPerRow == 0); - static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; - using TiledCopyOAtom = cute::Copy_Atom, Element>; - using TiledCopyOThrLayout = decltype(cute::make_layout( - cute::make_shape(Int{}, Int{}), - LayoutRight{})); - using TiledCopyOValLayout = decltype(cute::make_layout( - cute::make_shape(_1{}, Int{}), - LayoutRight{})); - using GmemTiledCopyO = decltype(make_tiled_copy( - TiledCopyOAtom{}, - TiledCopyOThrLayout{}, // Thr layout - TiledCopyOValLayout{} // Val layout - )); - - using MainloopPipeline = typename cutlass::PipelineTmaAsync; - using PipelineState = typename cutlass::PipelineState; -}; diff --git a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp b/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp deleted file mode 100644 index 84d41af83..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp +++ /dev/null @@ -1,473 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include "cutlass/pipeline/pipeline.hpp" - -#include "cute/tensor.hpp" - -#include "cutlass/gemm/collective/collective_builder.hpp" - -using namespace cute; - -enum class AttnNamedBarriers { - QueryEmpty = 0, - ValueEmpty = 1, - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - WarpSchedulerWG1 = 4, - WarpSchedulerWG2 = 5, - WarpSchedulerWG3 = 6, -}; - - - -template -struct CollectiveMainloopAttn { - - using Element = typename Ktraits::Element; - using TileShape_MNK = typename Ktraits::TileShape_MNK; - using ClusterShape = typename Ktraits::ClusterShape_MNK; - - static constexpr int kStages = Ktraits::kStages; - static constexpr int kHeadDim = Ktraits::kHeadDim; - static constexpr int kBlockM = Ktraits::kBlockM; - static constexpr int kBlockN = Ktraits::kBlockN; - - using ShapeT = cute::Shape; - using StrideT = cute::Shape; - using LayoutT = cute::Layout; - - - using GmemTiledCopyQ = cute::SM90_TMA_LOAD; - using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); - using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO; - - using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtomK{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutV = SmemLayoutK; - // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutVt = - decltype(cute::composition(SmemLayoutV{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), - make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); - using SmemLayoutO = typename Ktraits::SmemLayoutO; - using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO; - - using TMA_Q = decltype(make_tma_copy( - GmemTiledCopyQ{}, - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - repeat_like(StrideT{}, int32_t(0)), - StrideT{} - ), - SmemLayoutQ{}, - select<0, 2>(TileShape_MNK{}), - _1{})); // no mcast for Q - - using TMA_KV = decltype(make_tma_copy( - GmemTiledCopyKV{}, - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - repeat_like(StrideT{}, int32_t(0)), - StrideT{} - ), - take<0, 2>(SmemLayoutK{}), - select<1, 2>(TileShape_MNK{}), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); - using MainloopPipeline = typename Ktraits::MainloopPipeline; - using PipelineParams = typename MainloopPipeline::Params; - using PipelineState = typename MainloopPipeline::PipelineState; - - // Set the bytes transferred in this TMA transaction (may involve multiple issues) - static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); - - static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; - - // Host side kernel arguments - struct Arguments { - Element const* ptr_Q; - LayoutT layout_Q; - Element const* ptr_K; - LayoutT layout_K; - Element const* ptr_V; - LayoutT layout_V; - float const softmax_scale_log2; - }; - - // Device side kernel params - struct Params { - LayoutT layout_Q; - LayoutT layout_K; - LayoutT layout_V; - cutlass::FastDivmod qhead_per_khead_divmod; - TMA_Q tma_load_Q; - TMA_KV tma_load_K, tma_load_V; - float const softmax_scale_log2; - }; - - - static Params - to_underlying_arguments(Arguments const& args) { - Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q); - TMA_Q tma_load_Q = make_tma_copy( - GmemTiledCopyQ{}, - mQ, - SmemLayoutQ{}, - select<0, 2>(TileShape_MNK{}), - _1{}); // no mcast for Q - Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); - TMA_KV tma_load_K = make_tma_copy( - GmemTiledCopyKV{}, - mK, - SmemLayoutK{}(_, _, _0{}), - select<1, 2>(TileShape_MNK{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); - TMA_KV tma_load_V = make_tma_copy( - GmemTiledCopyKV{}, - mV, - SmemLayoutV{}(_, _, _0{}), - select<1, 2>(TileShape_MNK{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {args.layout_Q, args.layout_K, args.layout_V, - cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), - tma_load_Q, tma_load_K, tma_load_V, - args.softmax_scale_log2}; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); - } - - template - CUTLASS_DEVICE auto get_local_tile_tensor( - const MTensor &m_tensor, - const Shape &tile_shape, - const int *cu_seq_len, - const int bidh, - const int bidb, - const int actual_seq_len) const { - auto g_offset = local_tile( - m_tensor(_, _, bidh), - cute::make_shape(1, get<1>(tile_shape)), - make_coord(cu_seq_len[bidb], _0{})); - auto g_sequence = make_tensor( - g_offset.data(), - make_layout( - cute::make_shape(actual_seq_len, get<1>(tile_shape)), - g_offset.stride() - )); - auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); - return g_tensor; - } - - - template - CUTLASS_DEVICE void - load(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipeline pipeline_v, - PipelineState& smem_pipe_write_k, - PipelineState& smem_pipe_write_v, - SharedStorage &shared_storage, - const int *qk_gate_topk_idx, - const int n_block_max, - const int m_block, - const int bidh, - const int bidb, - const int *cu_seq_q, - const int *cu_seq_k, - const int seq_len_q, - const int seq_len_k) { - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); - - Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); - Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); - int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); - - Tensor gQ = get_local_tile_tensor( - mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block); - Tensor gK = get_local_tile_tensor( - mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k); - Tensor gV = get_local_tile_tensor( - mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k); - - Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); - Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); - auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); - auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK)); - auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV)); - - uint16_t mcast_mask_kv = 0; - - int n_block = n_block_max - 1; - - int lane_predicate = cute::elect_one_sync(); - - if (lane_predicate) { - shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); - } - - - if (lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write_k); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); - ++smem_pipe_write_k; - } - - if (lane_predicate) { - int idx = 0; - #pragma unroll 2 - for (; n_block > 0; ) { - pipeline_k.producer_acquire(smem_pipe_write_k); - int pre_idx = 1; - if constexpr (UseMoba) { - pre_idx = qk_gate_topk_idx[idx]; - } - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - pre_idx), tKsK(_, smem_pipe_write_k.index())); - - ++smem_pipe_write_k; - pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); - ++smem_pipe_write_v; - n_block -= pre_idx; - idx += 1; - } - } - if (lane_predicate) { - pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); - ++smem_pipe_write_v; - } - } - - CUTLASS_DEVICE void - warp_scheduler_barrier_sync() { - if constexpr (UseSchedulerBarrier) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/); - } - } - - CUTLASS_DEVICE void - mma_init() { - if constexpr (!UseSchedulerBarrier) { return; } - static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); - if (cutlass::canonical_warp_group_idx() > 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); - } - if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { - if (cutlass::canonical_warp_group_idx() > 2) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/); - } - } - - } - - CUTLASS_DEVICE void - warp_scheduler_barrier_arrive() { - if constexpr (!UseSchedulerBarrier) { return; } - static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); - if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); - } else { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); - } - } - - - template - CUTLASS_DEVICE void - mma(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipeline pipeline_v, - PipelineState& smem_pipe_read_k, - PipelineState& smem_pipe_read_v, - FrgTensorO& tOrO, - Softmax& softmax, - const int *qk_gate_topk_idx, - const int n_block_max, - const int thread_idx, - const int m_block, - const int seq_len_q, - const int seq_len_k, - SharedStorage& shared_storage) { - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); - - typename Ktraits::TiledMma0 tiled_mma0; - typename Ktraits::TiledMma1 tiled_mma1; - auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); - auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); - - Tensor tSrQ = threadMma0.partition_fragment_A(sQ); - Tensor tSrK = threadMma0.partition_fragment_B(sK); - Tensor tOrV = threadMma1.partition_fragment_B(sVt); - - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - }; - - tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - - int n_block = n_block_max - 1; - - cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(0)); - if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); } - - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - warp_scheduler_barrier_sync(); - gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - warp_scheduler_barrier_arrive(); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read_k); - ++smem_pipe_read_k; - - auto col_limit_causal = [&](int row, int n_block) { - return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM; - }; - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); - Tensor tScS = threadMma0.partition_C(cS); - #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= - std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) { - tSrS(i) = -INFINITY; - } - } - - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); - Tensor scores_scale = make_fragment_like(softmax.row_max); - clear(scores_scale); - - int idx = 0; - #pragma unroll 2 - for (; n_block > 0; ) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - warp_scheduler_barrier_sync(); - gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - softmax.rescale_o(tOrO, scores_scale); - consumer_wait(pipeline_v, smem_pipe_read_v); - gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - warp_scheduler_barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read_k); // release K - cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - ++smem_pipe_read_k; - ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); - if constexpr (UseMoba) { - n_block -= qk_gate_topk_idx[idx]; - idx += 1; - } else { - n_block -= 1; - } - } - - softmax.rescale_o(tOrO, scores_scale); - consumer_wait(pipeline_v, smem_pipe_read_v); - gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); - ++smem_pipe_read_v; - - softmax.rescale_o(tOrO, scores_scale); - } - - template - CUTLASS_DEVICE void - store(Params const& mainloop_params, - FrgTensorO const& tOrO, - SharedStorage& shared_storage, - TiledMma tiled_mma, - int thread_idx, - const int o_head_stride, - const int real_seq, - T * out_ptr) { - - Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); - auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - - Tensor tOrO_out = convert_type(tOrO); - Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); - - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(AttnNamedBarriers::ValueEmpty) /*id*/); - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - - Tensor gO = make_tensor(make_gmem_ptr(out_ptr), - Shape, Int>{}, - make_stride(o_head_stride, _1{})); - - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - - Tensor cO = make_identity_tensor(Shape, Int>{}); - - Tensor tOcO = gmem_thr_copy_O.partition_S(cO); - - if (real_seq >= kBlockM) { - copy(gmem_tiled_copy_O, tOsO, tOgO, tOcO); - } else { - copy(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq); - } - } - -}; diff --git a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_attn.cu b/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_attn.cu deleted file mode 100644 index 29d6564ff..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_attn.cu +++ /dev/null @@ -1,384 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include - -#include "cutlass/util/print_error.hpp" -#include "cutlass/util/GPU_Clock.hpp" -#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -# include "cutlass/util/cublas_wrappers.hpp" -#endif -#include "moba_attn/moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" -#include "kernel_traits.h" -#include "mainloop_attn.hpp" -#include "softmax.hpp" -#include "cutlass/arch/reg_reconfig.h" - -template -auto get_gmem_layout(int token_num, int head_num) { - return make_layout( - make_shape(token_num, kHeadDim, head_num), - make_stride(head_num * kHeadDim, _1{}, kHeadDim)); -} - -template -__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - moba_encoder_attention_kernel( - CUTE_GRID_CONSTANT typename CollectiveMainloopAttn::Params const mainloop_params, - CUTE_GRID_CONSTANT moba_encoder_attn_params const data_params) { - - using Element = typename Ktraits::Element; - using ElementAccum = typename Ktraits::ElementAccum; - using SoftType = ElementAccum; - using TileShape_MNK = typename Ktraits::TileShape_MNK; - using ClusterShape = typename Ktraits::ClusterShape_MNK; - - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); - static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = Ktraits::kBlockM; - static constexpr int kBlockN = Ktraits::kBlockN; - constexpr int kHeadDim = Ktraits::kHeadDim; - constexpr int kMaxN = Ktraits::kMaxN; - - using CollectiveMainloop = CollectiveMainloopAttn; - - using MainloopPipeline = typename Ktraits::MainloopPipeline; - using PipelineParams = typename MainloopPipeline::Params; - using PipelineState = typename MainloopPipeline::PipelineState; - - extern __shared__ char shared_memory[]; - auto &shared_storage = *reinterpret_cast(shared_memory); - - const int m_block = blockIdx.x; - const int bidh = blockIdx.y; - const int bidb = blockIdx.z; - - const int seq_len_q = data_params.seq_len_encoder[bidb]; - const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb]; - - - if (seq_len_q == 0) { - return; - } - - __align__(16) __shared__ int qk_gate_topk_idx[kMaxN]; - const int *qk_gate_idx_cur_offset = data_params.qk_gate_topk_idx + data_params.cu_seq_q_pack[bidb] / kBlockM * data_params.head_num * kMaxN + (m_block * data_params.head_num + bidh) * kMaxN; - - #pragma unroll - for (int i = threadIdx.x; i < kMaxN / 4; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) { - reinterpret_cast(qk_gate_topk_idx)[i] = reinterpret_cast(qk_gate_idx_cur_offset)[i]; - } - - - const int n_block_max = min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN), cute::ceil_div(seq_len_k, kBlockN)); - - if (m_block * kBlockM >= seq_len_q) { - return; - } - - int const lane_predicate = cute::elect_one_sync(); - int const warp_idx = cutlass::canonical_warp_idx_sync(); - - if (warp_idx == 0 && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); - } - - // Obtain warp index - int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; - - PipelineParams pipeline_params; - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; - int warp_group_idx = cutlass::canonical_warp_group_idx(); - pipeline_params.role = warp_group_idx == 0 - ? MainloopPipeline::ThreadCategory::Producer - : MainloopPipeline::ThreadCategory::Consumer; - pipeline_params.is_leader = warp_group_thread_idx == 0; - pipeline_params.num_consumers = NumMmaThreads; - - if (warp_idx == 0 && lane_predicate) { - shared_storage.barrier_Q.init(1); - } - - MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); - MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); - - __syncthreads(); - - CollectiveMainloop collective_mainloop; - - if (warp_group_idx == 0) { - cutlass::arch::warpgroup_reg_dealloc(); - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - if (warp_idx_in_warpgroup == 0) { - PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); - PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); - - collective_mainloop.load( - mainloop_params, - pipeline_k, - pipeline_v, - smem_pipe_write_k, - smem_pipe_write_v, - shared_storage, - qk_gate_topk_idx, - n_block_max, - m_block, - bidh, - bidb, - data_params.cu_seq_q, - data_params.cu_seq_k, - seq_len_q, - seq_len_k); - } - } else { - cutlass::arch::warpgroup_reg_alloc(); - typename Ktraits::TiledMma1 tiled_mma1; - - collective_mainloop.mma_init(); - - PipelineState smem_pipe_read_k, smem_pipe_read_v; - - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); - Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; - - collective_mainloop.mma( - mainloop_params, - pipeline_k, - pipeline_v, - smem_pipe_read_k, - smem_pipe_read_v, - tOrO, - softmax, - qk_gate_topk_idx, - n_block_max, - threadIdx.x - NumCopyThreads, - m_block, - seq_len_q, - seq_len_k, - shared_storage); - - const int o_head_stride = data_params.head_num * kHeadDim; - const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim; - - const int real_seq = seq_len_q - m_block * kBlockM; - - collective_mainloop.store( - mainloop_params, - tOrO, - shared_storage, - tiled_mma1, - threadIdx.x - NumCopyThreads, - o_head_stride, - real_seq, - reinterpret_cast(data_params.o_ptr) + store_offset); - } - -} - - -template -void run_moba_decoder_attn(moba_encoder_attn_params ¶ms, cudaStream_t stream) { - using Element = typename Kernel_traits::Element; - using TileShape_MNK = typename Kernel_traits::TileShape_MNK; - using ClusterShape = typename Kernel_traits::ClusterShape_MNK; - - using CollectiveMainloop = CollectiveMainloopAttn; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - typename CollectiveMainloop::Params mainloop_params = - CollectiveMainloop::to_underlying_arguments({ - static_cast(params.q_ptr), - get_gmem_layout(params.max_seq_q * params.batch_size, params.head_num), - static_cast(params.k_ptr), - get_gmem_layout(params.max_seq_k * params.batch_size, params.kv_head_num), - static_cast(params.v_ptr), - get_gmem_layout(params.max_seq_k * params.batch_size, params.kv_head_num), - params.scale_softmax_log2 - }); - - int num_blocks_m = cutlass::ceil_div(params.max_seq_q, Kernel_traits::kBlockM); - - num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); - - void *kernel; - kernel = (void *)moba_encoder_attention_kernel; - int smem_size = sizeof(typename Kernel_traits::SharedStorage); - - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - - dim3 grid_dims; - grid_dims.x = num_blocks_m; - grid_dims.y = params.head_num; - grid_dims.z = params.batch_size; - - static constexpr int ctaSize = Kernel_traits::kNWarps * 32; - dim3 block_dims(ctaSize); - dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); - cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; - cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params); -} - - -template -void run_moba_encoder_attn_hdim128(moba_encoder_attn_params ¶ms, cudaStream_t stream) { - - constexpr static int Headdim = 128; - constexpr static int kNWarps = kBlockM / 16 + 4; - constexpr static int kStages = 2; - - using Ktraits = moba_encoder_attn_kernel_traits; - run_moba_decoder_attn(params, stream); -} - -template -void DispatchMobaEncoderAttn( - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& qk_gate_topk_idx, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& out, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int head_dim, - const int batch_size, - const int max_input_length) { - - constexpr int kBlockM = 128; - constexpr int kBlockN = 128; - constexpr int kMobaBlockSize = 128; - constexpr int kMaxN = 1024; - - using cute_type = typename cuteType::type; - - moba_encoder_attn_params params; - memset(¶ms, 0, sizeof(moba_encoder_attn_params)); - - params.q_ptr = reinterpret_cast(const_cast(q_input.data())); - params.k_ptr = reinterpret_cast(const_cast(k_input.data())); - params.v_ptr = reinterpret_cast(const_cast(v_input.data())); - params.o_ptr = reinterpret_cast(const_cast(out.data())); - params.cu_seq_q = const_cast(cu_seq_q.data()); - params.cu_seq_k = const_cast(cu_seq_k.data()); - params.head_num = head_num; - params.kv_head_num = kv_head_num; - params.max_seq_q = max_seq_q; - params.max_seq_k = max_seq_k; - params.batch_size = batch_size; - params.gqa_group_size = head_num / kv_head_num; - constexpr float kLog2e = 1.4426950408889634074; - params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e; - params.qk_gate_topk_idx = const_cast(qk_gate_topk_idx.data()); - params.seq_len_encoder = const_cast(seq_len_encoder.data()); - params.cu_seq_q_pack = const_cast(cu_seq_q_pack.data()); - - run_moba_encoder_attn_hdim128(params, out.stream()); -} - -void MobaEncoderAttn( - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& qk_gate_topk_idx, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& out, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length) { - - const int batch_size = seq_len_encoder.dims()[0]; - if (q_input.dtype() == paddle::DataType::FLOAT16) { - return - DispatchMobaEncoderAttn( - q_input, - k_input, - v_input, - qk_gate_topk_idx, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - seq_len_encoder, - seq_len_decoder, - out, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - head_dim, - batch_size, - max_input_length); - } else if (q_input.dtype() == paddle::DataType::BFLOAT16) { - return - DispatchMobaEncoderAttn( - q_input, - k_input, - v_input, - qk_gate_topk_idx, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - seq_len_encoder, - seq_len_decoder, - out, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - head_dim, - batch_size, - max_input_length); - } -} - - -PD_BUILD_OP(moba_encoder_attn) - .Inputs({ - "q_input", - "k_input", - "v_input", - "qk_gate_topk_idx", - "cu_seq_q", - "cu_seq_k", - "cu_seq_q_pack", - "seq_len_encoder", - "seq_len_decoder", - "out"}) - .Attrs({ - "max_seq_q: int", - "max_seq_k: int", - "head_num: int", - "kv_head_num: int", - "head_dim: int", - "max_input_length: int"}) - .Outputs({"attn_out"}) - .SetInplaceMap({{"out", "attn_out"}}) - .SetKernelFn(PD_KERNEL(MobaEncoderAttn)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_write_cache.cu b/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_write_cache.cu deleted file mode 100644 index c40da4c24..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_encoder_write_cache.cu +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn.h" - - -template -__global__ void write_encoder_cachekv_c16( - const T * k_input, - const T * v_input, - const int * cu_seq_k, - const int * seq_len_encoder, - const int * seq_len_decoder, - T * cache_k, - T * cache_v, - const int * block_tables, - const int kv_head_num, - const int max_blocks_per_seq) { - - constexpr int kPackSize = 16 / sizeof(T); - const int block_idx = blockIdx.x * kBlockSize; - int bidh = blockIdx.y; - const int bidb = blockIdx.z; - const int tidx = threadIdx.x; - const int row_idx = tidx / (kHeadDim / kPackSize); - const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize; - const int seq_len = seq_len_encoder[bidb]; - - if (seq_len == 0) return; - - const int ramian_tokens = seq_len - block_idx; - - const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq; - const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize]; - - if (bidh < kv_head_num) { - T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx; - const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx; - - #pragma unroll - for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) { - if (i < ramian_tokens) { - *reinterpret_cast(cache + i * kHeadDim) = *reinterpret_cast(k_input + base_load_idx + i * kv_head_num * kHeadDim); - } - } - } else { - bidh -= kv_head_num; - const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx; - T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx; - - #pragma unroll - for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) { - if (i < ramian_tokens) { - *reinterpret_cast(cache + i * kHeadDim) = *reinterpret_cast(v_input + base_load_idx + i * kv_head_num * kHeadDim); - } - } - - } -} - -void MobaEncoderAttnWriteCacheKv( - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::optional& cache_k_quant_scale, - const paddle::optional& cache_v_quant_scale, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_seq_q, - const std::string &cache_quant_type_str) { - - constexpr int kThreads = 128; - constexpr int kHeadDim = 128; - assert(kHeadDim == head_dim); - constexpr int kBlockSize = 64; - const int batch_size = block_tables.dims()[0]; - const int max_blocks_per_seq = block_tables.dims()[1]; - if (cache_quant_type_str == "none") { - dim3 grid_dims; - grid_dims.x = (max_seq_q + kBlockSize - 1) / kBlockSize; - grid_dims.y = kv_head_num * 2; - grid_dims.z = batch_size; - if (k_input.dtype() == paddle::DataType::FLOAT16) { - using T = phi::dtype::float16; - write_encoder_cachekv_c16<<>>( - const_cast(k_input.data()), - const_cast(v_input.data()), - cu_seq_k.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - block_tables.data(), - kv_head_num, - max_blocks_per_seq); - } else if (k_input.dtype() == paddle::DataType::BFLOAT16) { - using T = phi::dtype::bfloat16; - write_encoder_cachekv_c16<<>>( - const_cast(k_input.data()), - const_cast(v_input.data()), - cu_seq_k.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - block_tables.data(), - kv_head_num, - max_blocks_per_seq); - } - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Quantized cache not implemented for cache_quant_type = %s", cache_quant_type_str.c_str())); - } -} - -PD_BUILD_OP(moba_encoder_attn_write_cache_kv) - .Inputs({ - "k_input", - "v_input", - "cu_seq_k", - "seq_len_encoder", - "seq_len_decoder", - "cache_k", - "cache_v", - "block_tables", - paddle::Optional("cache_k_quant_scale"), - paddle::Optional("cache_v_quant_scale"), - paddle::Optional("cache_k_dequant_scale"), - paddle::Optional("cache_v_dequant_scale"), - paddle::Optional("cache_k_zero_points"), - paddle::Optional("cache_v_zero_points")}) - .Attrs({ - "head_num: int", - "kv_head_num: int", - "head_dim: int", - "max_seq_q: int", - "cache_quant_type_str: std::string"}) - .Outputs({"cache_k_out", "cache_v_out"}) - .SetInplaceMap({{"cache_k", "cache_k_out"}, - {"cache_v", "cache_v_out"}}) - .SetKernelFn(PD_KERNEL(MobaEncoderAttnWriteCacheKv)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_qk_sort_encoder.cu b/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_qk_sort_encoder.cu deleted file mode 100644 index 2b190ede5..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/moba_qk_sort_encoder.cu +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn_utils.hpp" - -template -__global__ void qk_gate_sort_encoder_kernel( - const T* qk_gate_weight, - int * qk_gate_topk_idx, - const int *seq_len_encoder, - const int *seq_len_decoder, - const int* cu_seq_q, - const int* cu_seq_k, - const int* cu_seq_q_pack, - const int use_moba_seq_limit, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int kGqaGroupSize, - const int top_k_left, - const int top_k_right) { - - const int bidt = blockIdx.x * kBlockM; - const int bidh = blockIdx.y; - const int bidb = blockIdx.z; - const int tidx = threadIdx.x; - - constexpr int kPackSize = kBlockMaxN / knthreads; - - static_assert(kBlockMaxN % knthreads == 0); - - const int seq_len_q = seq_len_encoder[bidb]; - - if (seq_len_q == 0 || bidt >= seq_len_q) { - return; - } - - const int seq_len_k = (bidt + kBlockM + seq_len_decoder[bidb]); - - const int seq_len_moba = seq_len_k / moba_block_size; - - using SrcType = Vec; - using SrcType_f = Vec; - using SrcType_i = Vec; - - SrcType src; - SrcType_f src_f; - - SrcType_i select_idx; - - select_idx.set_zero(); - - const int store_idx = cu_seq_q_pack[bidb] / kBlockM * head_num * kBlockMaxN + bidh * kBlockMaxN + blockIdx.x * head_num * kBlockMaxN + tidx * kPackSize; - - if (seq_len_k < use_moba_seq_limit) { - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - select_idx.data.elt[i] = 1; - } - select_idx.store_to(qk_gate_topk_idx + store_idx); - return; - } - - const int load_offset = (cu_seq_q[bidb] + bidt) * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize; - const int data_len = seq_len_moba - tidx * kPackSize; - - #pragma unroll - for (int t = 0; t < kBlockM; t++) { - if (bidt + t >= seq_len_q) { - break; - } - src.load_from(qk_gate_weight + load_offset + t * head_num * kBlockMaxN); - float min_global = FLT_MAX; - float max_global = -FLT_MAX; - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - if (i < data_len) { - src_f.data.elt[i] = float(src.data.elt[i]); - min_global = min(min_global, src_f.data.elt[i]); - } else { - src_f.data.elt[i] = -FLT_MAX; - } - max_global = max(max_global, src_f.data.elt[i]); - } - - max_global = BlockAllReduce, knthreads>(max_global); - min_global = BlockAllReduce, knthreads>(min_global); - - float right_limit = max_global; - float left_limit = min_global; - - float mid_limit; - int count; - - if (right_limit == left_limit) { - mid_limit = (left_limit + right_limit) * 0.5f; - } else { - #pragma unroll - for (int i = 0; i < searchtimes; i++) { - mid_limit = (left_limit + right_limit) * 0.5f; - count = get_data_count(src_f.data.elt, mid_limit); - if (count < top_k_left) { - right_limit = mid_limit; - } else if (count > top_k_right) { - left_limit = mid_limit; - } else { - break; - } - } - } - - #pragma unroll - for (int i = 0; i < kPackSize; i++) { - if (src_f.data.elt[i] >= mid_limit) { - select_idx.data.elt[i] = 1; - } - } - } - - if (tidx == 0) { - select_idx.data.elt[0] = 1; - } - - __align__(16) __shared__ int qk_gate_mem[kBlockMaxN]; - __align__(16) __shared__ int qk_continue_idx_mem[kBlockMaxN]; - select_idx.store_to(qk_gate_mem + tidx * kPackSize); - - __syncthreads(); - - if (tidx == 0) { - int cur_idx = 0; - int idx = -1; - const int last_idx = seq_len_moba - 1; - while (last_idx + idx >= 0 && qk_gate_mem[last_idx + idx] == 0) { - idx--; - } - qk_continue_idx_mem[cur_idx] = -idx; - cur_idx++; - - for (int i = last_idx - 1; i >= 0; --i) { - if (qk_gate_mem[i] == 1) { - int idx = -1; - while (i + idx >= 0 && qk_gate_mem[i + idx] == 0) { - idx--; - } - qk_continue_idx_mem[cur_idx] = -idx; - cur_idx++; - } - } - qk_continue_idx_mem[cur_idx] = 10000000; - } - - __syncthreads(); - - *reinterpret_cast(qk_gate_topk_idx + store_idx) = reinterpret_cast(qk_continue_idx_mem)[tidx]; -} - -template -void qk_gate_sort_encoder( - const T* qk_gate_weight, - int * qk_gate_topk_idx, - const int *seq_len_encoder, - const int *seq_len_decoder, - const int* cu_seq_q, - const int* cu_seq_k, - const int* cu_seq_q_pack, - const int use_moba_seq_limit, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int batch_size, - const int top_k_left, - const int top_k_right, - cudaStream_t stream) { - - constexpr int kPackSize = 16 / sizeof(T); - - const int gqa_group_size = head_num / kv_head_num; - const int knthreads = kMaxN / kPackSize; - const int searchtimes = 6; - - dim3 grid_dims; - grid_dims.x = (max_seq_q + kBlockM - 1) / kBlockM; - grid_dims.y = head_num; - grid_dims.z = batch_size; - - constexpr auto kernel = qk_gate_sort_encoder_kernel; - - kernel<<>>( - qk_gate_weight, - qk_gate_topk_idx, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - use_moba_seq_limit, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - gqa_group_size, - top_k_left, - top_k_right); -} -template -std::vector DispatchQkSortEncoder( - const paddle::Tensor& qk_gate_weight, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& q_pack_tokens, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit) { - constexpr int kBlockM = 128; - constexpr int kBlockN = 128; - constexpr int kMobaBlockSize = 128; - constexpr int kMaxN = 1024; - using cute_type = typename cuteType::type; - const int batch_size = seq_len_encoder.dims()[0]; - - paddle::Tensor qk_gate_topk_idx = paddle::empty({q_pack_tokens.data()[0] / kBlockM, head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place()); - - qk_gate_sort_encoder( - reinterpret_cast(qk_gate_weight.data()), - qk_gate_topk_idx.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_q.data(), - cu_seq_k.data(), - cu_seq_q_pack.data(), - use_moba_seq_limit, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - batch_size, - top_k_left, - top_k_right, - qk_gate_weight.stream()); - - return {qk_gate_topk_idx}; -} - - -std::vector QkSortEncoder( - const paddle::Tensor& qk_gate_weight, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& cu_seq_q_pack, - const paddle::Tensor& q_pack_tokens, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int top_k_left, - const int top_k_right, - const int use_moba_seq_limit) { - if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) { - return std::move( - DispatchQkSortEncoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - q_pack_tokens, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - top_k_left, - top_k_right, - use_moba_seq_limit - ) - ); - } else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) { - return std::move( - DispatchQkSortEncoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - q_pack_tokens, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - top_k_left, - top_k_right, - use_moba_seq_limit - ) - ); - } -} - -PD_BUILD_OP(moba_qk_sort_encoder) - .Inputs({ - "qk_gate_weight", - "seq_len_encoder", - "seq_len_decoder", - "cu_seq_q", - "cu_seq_k", - "cu_seq_q_pack", - "q_pack_tokens"}) - .Attrs({ - "max_seq_q: int", - "max_seq_k: int", - "head_num: int", - "kv_head_num: int", - "top_k_left: int", - "top_k_right: int", - "use_moba_seq_limit: int"}) - .Outputs({"qk_gate_topk_idx"}) - .SetKernelFn(PD_KERNEL(QkSortEncoder)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp b/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp deleted file mode 100644 index fc7832d15..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include - -#include - -#include "../moba_attn_utils.hpp" - -using namespace cute; - -template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } - } -} - -template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++){ - dst(i) = Allreduce<4>::run(src(i), op); - } -} - -template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); -} - -template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); -} - -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); - if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } -} - -__forceinline__ __device__ __half2 half_exp(__half2 x) { - uint32_t tmp_out, tmp_in; - tmp_in = reinterpret_cast(x); - asm ("ex2.approx.f16x2 %0, %1;\n" - : "=r"(tmp_out) - : "r"(tmp_in)); - __half2 out = reinterpret_cast<__half2&>(tmp_out); - return out; -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; - sum(mi) = 0; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - } -} - - -template -__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const float max_scaled = max(mi) * scale; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - } - } -} - - -template -struct Softmax { - - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; - - CUTLASS_DEVICE Softmax() {}; - - template - __forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) { - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - TensorT scores_scale; - if constexpr (Is_first) { - reduce_max(scores, row_max); - cute::fill(scores_scale, 1.f); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - reduce_max(scores, row_max); - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = row_max(mi); - scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - row_sum(mi) *= scores_scale(mi); - } - } - return scores_scale; - }; - - template - __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) { - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - TensorT scores_scale; - if constexpr (Is_first) { - reduce_max(scores, row_max); - scale_apply_exp2(scores, row_max, softmax_scale_log2); - reduce_sum(scores, row_sum); - cute::fill(scores_scale, 1.f); - } else { - scale_apply_exp2(scores, row_max, softmax_scale_log2); - reduce_sum(scores, row_sum); - } - return scores_scale; - }; - - __forceinline__ __device__ TensorT finalize(float softmax_scale_log2) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT scores_scale; - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float sum = row_sum(mi); - float inv_sum = 1.0f / sum; - row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); - scores_scale(mi) = inv_sum; - } - return scores_scale; - }; - - template - __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { - Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale(mi); - } - } - }; - -}; diff --git a/custom_ops/gpu_ops/moba_attn/moba_process/moba_get_kv_from_cache.cu b/custom_ops/gpu_ops/moba_attn/moba_process/moba_get_kv_from_cache.cu deleted file mode 100644 index 712a8e62d..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_process/moba_get_kv_from_cache.cu +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" - -template -__global__ void get_kv_from_cache_c16_kernel( - T * k_input, - T * v_input, - const int * seq_len_encoder, - const int * seq_len_decoder, - const int * cu_seq_k, - const T * cache_k, - const T * cache_v, - const int * block_tables, - const int kv_head_num, - const int head_dim, - const int batch_size, - const int max_input_length, - const int max_blocks_per_seq) { - - const int block_idx = blockIdx.x; - int bidh = blockIdx.y; - const int bidb = blockIdx.z; - const int seq_len = seq_len_decoder[bidb] + seq_len_encoder[bidb]; - const int tidx = threadIdx.x; - const int base_token_idx = block_idx * kBlockSize; - - if (base_token_idx >= seq_len || seq_len_encoder[bidb] == 0) { - return; - } - - constexpr int kPackSize = 16 / sizeof(T); - - const int row_idx = tidx / (kHeadDim / kPackSize); - const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize; - const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx]; - - - const int ramian_tokens = seq_len - base_token_idx; - - if (bidh < kv_head_num) { - const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx; - const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx; - #pragma unroll - for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) { - if (i < ramian_tokens) { - *reinterpret_cast(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast(cache_k + cache_offset + i * kHeadDim); - } - } - } else { - bidh -= kv_head_num; - const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx; - const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx; - #pragma unroll - for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) { - if (i < ramian_tokens) { - *reinterpret_cast(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast(cache_v + cache_offset + i * kHeadDim); - } - } - } -} - -template -void get_kv_from_cache( - T * k_input, - T * v_input, - const int * seq_len_encoder, - const int * seq_len_decoder, - const int * cu_seq_k, - const void * cache_k, - const void * cache_v, - const int * block_tables, - const T * cache_k_dequant_scale, - const T * cache_v_dequant_scale, - const T * cache_k_zero_points, - const T * cache_v_zero_points, - const int kv_head_num, - const int head_dim, - const int max_seq_k, - const int batch_size, - const int max_input_length, - const int max_blocks_per_seq, - const std::string &cache_quant_type_str, - cudaStream_t stream) { - - constexpr int kThreads = 128; - constexpr int kHeadDim = 128; - assert(kHeadDim == head_dim); - constexpr int kBlockSize = 64; - if (cache_quant_type_str == "none") { - dim3 grid_dims; - grid_dims.x = (max_seq_k + kBlockSize - 1) / kBlockSize; - grid_dims.y = kv_head_num * 2; - grid_dims.z = batch_size; - get_kv_from_cache_c16_kernel<<>>( - k_input, - v_input, - seq_len_encoder, - seq_len_decoder, - cu_seq_k, - reinterpret_cast(cache_k), - reinterpret_cast(cache_v), - block_tables, - kv_head_num, - head_dim, - batch_size, - max_input_length, - max_blocks_per_seq); - } else { - PD_THROW("Only supported cache_quant_type_str in ['none']."); - } -} - -void GetKVFromCache( - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cache_k, - const paddle::Tensor& cache_v, - const paddle::Tensor& block_tables, - const paddle::optional& cache_k_dequant_scale, - const paddle::optional& cache_v_dequant_scale, - const paddle::optional& cache_k_zero_points, - const paddle::optional& cache_v_zero_points, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const int max_seq_k, - const std::string &cache_quant_type_str) { - - if (k_input.dtype() == paddle::DataType::FLOAT16) { - using T = phi::dtype::float16; - using cute_type = typename cuteType::type; - get_kv_from_cache( - reinterpret_cast(const_cast(k_input.data())), - reinterpret_cast(const_cast(v_input.data())), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_k.data(), - cache_k.data(), - cache_v.data(), - block_tables.data(), - cache_k_dequant_scale ? reinterpret_cast(const_cast(cache_k_dequant_scale.get().data())) : nullptr, - cache_v_dequant_scale ? reinterpret_cast(const_cast(cache_v_dequant_scale.get().data())) : nullptr, - cache_k_zero_points ? reinterpret_cast(const_cast(cache_k_zero_points.get().data())) : nullptr, - cache_v_zero_points ? reinterpret_cast(const_cast(cache_v_zero_points.get().data())) : nullptr, - kv_head_num, - head_dim, - max_seq_k, - seq_len_encoder.dims()[0], - max_input_length, - block_tables.dims()[1], - cache_quant_type_str, - k_input.stream()); - } else if (k_input.dtype() == paddle::DataType::BFLOAT16) { - using T = phi::dtype::bfloat16; - using cute_type = typename cuteType::type; - get_kv_from_cache( - reinterpret_cast(const_cast(k_input.data())), - reinterpret_cast(const_cast(v_input.data())), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_k.data(), - cache_k.data(), - cache_v.data(), - block_tables.data(), - cache_k_dequant_scale ? reinterpret_cast(const_cast(cache_k_dequant_scale.get().data())) : nullptr, - cache_v_dequant_scale ? reinterpret_cast(const_cast(cache_v_dequant_scale.get().data())) : nullptr, - cache_k_zero_points ? reinterpret_cast(const_cast(cache_k_zero_points.get().data())) : nullptr, - cache_v_zero_points ? reinterpret_cast(const_cast(cache_v_zero_points.get().data())) : nullptr, - kv_head_num, - head_dim, - max_seq_k, - seq_len_encoder.dims()[0], - max_input_length, - block_tables.dims()[1], - cache_quant_type_str, - k_input.stream()); - } -} - -__global__ void get_cur_cu_seq_len_k_kernel( - const int* __restrict__ seq_lens_encoder, - const int* __restrict__ seq_lens_decoder, - const int* __restrict__ seq_lens_this_time, - int* __restrict__ cu_seqlens_k, - int* __restrict__ cu_seq_q_pack, - int* __restrict__ q_pack_tokens, - const int pack_size, - const int bsz) { - - int total_tokens = 0; - cu_seqlens_k[0] = 0; - cu_seq_q_pack[0] = 0; - - for (uint32_t bid = 0; bid < bsz; bid++) { - int cache_len = seq_lens_decoder[bid]; - const int q_len = seq_lens_encoder[bid]; - if (q_len <= 0) { - cache_len = 0; - } - total_tokens += (cache_len + q_len); - cu_seqlens_k[bid + 1] = total_tokens; - cu_seq_q_pack[bid + 1] = cu_seq_q_pack[bid] + (q_len + pack_size -1) / pack_size * pack_size; - } - q_pack_tokens[0] = cu_seq_q_pack[bsz]; -} - -std::vector GetCurCuSeqLenk( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const int pack_size) { - auto stream = seq_lens_decoder.stream(); - auto place = seq_lens_decoder.place(); - int bsz = seq_lens_this_time.shape()[0]; - - paddle::Tensor cu_seq_q_pack = paddle::empty({bsz + 1}, paddle::DataType::INT32, place); - paddle::Tensor cu_seqlens_k = paddle::empty({bsz + 1}, paddle::DataType::INT32, place); - paddle::Tensor q_pack_tokens = paddle::empty({1}, paddle::DataType::INT32, place); - - get_cur_cu_seq_len_k_kernel<<<1, 1, 0, stream>>>( - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seqlens_k.data(), - cu_seq_q_pack.data(), - q_pack_tokens.data(), - pack_size, - bsz - ); - - auto q_pack_tokens_cpu = q_pack_tokens.copy_to(paddle::CPUPlace(), true); - return {cu_seq_q_pack, cu_seqlens_k, q_pack_tokens_cpu}; -} - -PD_BUILD_OP(get_kv_from_cache) - .Inputs({ - "k_input", - "v_input", - "cu_seq_k", - "seq_len_encoder", - "seq_len_decoder", - "cache_k", - "cache_v", - "block_tables", - paddle::Optional("cache_k_dequant_scale"), - paddle::Optional("cache_v_dequant_scale"), - paddle::Optional("cache_k_zero_points"), - paddle::Optional("cache_v_zero_points")}) - .Attrs({ - "head_num: int", - "kv_head_num: int", - "head_dim: int", - "max_input_length: int", - "max_seq_k: int", - "cache_quant_type_str: std::string"}) - .Outputs({"k_input_out", "v_input_out"}) - .SetInplaceMap({{"k_input", "k_input_out"}, - {"v_input", "v_input_out"}}) - .SetKernelFn(PD_KERNEL(GetKVFromCache)); - -PD_BUILD_OP(get_cur_cu_seq_len_k) - .Inputs({ - "seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time"}) - .Attrs({ - "pack_size: int"}) - .Outputs({"cu_seq_q_pack", "cu_seqlens_k", "q_pack_tokens"}) - .SetKernelFn(PD_KERNEL(GetCurCuSeqLenk)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu b/custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu deleted file mode 100644 index 2354f4106..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" - - -template -__global__ void moba_mlp_einsum_kernel( - const T * src_data, - const T * weight_data, - const int * seq_lens_encoder, - const int * seq_lens_decoder, - const int * cu_seq_k, - T * dst_data, - const int head_num) { - - constexpr int kPackSize = 16 / sizeof(T); - const int block_idx = blockIdx.x; - const int bidh = blockIdx.y; - const int bidb = blockIdx.z; - const int tidx = threadIdx.x; - const int lane_id = tidx % 32; - const int warp_id = tidx / 32; - - __align__(16) __shared__ T local_sum_mem[128 / 32 * kHeadDim]; - - const int seq_len_encoder = seq_lens_encoder[bidb]; - const int seq_len_decoder = seq_len_encoder + seq_lens_decoder[bidb]; - - const int seq_len_this_block = seq_len_decoder - block_idx * moba_block_size; - - if (seq_len_encoder == 0 || seq_len_this_block <= 0) { - return; - } - - - using SrcType = Vec; - - constexpr int tidx_per_row = kHeadDim / kPackSize; - - const int row_idx = tidx / tidx_per_row; - const int col_idx = tidx % tidx_per_row * kPackSize; - - const int src_base_idx = cu_seq_k[bidb] * head_num * kHeadDim + block_idx * moba_block_size * head_num * kHeadDim + bidh * kHeadDim + row_idx * head_num * kHeadDim + col_idx; - const int weight_base_idx = bidh * kHeadDim * moba_block_size + row_idx * kHeadDim + col_idx; - - constexpr int step = 128 / tidx_per_row; - - SrcType sums, src, weight; - - sums.set_zero(); - - for (int i = 0; i < moba_block_size; i += step) { - if (i >= seq_len_this_block) { - break; - } - src.load_from(src_data + src_base_idx + i * head_num * kHeadDim); - weight.load_from(weight_data + weight_base_idx + i * kHeadDim); - sums.fma(src, weight); - } - - SrcType neighbor; - - #pragma unroll - for (int i = 0; i < kPackSize; i+=2) { - *reinterpret_cast(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast(sums.data.elt + i), 16); - } - - sums.add(neighbor); - - if (lane_id < 16) { - sums.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize); - } - - __syncthreads(); - using pack_half = std::conditional_t::value, __half2, nv_bfloat162>; - pack_half * local_sum_mem_half = reinterpret_cast(local_sum_mem); - - if (tidx < kHeadDim / 2) { - pack_half local_sum_half = local_sum_mem_half[tidx]; - #pragma unroll - for (int i = 1; i < 4; i++) { - local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)]; - } - local_sum_mem_half[tidx] = local_sum_half; - } - - __syncthreads(); - - const int store_row_id = tidx / (kHeadDim / kPackSize); - const int store_col_id = tidx % (kHeadDim / kPackSize) * kPackSize; - - sums.load_from(local_sum_mem + store_col_id); - - const int base_store_idx = bidb * kMaxN * head_num * kHeadDim + (block_idx * (moba_block_size / 128) + store_row_id) * head_num * kHeadDim + bidh * kHeadDim + store_col_id; - - if (store_row_id < moba_block_size / 128) { - sums.store_to(dst_data + base_store_idx); - } -} - - -template -void moba_mlp_einsum( - const T * src_data, - const T * weight_data, - const int * seq_lens_encoder, - const int * seq_lens_decoder, - const int * cu_seq_k, - T * dst_data, - const int moba_block_size, - const int max_seq_len, - const int head_num, - const int batch_size, - cudaStream_t stream) { - - dim3 grid_dims; - grid_dims.x = (max_seq_len + moba_block_size - 1) / moba_block_size; - grid_dims.y = head_num; - grid_dims.z = batch_size; - - if (moba_block_size == 1024) { - moba_mlp_einsum_kernel<<>>( - src_data, - weight_data, - seq_lens_encoder, - seq_lens_decoder, - cu_seq_k, - dst_data, - head_num); - } else if (moba_block_size == 128) { - moba_mlp_einsum_kernel<<>>( - src_data, - weight_data, - seq_lens_encoder, - seq_lens_decoder, - cu_seq_k, - dst_data, - head_num); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "MobaMlpEinsum not implemented for moba_block_size = %d", moba_block_size)); - } - -} - - -std::vector MobaMlpEinsum( - const paddle::Tensor& k_input, - const paddle::Tensor& attn_gate_weight, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& cu_seq_k, - const int max_seq_len, - const int kv_head_num) { - - const int kHeadDim = 128; - const int kMaxN = 1024; - const int moba_block_size = attn_gate_weight.dims()[1]; - const int batch_size = seq_lens_encoder.dims()[0]; - paddle::Tensor k_gate_weight = paddle::zeros({batch_size, kMaxN, kv_head_num, kHeadDim}, k_input.dtype(), k_input.place()); - - if (k_input.dtype() == paddle::DataType::FLOAT16) { - using T = phi::dtype::float16; - moba_mlp_einsum( - const_cast(k_input.data()), - const_cast(attn_gate_weight.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(cu_seq_k.data()), - k_gate_weight.data(), - moba_block_size, - max_seq_len, - kv_head_num, - batch_size, - k_input.stream() - ); - } else if (k_input.dtype() == paddle::DataType::BFLOAT16) { - using T = phi::dtype::bfloat16; - moba_mlp_einsum( - const_cast(k_input.data()), - const_cast(attn_gate_weight.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(cu_seq_k.data()), - k_gate_weight.data(), - moba_block_size, - max_seq_len, - kv_head_num, - batch_size, - k_input.stream() - ); - } - return {k_gate_weight}; -} - -PD_BUILD_OP(moba_mlp_einsum) - .Inputs({ - "k_input", - "attn_gate_weight", - "seq_lens_encoder", - "seq_lens_decoder", - "cu_seq_k"}) - .Attrs({ - "max_seq_len: int", - "kv_head_num: int"}) - .Outputs({"k_gate"}) - .SetKernelFn(PD_KERNEL(MobaMlpEinsum)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu b/custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu deleted file mode 100644 index 96409592a..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu +++ /dev/null @@ -1,465 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/layout/layout.h" -#include "cutlass/numeric_types.h" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/cluster_launch.hpp" -#include "cutlass/arch/reg_reconfig.h" - -template -__global__ void qk_gemm_kernel( - const input_type *q_input, - const input_type *k_gate_mean, - input_type *qk_gate_weight, - const int *seq_len_encoder, - const int *seq_len_decoder, - const int *cu_seq_q, - const int *cu_seq_k, - const int use_moba_seq_limit, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int kGQA_groupsize) { - - using TileShape_MNK = Shape, Int, Int>; - - using SmemLayoutAtomQ = decltype( - cutlass::gemm::collective::detail::ss_smem_selector< - GMMA::Major::K, input_type, - decltype(cute::get<0>(TileShape_MNK{})), - decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomK = decltype( - cutlass::gemm::collective::detail::ss_smem_selector< - GMMA::Major::K, input_type, decltype(cute::get<1>(TileShape_MNK{})), - decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomQK = decltype( - cutlass::gemm::collective::detail::ss_smem_selector< - GMMA::Major::K, input_type, - decltype(cute::get<0>(TileShape_MNK{})), - decltype(cute::get<1>(TileShape_MNK{}))>()); - - using SmemLayoutQK = decltype(tile_to_shape(SmemLayoutAtomQK{}, select<0, 1>(TileShape_MNK{}))); - - - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; - - using ValLayoutMNK = std::conditional_t< - is_split_kv, - Layout>, - Layout> - >; - - using PermutationMNK = std::conditional_t< - is_split_kv, - Tile<_16,_64,_16>, - Tile<_64,_16,_16> - >; - - using TiledMma = TiledMMA< - MMA_Atom_Arch, - ValLayoutMNK, - PermutationMNK>; - - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomQK = Copy_Atom; - - constexpr int kNThreads = 128; - constexpr int kThreadPerValue = 16 / sizeof(input_type); - constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue; - constexpr int kThreadsPerRowQK = kBlockN / kThreadPerValue; - - using GmemLayoutAtom = Layout< - Shape , Int>, - Stride, _1>>; - - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom< - SM80_CP_ASYNC_CACHEGLOBAL, input_type>{}, - GmemLayoutAtom{}, - Layout>>{})); - - using GmemLayoutAtomQK = Layout< - Shape , Int>, - Stride, _1>>; - - using GmemTiledCopyQK = decltype( - make_tiled_copy(Copy_Atom< - UniversalCopy, input_type>{}, - GmemLayoutAtomQK{}, - Layout>>{})); - - int mn_block = blockIdx.x; - const int bidb = blockIdx.y; - const int bidh = blockIdx.z; - const int bidh_k = bidh / kGQA_groupsize; - const int tidx = threadIdx.x; - - extern __shared__ char smem_[]; - - const int seq_len_q = seq_len_encoder[bidb]; - const int seq_len_k = seq_len_decoder[bidb]; - const int seq_len_qk = seq_len_q + seq_len_k; - - int q_head_stride; - const int k_head_stride = kv_head_num * kHeadDim; - int qk_head_stride; - int offset_q; - int offset_k; - int offset_qk; - int remain_q_seq; - - if constexpr (is_split_kv) { - if (seq_len_k < use_moba_seq_limit || seq_len_k == 0) { - return; - } - mn_block *= kBlockN; - q_head_stride = kHeadDim; - qk_head_stride = kMaxN; - if (mn_block >= (seq_len_k + kMobaBlockSize - 1) / kMobaBlockSize) { - return; - } - offset_q = cu_seq_q[bidb] * head_num * kHeadDim + bidh * kGQA_groupsize * kHeadDim; - offset_k = (bidb * kMaxN + mn_block) * k_head_stride + bidh * kHeadDim; - offset_qk = bidb * head_num * kMaxN + bidh * kGQA_groupsize * kMaxN + mn_block; - remain_q_seq = kGQA_groupsize; - } else { - if (seq_len_q == 0 || seq_len_qk < use_moba_seq_limit) { - return; - } - q_head_stride = head_num * kHeadDim; - qk_head_stride = head_num * kMaxN; - mn_block *= kBlockM; - if (mn_block >= seq_len_q) { - return; - } - offset_q = (cu_seq_q[bidb] + mn_block) * q_head_stride + bidh * kHeadDim; - offset_k = bidb * kMaxN * k_head_stride + bidh_k * kHeadDim; - offset_qk = (cu_seq_q[bidb] + mn_block) * qk_head_stride + bidh * kMaxN; - remain_q_seq = seq_len_q - mn_block; - } - - Tensor gQ = make_tensor(make_gmem_ptr(q_input + offset_q), - Shape, Int>{}, - make_stride(q_head_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(k_gate_mean + offset_k), - Shape, Int>{}, - make_stride(k_head_stride, _1{})); - Tensor gQK = make_tensor(make_gmem_ptr(qk_gate_weight + offset_qk), - Shape, Int>{}, - make_stride(qk_head_stride, _1{})); - - Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), SmemLayoutK{}); - Tensor sQ = make_tensor(sK.data() + size(sK), SmemLayoutQ{}); - Tensor sQK = make_tensor(sK.data() + size(sK), SmemLayoutQK{}); - - auto gmem_tiled_copy = GmemTiledCopy{}; - auto gmem_tiled_copy_qk = GmemTiledCopyQK{}; - auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx); - auto gmem_thr_copy_qk = gmem_tiled_copy_qk.get_thread_slice(tidx); - - - Tensor tQgQ = gmem_thr_copy.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy.partition_D(sQ); - - Tensor tKgK = gmem_thr_copy.partition_S(gK); - Tensor tKsK = gmem_thr_copy.partition_D(sK); - - Tensor tQKgQK = gmem_thr_copy_qk.partition_S(gQK); - Tensor tQKsQK = gmem_thr_copy_qk.partition_D(sQK); - - - Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim)); - Tensor tQcQ = gmem_thr_copy.partition_S(cQ); - - Tensor cK = make_identity_tensor(make_shape(kBlockN, kHeadDim)); - Tensor tKcK = gmem_thr_copy.partition_S(cK); - - Tensor cQK = make_identity_tensor(make_shape(kBlockM, kBlockN)); - Tensor tQKcQK = gmem_thr_copy.partition_S(cQK); - - if (remain_q_seq >= kBlockM) { - copy(gmem_tiled_copy, tQgQ, tQsQ, tQcQ); - } else { - copy(gmem_tiled_copy, tQgQ, tQsQ, tQcQ, remain_q_seq); - } - copy(gmem_tiled_copy, tKgK, tKsK, tKcK); - - cute::cp_async_fence(); - - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); - Tensor tSrK = thr_mma.partition_fragment_B(sK); - - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); - - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - - auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_QK = make_tiled_copy_C(SmemCopyAtomQK{}, tiled_mma); - auto smem_thr_copy_QK = smem_tiled_copy_QK.get_thread_slice(tidx); - Tensor tsQK = smem_thr_copy_QK.partition_D(sQK); - - const int n_blocks = is_split_kv ? 1 : cute::ceil_div(cute::ceil_div(seq_len_qk, kMobaBlockSize), kBlockN); - - #pragma unroll - for (int n_block = 0; n_block < n_blocks; ++n_block) { - clear(acc_s); - cp_async_wait<0>(); - __syncthreads(); - if (n_block == 0) { - gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K); - } else { - gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K); - } - if constexpr (!is_split_kv) { - if (n_block < n_blocks - 1) { - __syncthreads(); - tKgK.data() = tKgK.data() + kBlockN * k_head_stride; - copy(gmem_tiled_copy, tKgK, tKsK, tKcK); - cute::cp_async_fence(); - } - } - - Tensor rS = convert_type(acc_s); - Tensor trQK = smem_thr_copy_QK.retile_S(rS); - cute::copy(smem_tiled_copy_QK, trQK, tsQK); - - __syncthreads(); - if (remain_q_seq >= kBlockM) { - copy(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK); - } else { - copy(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK, remain_q_seq); - } - if constexpr (!is_split_kv) { - __syncthreads(); - tQKgQK.data() = tQKgQK.data() + kBlockN; - } - } -} - -template -void qk_gemm( - const input_type *q_input, - const input_type *k_gate_mean, - input_type *qk_gate_weight, - const int *seq_len_encoder, - const int *seq_len_decoder, - const int *cu_seq_q, - const int *cu_seq_k, - const int use_moba_seq_limit, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int bsz, - cudaStream_t stream) { - - const int gqa_group_size = head_num / kv_head_num; - - dim3 grid_dims; - const int num_m_block = (max_seq_q + kBlockM - 1) / kBlockM; - const int num_n_block = ((max_seq_k + kMobaBlockSize - 1) / kMobaBlockSize + kBlockN - 1) / kBlockN; - - if (is_split_kv) { - grid_dims.x = num_n_block; - grid_dims.z = kv_head_num; - } else { - grid_dims.x = num_m_block; - grid_dims.z = head_num; - } - grid_dims.y = bsz; - - constexpr int kHeadDim = 128; - constexpr int smemq = kBlockM * kHeadDim * sizeof(input_type); - constexpr int smemk = kBlockN * kHeadDim * sizeof(input_type); - constexpr int smemqk = kBlockM * kBlockN * sizeof(input_type); - const int smem_size = smemk + max(smemq, smemqk); - - auto kernel = &qk_gemm_kernel; - - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - - kernel<<>>( - q_input, - k_gate_mean, - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - use_moba_seq_limit, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - gqa_group_size); -} - - -template -std::vector DispatchMobaQKGemm( - const paddle::Tensor& q_input, - const paddle::Tensor& k_block_means, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const bool is_split_kv, - const int use_moba_seq_limit) { - - constexpr int kMobaBlockSize = 128; - constexpr int kMaxN = 1024; - const int batch_size = seq_len_encoder.dims()[0]; - using cute_type = typename cuteType::type; - if (is_split_kv) { - paddle::Tensor qk_gate_weight = paddle::empty({batch_size, head_num, kMaxN}, q_input.dtype(), q_input.place()); - qk_gemm( - reinterpret_cast(q_input.data()), - reinterpret_cast(k_block_means.data()), - reinterpret_cast(qk_gate_weight.data()), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_q.data(), - cu_seq_k.data(), - use_moba_seq_limit, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - batch_size, - q_input.stream() - ); - return {qk_gate_weight}; - } else { - constexpr int kBlockM = 128; - constexpr int kBlockN = 128; - const int token_num = q_input.dims()[0]; - paddle::Tensor qk_gate_weight = paddle::empty({token_num, head_num, kMaxN}, q_input.dtype(), q_input.place()); - qk_gemm( - reinterpret_cast(const_cast(q_input.data())), - reinterpret_cast(const_cast(k_block_means.data())), - reinterpret_cast(qk_gate_weight.data()), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_q.data(), - cu_seq_k.data(), - use_moba_seq_limit, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - batch_size, - q_input.stream()); - return {qk_gate_weight}; - } -} - -std::vector MobaQKGemm( - const paddle::Tensor& q_input, - const paddle::Tensor& k_block_means, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const bool is_split_kv, - const int use_moba_seq_limit) { - - if (q_input.dtype() == paddle::DataType::FLOAT16) { - return std::move( - DispatchMobaQKGemm( - q_input, - k_block_means, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - is_split_kv, - use_moba_seq_limit - ) - ); - } else if (q_input.dtype() == paddle::DataType::BFLOAT16) { - return std::move( - DispatchMobaQKGemm( - q_input, - k_block_means, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - is_split_kv, - use_moba_seq_limit - ) - ); - } -} - -PD_BUILD_OP(moba_qk_gemm) - .Inputs({ - "q_input", - "k_block_means", - "seq_len_encoder", - "seq_len_decoder", - "cu_seq_q", - "cu_seq_k"}) - .Attrs({ - "max_seq_q: int", - "max_seq_k: int", - "head_num: int", - "kv_head_num: int", - "is_split_kv: bool", - "use_moba_seq_limit: int"}) - .Outputs({"qk_gate_weight"}) - .SetKernelFn(PD_KERNEL(MobaQKGemm)); diff --git a/custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu b/custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu deleted file mode 100644 index 3957ba8e8..000000000 --- a/custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include "moba_attn/moba_attn_utils.hpp" -#include "moba_attn/moba_attn.h" - -template -__global__ void fused_block_mean_and_rope_kernel( - const input_type *qkv_input, - const input_type *qkv_bias, - input_type *k_gate_mean, - input_type *q_input, - input_type *k_input, - input_type *v_input, - const float *rope_sin_cos, - const int *seq_len_encoder, - const int *seq_len_decoder, - const int *cu_seq_q, - const int *cu_seq_k, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int max_input_length) { - - constexpr int kPackSize = 16 / sizeof(input_type); - constexpr int kHeadDim = 128; - - using src_type = Vec; - - using rope_type = Vec; - using pack_half = std::conditional_t::value, __half2, nv_bfloat162>; - - __align__(16) __shared__ input_type local_sum_mem[128 / 32 * kHeadDim]; - - const int bidb = blockIdx.x; - const int bidh = blockIdx.y; - const int bidt_q = blockIdx.z * tokens_per_block; - const int bidt_v = blockIdx.z * tokens_per_block; - const int bidt_k = need_k_mean ? blockIdx.z * moba_block_size : blockIdx.z * tokens_per_block; - const int tidx = threadIdx.x; - const int lane_id = tidx % 32; - const int warp_id = tidx / 32; - const int seq_len = seq_len_encoder[bidb]; - const int seq_len_start = seq_len_decoder[bidb]; - - if (seq_len == 0) { - return; - } - - const int all_head_num = head_num + 2 * kv_head_num; - const int hidden = all_head_num * kHeadDim; - - const int row_idx = tidx / (kHeadDim / kPackSize); - const int col_idx = tidx % (kHeadDim / kPackSize); - - const int bias_idx = bidh * kHeadDim + col_idx * kPackSize; - - src_type src, src_bias; - rope_type sin, cos; - - const bool need_add_bias = qkv_bias != nullptr; - - if (need_add_bias) { - src_bias.load_from(qkv_bias + bias_idx); - } - - if (bidh < head_num) { - const int cur_token = bidt_q + row_idx; - const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2); - const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2); - - if (cur_token < seq_len) { - src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden); - - if (need_add_bias) { - src.add(src_bias); - } - - sin.load_from(sin_rope); - cos.load_from(cos_rope); - apply_rotary_embedding(src, cos, sin); - - src.store_to(q_input + (cu_seq_q[bidb] + cur_token) * head_num * kHeadDim + bias_idx); - } - } else if (bidh < head_num + kv_head_num) { - if constexpr (!need_k_mean) { - const int cur_token = bidt_k + row_idx; - const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2); - const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2); - - if (cur_token < seq_len) { - src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden); - - if (need_add_bias) { - src.add(src_bias); - } - - sin.load_from(sin_rope); - cos.load_from(cos_rope); - apply_rotary_embedding(src, cos, sin); - - src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * head_num * kHeadDim + bias_idx- head_num * kHeadDim); - } - } else { - if (bidt_k >= seq_len) { - return; - } - - src_type local_sum; - local_sum.set_zero(); - - const input_type* qkv = qkv_input + cu_seq_q[bidb] * hidden + bias_idx; - - for (int i = 0; i < moba_block_size; i += tokens_per_block) { - const int cur_token = bidt_k + i + row_idx; - if (cur_token < seq_len) { - src.load_from(qkv + cur_token * hidden); - - if (need_add_bias) { - src.add(src_bias); - } - const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2); - const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2); - sin.load_from(sin_rope); - cos.load_from(cos_rope); - - apply_rotary_embedding(src, cos, sin); - - src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - head_num * kHeadDim); - - local_sum.add(src); - } - } - - src_type neighbor; - - #pragma unroll - for (int i = 0; i < kPackSize; i+=2) { - *reinterpret_cast(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast(local_sum.data.elt + i), 16); - } - - local_sum.add(neighbor); - - if (lane_id < 16) { - local_sum.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize); - } - - __syncthreads(); - - pack_half * local_sum_mem_half = reinterpret_cast(local_sum_mem); - - pack_half local_sum_half = local_sum_mem_half[tidx]; - - - if (tidx < kHeadDim / 2) { - - #pragma unroll - for (int i = 1; i < 4; i++) { - local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)]; - } - - float inv_tokens_sum = fdividef(1.0f, min(seq_len - bidt_k, moba_block_size)); - - local_sum_half *= float_2_half2(inv_tokens_sum); - - const int store_mean_idx = ((bidb * kMaxN + blockIdx.z + seq_len_start / moba_block_size) * kv_head_num * kHeadDim + (bidh - head_num) * kHeadDim) / 2 + tidx; - - reinterpret_cast(k_gate_mean)[store_mean_idx] = local_sum_half; - } - } - } else { - const int cur_token = bidt_v + row_idx; - - if (cur_token < seq_len) { - src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden); - if (need_add_bias) { - src.add(src_bias); - } - - src.store_to(v_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - (head_num + kv_head_num) * kHeadDim); - } - } -} - -template -void fused_block_mean_and_rope( - const input_type *qkv_input, - const input_type *qkv_bias, - input_type *k_gate_mean, - input_type *q_input, - input_type *k_input, - input_type *v_input, - const float *rope_sin_cos, - const int *seq_len_encoder, - const int *seq_len_decoder, - const int *cu_seq_q, - const int *cu_seq_k, - const int max_seq_q, - const int max_seq_k, - const int head_num, - const int kv_head_num, - const int bsz, - const int max_input_length, - cudaStream_t stream) { - - static_assert(moba_block_size >= 64, "moba_block_size must be at least 64"); - constexpr int kPackSize = 16 / sizeof(input_type); - constexpr int kHeadDim = 128; - constexpr int kThreads = 128; - constexpr int tokens_per_block = kThreads / (kHeadDim / kPackSize); - dim3 grid_dims; - grid_dims.x = bsz; - grid_dims.y = head_num + 2 * kv_head_num; - grid_dims.z = (max_seq_q + tokens_per_block - 1) / tokens_per_block; - - if (k_gate_mean != nullptr) { - fused_block_mean_and_rope_kernel - <<>>( - qkv_input, - qkv_bias, - k_gate_mean, - q_input, - k_input, - v_input, - rope_sin_cos, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - max_input_length); - } else { - fused_block_mean_and_rope_kernel - <<>>( - qkv_input, - qkv_bias, - k_gate_mean, - q_input, - k_input, - v_input, - rope_sin_cos, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - max_input_length); - } -} - -void FusedBlockMeanAndRope( - const paddle::Tensor& qkv_out, - const paddle::Tensor& k_block_means, - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& rotary_embs, - const paddle::Tensor& seq_len_encoder, - const paddle::Tensor& seq_len_decoder, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::optional& qkv_bias, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_input_length, - const int max_seq_q, - const int max_seq_k, - const std::string &cache_quant_type_str) { - - constexpr int kBlockM = 128; - constexpr int kBlockN = 128; - constexpr int kMobaBlockSize = 128; - constexpr int kMaxN = 1024; - - if (k_input.dtype() == paddle::DataType::FLOAT16) { - using T = phi::dtype::float16; - using cute_type = typename cuteType::type; - fused_block_mean_and_rope( - reinterpret_cast(const_cast(qkv_out.data())), - qkv_bias ? reinterpret_cast(const_cast(qkv_bias.get().data())) : nullptr, - reinterpret_cast(const_cast(k_block_means.data())), - reinterpret_cast(const_cast(q_input.data())), - reinterpret_cast(const_cast(k_input.data())), - reinterpret_cast(const_cast(v_input.data())), - rotary_embs.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_q.data(), - cu_seq_k.data(), - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - seq_len_encoder.dims()[0], - max_input_length, - qkv_out.stream()); - } else if (k_input.dtype() == paddle::DataType::BFLOAT16) { - using T = phi::dtype::bfloat16; - using cute_type = typename cuteType::type; - fused_block_mean_and_rope( - reinterpret_cast(const_cast(qkv_out.data())), - qkv_bias ? reinterpret_cast(const_cast(qkv_bias.get().data())) : nullptr, - reinterpret_cast(const_cast(k_block_means.data())), - reinterpret_cast(const_cast(q_input.data())), - reinterpret_cast(const_cast(k_input.data())), - reinterpret_cast(const_cast(v_input.data())), - rotary_embs.data(), - seq_len_encoder.data(), - seq_len_decoder.data(), - cu_seq_q.data(), - cu_seq_k.data(), - max_seq_q, - max_seq_k, - head_num, - kv_head_num, - seq_len_encoder.dims()[0], - max_input_length, - qkv_out.stream()); - } -} - - - -PD_BUILD_OP(fused_block_mean_and_rope) - .Inputs({ - "qkv_out", - "k_block_means", - "q_input", - "k_input", - "v_input", - "rotary_embs", - "seq_len_encoder", - "seq_len_decoder", - "cu_seq_q", - "cu_seq_k", - paddle::Optional("qkv_bias")}) - .Attrs({ - "head_num: int", - "kv_head_num: int", - "head_dim: int", - "max_input_length: int", - "max_seq_q: int", - "max_seq_k: int", - "cache_quant_type_str: std::string"}) - .Outputs({"q_input_out", "k_input_out", "v_input_out", "k_block_means_out"}) - .SetInplaceMap({{"q_input", "q_input_out"}, - {"k_input", "k_input_out"}, - {"v_input", "v_input_out"}, - {"k_block_means", "k_block_means_out"}}) - .SetKernelFn(PD_KERNEL(FusedBlockMeanAndRope)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index fcea2719e..5eeda4e6d 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -385,7 +385,6 @@ elif paddle.is_compiled_with_cuda(): "-Igpu_ops", "-Ithird_party/nlohmann_json/include", ] - nvcc_version = get_nvcc_version() print(f"nvcc_version = {nvcc_version}") if nvcc_version >= 12.0: @@ -509,10 +508,6 @@ elif paddle.is_compiled_with_cuda(): # Hopper optmized mla sources += find_end_files("gpu_ops/mla_attn", ".cu") sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] - sources += find_end_files("gpu_ops/moba_attn/moba_decoder_attn/", ".cu") - sources += find_end_files("gpu_ops/moba_attn/moba_encoder_attn/", ".cu") - sources += find_end_files("gpu_ops/moba_attn/moba_process/", ".cu") - sources += ["gpu_ops/moba_attn/moba_attn.cu"] os.system("python utils/auto_gen_w4afp8_gemm_kernel.py") sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu") os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py") diff --git a/docs/features/moba_sparse_attention.md b/docs/features/moba_sparse_attention.md deleted file mode 100644 index 8004bf4ac..000000000 --- a/docs/features/moba_sparse_attention.md +++ /dev/null @@ -1,31 +0,0 @@ -# moba_sparse_attention - -## Introduction - -We propose Lite MoBA and improve it based on MoBA. Specifically, we still draw on the MoE structure to divide KV into multiple blocks, introduce a learnable MLP layer to adaptively select important blocks. We use Full Attention's 1D Max Pooling Attention Map as Ground Truth. Then, we employ KLDivLoss to distill and train the MLP layer weights. Lite MoBA can be directly applied to post - training, where only the weights of the MLP are learnable and the weights of the original model remain unchanged. - -Compared to NSA or MoBA, our Lite MoBA is more scalable and pluggable, without the need to change traditional attention architectures or interfere with model weight training in the Pre - training and Post - training stages. It only requires a small amount of training on the MLP layer in the final stage of the model to achieve almost lossless accuracy. Since MoBA updates the weights of the entire model, even when Full Attention is automatically invoked for inputs shorter than BlockSize x BlockNum, it still cannot avoid the impact of model updates on the model's effectiveness in text processing. In contrast, our pluggable Lite MoBA can achieve Full Attention that is truly equivalent to that of the original model in short text scenarios. - -Compared with MoBA, in terms of effectiveness, its use of Average Pooling to represent inter - block relationships appears relatively limited and has poor handling of outlier representations. Our ablation experiments also demonstrated that the effectiveness of Average Pooling is inferior to that of the learnable MLP. In terms of training performance, since only the MLP weights need to be updated and the model weights do not need to be updated, a large amount of video memory will be saved during training (which needs to be tested). In terms of inference performance, when the input length is 128K, Block Size = 1024, and Block Num = 16, the performance is improved by 322% compared to Flash Attention 3. - -## Usage - -```bash -export FD_ATTENTION_BACKEND="MOBA_ATTN" - -python -m fastdeploy.entrypoints.openai.api_server - --model baidu/ERNIE-4.5-300B-A47B-Paddle \ - --port 8188 \ - --tensor-parallel-size 4 \ - --quantization wint4 \ - --enable-chunked-prefill \ - --max-num-batched-tokens 8192 \ - --max-model-len 131072 \ - --max-num-seqs 32 \ - --moba-attention-config '{"moba_encoder_top_k_left": 60, "moba_encoder_top_k_right": 80, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' -``` -## Environmental Variables Description - -* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention. -* `moba_encoder_top_k_left=60, moba_encoder_top_k_right=80` indicates that the range of top - k is between 80 and 100 when the encoder is sparse. -* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=100` indicates that the range of top - k is between 120 and 140 when the decoder is sparse. diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6ac148c9c..fe6e03ebd 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -682,67 +682,6 @@ class GraphOptimizationConfig: argument = self.use_cudagraph -class MobaAttentionConfig: - def __init__( - self, - args, - ): - self.moba_encoder_top_k_left: int = None - self.moba_encoder_top_k_right: int = None - "The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]" - self.moba_decoder_top_k_left: int = None - self.moba_decoder_top_k_right: int = None - "The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]" - self.moba_use_encoder_seq_limit: int = None - "When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse" - self.moba_use_decoder_seq_limit: int = None - "When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse" - self.moba_block_size: int = 128 - self.mlp_weight_name: str = "moba_mlp_weight.safetensors" - self.moba_max_seq_length: int = 128 * 1024 - if args is not None: - for key, value in args.items(): - if hasattr(self, key): - setattr(self, key, value) - if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None: - self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size - if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None: - self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size - self.check_legality_parameters() - - def check_legality_parameters( - self, - ) -> None: - if self.moba_encoder_top_k_left is not None: - assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0" - - if self.moba_encoder_top_k_right is not None: - assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0" - assert ( - self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left - ), "moba_encoder_top_k_right must large than moba_encoder_top_k_left" - - if self.moba_decoder_top_k_left is not None: - assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0" - - if self.moba_decoder_top_k_right is not None: - assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0" - assert ( - self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left - ), "moba_decoder_top_k_right must large than moba_decoder_top_k_left" - - if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None: - assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size - if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None: - assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size - - def to_json_string(self): - """ - Convert moba_attention_config to json string. - """ - return json.dumps({key: value for key, value in self.__dict__.items() if value is not None}) - - class EarlyStopConfig: def __init__( self, @@ -1097,7 +1036,6 @@ class FDConfig: decoding_config: DecodingConfig = None, quant_config: QuantConfigBase = None, graph_opt_config: GraphOptimizationConfig = None, - moba_attention_config: MobaAttentionConfig = None, speculative_config: SpeculativeConfig = None, tokenizer: str = None, max_model_len: int = 8192, @@ -1132,7 +1070,7 @@ class FDConfig: self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config self.decoding_config: DecodingConfig = decoding_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore - self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config + # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 3cd8bda97..7797708f0 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -26,7 +26,6 @@ from fastdeploy.config import ( FDConfig, GraphOptimizationConfig, LoadConfig, - MobaAttentionConfig, ModelConfig, ParallelConfig, SpeculativeConfig, @@ -337,10 +336,6 @@ class EngineArgs: """ Configuration for graph optimization backend execution. """ - moba_attention_config: Optional[Dict[str, Any]] = None - """ - Configuration for moba attention. - """ enable_logprob: bool = False """ @@ -539,12 +534,6 @@ class EngineArgs: default=EngineArgs.graph_optimization_config, help="", ) - model_group.add_argument( - "--moba-attention-config", - type=json.loads, - default=EngineArgs.moba_attention_config, - help="", - ) model_group.add_argument( "--guided-decoding-backend", type=str, @@ -940,18 +929,6 @@ class EngineArgs: graph_optimization_args[k] = v return GraphOptimizationConfig(graph_optimization_args) - def create_moba_attention_config(self) -> MobaAttentionConfig: - """ - Create and retuan a MobaAttentionConfig object based on the current settings. - """ - attention_args = asdict(self) - if self.moba_attention_config is not None: - for k, v in self.moba_attention_config.items(): - attention_args[k] = v - return MobaAttentionConfig(attention_args) - else: - return MobaAttentionConfig(None) - def create_early_stop_config(self) -> EarlyStopConfig: """ Create and retuan an EarlyStopConfig object based on the current settings. @@ -989,7 +966,6 @@ class EngineArgs: speculative_cfg = self.create_speculative_config() graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) - moba_attention_config = self.create_moba_attention_config() early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) @@ -1027,7 +1003,6 @@ class EngineArgs: max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, graph_opt_config=graph_opt_cfg, - moba_attention_config=moba_attention_config, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, early_stop_config=early_stop_cfg, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 352cbc76c..a86ef2432 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -464,7 +464,6 @@ class LLMEngine: f" --load_strategy {self.cfg.load_config.load_strategy}" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --load_choices {self.cfg.load_config.load_choices}" - f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --ips {ips}" ) diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 7157ac63b..c4c1801d4 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -20,7 +20,6 @@ from .block_multihead_attn_backend import BlockAttentionBackend from .flash_attn_backend import FlashAttentionBackend from .iluvatar_attn_backend import IluvatarAttnBackend from .mla_attention_backend import MLAAttentionBackend -from .moba_attention_backend import MobaAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend @@ -35,5 +34,4 @@ __all__ = [ "IluvatarAttnBackend", "BlockAttentionBackend", "Attention", - "MobaAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 286e2f796..98527571a 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -28,11 +28,6 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta - -import os - -from safetensors import safe_open - from fastdeploy.model_executor.layers.utils import get_tensor @@ -118,42 +113,6 @@ class Attention(nn.Layer): self.k_norm_key = f"{self.prefix}.k_norm" self.init_weight() - if fd_config.moba_attention_config is not None: - mlp_weight_path = os.path.join( - fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name - ) - self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path) - moba_block_size = fd_config.moba_attention_config.moba_block_size - moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length - if self.moba_use_mlp: - mlp_weight = {} - with safe_open(mlp_weight_path, framework="np", device="cpu") as f: - for key_name in f.keys(): - weight = f.get_tensor(key_name) - weight = paddle.Tensor(weight, zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - mlp_weight[key_name] = weight - - if self.layer_id < fd_config.model_config.num_hidden_layers - 1: - self.attn_gate_weight = mlp_weight[ - f"ernie.layers.{self.layer_id}.self_attn.attn_gate.weight" - ].astype(paddle.get_default_dtype())[ - fd_config.parallel_config.tensor_parallel_rank - * self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1) - * self.kv_num_heads - ] - assert self.attn_gate_weight.shape[1] % moba_block_size == 0 - - self.cache_k_block_means = paddle.zeros( - [ - fd_config.parallel_config.max_num_seqs, - moba_max_seq_length // moba_block_size, - self.kv_num_heads, - self.head_dim, - ], - dtype=paddle.get_default_dtype(), - ) - def init_weight(self): self.q_norm_weight = self.create_parameter( shape=[self.qk_head_dim], diff --git a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py deleted file mode 100644 index 7ddba90d1..000000000 --- a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py +++ /dev/null @@ -1,198 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import paddle - -try: - from fastdeploy.model_executor.ops.gpu import get_cur_cu_seq_len_k, moba_attention -except: - moba_attention = None - get_cur_cu_seq_len_k = None - -if TYPE_CHECKING: - from fastdeploy.model_executor.forward_meta import ForwardMeta - -from fastdeploy.config import FDConfig -from fastdeploy.model_executor.layers.attention.attention import Attention -from fastdeploy.model_executor.layers.attention.base_attention_backend import ( - AttentionBackend, - AttentionMetadata, -) - - -@dataclass -class MobaAttentionMetadata(AttentionMetadata): - """ - AppendAttentionMetadata - """ - - q_input: paddle.Tensor = None - k_input: paddle.Tensor = None - v_input: paddle.Tensor = None - cu_seq_q_pack: paddle.Tensor = None - cu_seqlens_k: paddle.Tensor = None - q_pack_tokens: paddle.Tensor = None - max_enc_len_this_time: int = 0 - max_dec_len_this_time: int = 0 - - -class MobaAttentionBackend(AttentionBackend): - """ - The backend class that uses paddle native attention implementation. - Which is used only for testing purpose. - """ - - def __init__( - self, - fd_config: FDConfig, - kv_num_heads: int, - num_heads: int, - head_dim: int, - encoder_block_shape_q: int = -1, - decoder_block_shape_q: int = -1, - ) -> None: - """ - MobaAttentionBackend __init__ - """ - super().__init__() - self.attention_metadata: MobaAttentionMetadata = None - assert fd_config.moba_attention_config is not None, "moba_attention_config is None" - self.block_size = fd_config.parallel_config.block_size - self.max_seq_len = fd_config.parallel_config.max_model_len - self.max_num_seqs = fd_config.parallel_config.max_num_seqs - self.kv_num_heads = kv_num_heads - self.num_heads = num_heads - self.head_dim = fd_config.model_config.head_dim - self.num_layers: int = fd_config.model_config.num_hidden_layers - self.attn_block_m = 128 - self.moba_block_size = fd_config.moba_attention_config.moba_block_size - self.moba_encoder_top_k_left = int(fd_config.moba_attention_config.moba_encoder_top_k_left) - self.moba_encoder_top_k_right = int(fd_config.moba_attention_config.moba_encoder_top_k_right) - self.moba_use_encoder_seq_limit = int(fd_config.moba_attention_config.moba_use_encoder_seq_limit) - self.moba_decoder_top_k_left = int(fd_config.moba_attention_config.moba_decoder_top_k_left) - self.moba_decoder_top_k_right = int(fd_config.moba_attention_config.moba_decoder_top_k_right) - self.moba_use_decoder_seq_limit = int(fd_config.moba_attention_config.moba_use_decoder_seq_limit) - self.moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length - - def init_attention_metadata(self, forward_meta: ForwardMeta): - """Init the metadata for a forward pass.""" - metadata = MobaAttentionMetadata() - metadata._dtype = paddle.get_default_dtype() - metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k( - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - int(self.attn_block_m), - ) - metadata.max_enc_len_this_time = forward_meta.seq_lens_encoder.max().cpu() - metadata.max_dec_len_this_time = forward_meta.seq_lens_decoder.max().cpu() - q_token_num = int(forward_meta.cu_seqlens_q[-1]) - k_token_num = int(metadata.cu_seqlens_k[-1]) - metadata.q_input = paddle.zeros( - [q_token_num + self.attn_block_m, self.num_heads * self.head_dim], dtype=metadata._dtype - ) - metadata.k_input = paddle.zeros( - [k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype - ) - metadata.v_input = paddle.zeros( - [k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype - ) - self.attention_metadata = metadata - assert self.max_seq_len <= self.moba_max_seq_length - - def get_kv_cache_shape( - self, - max_num_blocks: int, - kv_cache_quant_type: str = None, - ): - """ - Caculate kv cache shape - """ - if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim // 2, - ) - else: - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) - - def forward_mixed( - self, - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - qkv: paddle.Tensor, - compressed_kv: paddle.Tensor, - k_pe: paddle.Tensor, - layer: Attention, - forward_meta: ForwardMeta, - ) -> paddle.Tensor: - """ - Mixed模式的前向传播 - """ - attention_metadata = self.attention_metadata - out = moba_attention( - qkv, - attention_metadata.q_input, - attention_metadata.k_input, - attention_metadata.v_input, - forward_meta.cu_seqlens_q, - attention_metadata.cu_seqlens_k, - attention_metadata.cu_seq_q_pack, - attention_metadata.q_pack_tokens, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.caches[2 * layer.layer_id], - forward_meta.caches[2 * layer.layer_id + 1], - forward_meta.block_tables, - forward_meta.rotary_embs, - layer.cache_k_block_means, - getattr(layer, "attn_gate_weight", None), - layer.qkv_bias, - getattr(layer, "cache_k_scale", None), - getattr(layer, "cache_v_scale", None), - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - self.num_heads, - self.kv_num_heads, - self.head_dim, - self.max_seq_len, - attention_metadata.max_enc_len_this_time, - attention_metadata.max_dec_len_this_time, - self.moba_encoder_top_k_left, - self.moba_encoder_top_k_right, - self.moba_use_encoder_seq_limit, - self.moba_decoder_top_k_left, - self.moba_decoder_top_k_right, - self.moba_use_decoder_seq_limit, - layer.moba_use_mlp, - getattr(layer, "cache_quant_type_str", "none"), - )[0] - return out diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index a0e13f9c7..974ab60d7 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -26,7 +26,6 @@ class _Backend(enum.Enum): MLA_ATTN = enum.auto() FLASH_ATTN = enum.auto() BLOCK_ATTN = enum.auto() - MOBA_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index a9e070755..38504134a 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -64,9 +64,6 @@ class CUDAPlatform(Platform): elif selected_backend == _Backend.FLASH_ATTN: logger.info("Using FLASH ATTN backend.") return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend" - elif selected_backend == _Backend.MOBA_ATTN: - logger.info("Using MOBA ATTN backend.") - return "fastdeploy.model_executor.layers.attention.MobaAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 82074b70c..3db6f5b87 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -59,7 +59,6 @@ class RolloutModelConfig: graph_optimization_config: str = None, early_stop_config: str = None, local_rank: int = 0, - moba_attention_config: str = None, ): # Required parameters self.model = model_name_or_path @@ -104,7 +103,6 @@ class RolloutModelConfig: self.local_rank = local_rank self.early_stop_config = early_stop_config self.ips = None - self.moba_attention_config = moba_attention_config def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 0b0e9710a..d7d0e8e40 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -34,7 +34,6 @@ from fastdeploy.config import ( FDConfig, GraphOptimizationConfig, LoadConfig, - MobaAttentionConfig, ModelConfig, ParallelConfig, SpeculativeConfig, @@ -542,12 +541,6 @@ def parse_args(): default=None, help="Configation of Graph optimization backend.", ) - parser.add_argument( - "--moba_attention_config", - type=json.loads, - default=None, - help="Configation of moba attention.", - ) parser.add_argument( "--guided_decoding_backend", type=str, @@ -653,8 +646,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config) - moba_attention_config = MobaAttentionConfig(args.moba_attention_config) - early_stop_config = EarlyStopConfig(args.early_stop_config) # Note(tangbinhan): used for load_checkpoint @@ -734,7 +725,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: cache_config=cache_config, engine_worker_queue_port=args.engine_worker_queue_port, ips=args.ips, - moba_attention_config=moba_attention_config, ) update_fd_config_for_mm(fd_config) diff --git a/tests/layers/test_moba_attention.py b/tests/layers/test_moba_attention.py deleted file mode 100644 index a4cdc1e47..000000000 --- a/tests/layers/test_moba_attention.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle - -try: - from fastdeploy.model_executor.ops.gpu import ( - fused_block_mean_and_rope, - get_cur_cu_seq_len_k, - moba_encoder_attn, - moba_mlp_einsum, - moba_qk_gemm, - moba_qk_sort_encoder, - ) -except: - moba_attention = None - get_cur_cu_seq_len_k = None -import unittest - -import numpy as np - - -def naive_attn(q_input, k_input, v_input, mask): - gqa_group_size = q_input.shape[2] // k_input.shape[2] - - q_cur = q_input.transpose([0, 2, 1, 3]) - k_cur = k_input.transpose([0, 2, 1, 3]) - v_cur = v_input.transpose([0, 2, 1, 3]) - out = paddle.zeros(q_cur.shape, dtype=q_input.dtype) - - for bsz in range(0, q_cur.shape[0]): - for hi in range(0, q_cur.shape[1]): - qk = paddle.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3])) - qk += mask - qk_max = qk.max(axis=-1).unsqueeze(-1) - qk -= qk_max - qk = qk.exp() - - exp_sum = qk.sum(axis=-1).unsqueeze(-1) - exp_sum_inv = 1.0 / exp_sum - - out[bsz, hi] = (paddle.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype) - return out - - -class TestMobaAttention(unittest.TestCase): - def setUp(self): - paddle.seed(0) - self.seq_len = int(8 * 1024) - self.num_heads = int(8) - self.num_kv_heads = int(1) - self.head_dim = int(128) - self.max_num_seqs = 1 - self.moba_max_seq_length = int(128 * 1024) - self.moba_block_size = int(128) - self.moba_encoder_top_k_left = 2 - self.moba_encoder_top_k_right = 3 - self.moba_use_encoder_seq_limit = int(4 * 1024) - self.cache_k_block_means = paddle.zeros( - [ - self.max_num_seqs, - self.moba_max_seq_length // self.moba_block_size, - self.num_kv_heads, - self.head_dim, - ], - dtype="bfloat16", - ) - self.attn_block_m = 128 - self.tokens = self.seq_len * self.max_num_seqs - self.q_input = paddle.zeros( - [self.tokens + self.attn_block_m, self.num_heads, self.head_dim], - dtype="bfloat16", - ) - self.k_input = paddle.zeros( - [self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim], - dtype="bfloat16", - ) - self.v_input = paddle.zeros( - [self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim], - dtype="bfloat16", - ) - self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32") - - self.attn_gate_weight = paddle.randn( - [self.num_kv_heads, self.moba_block_size, self.head_dim], dtype="bfloat16" - ) - - self.gqa_group_size = self.num_heads // self.num_kv_heads - - self.num_blocks = (self.seq_len + self.moba_block_size - 1) // self.moba_block_size - - self.sparse_step = 4 - - def compare_split_qkv_rope(self, qkv_out): - assert (qkv_out[:, 0 : self.num_heads, :] - self.q_input[0 : self.tokens]).abs().max() < 1e-3 - assert ( - qkv_out[:, self.num_heads : self.num_heads + self.num_kv_heads, :] - self.k_input[0 : self.tokens] - ).abs().max() < 1e-3 - assert (qkv_out[:, self.num_heads + self.num_kv_heads :, :] - self.v_input[0 : self.tokens]).abs().max() < 1e-3 - - for i in range(self.max_num_seqs): - k_padding = paddle.zeros( - [ - (self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size, - self.num_kv_heads, - self.head_dim, - ], - dtype="bfloat16", - ) - k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len] - real_k_block_means = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim]) - real_k_block_means = real_k_block_means.mean(axis=1) - compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]] - assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003 - - print("[consistency]Moba attention: split_qkv_rope matches.") - - def compare_mlp_einsum(self, k_gate_weight): - for i in range(self.max_num_seqs): - k_padding = paddle.zeros( - [ - (self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size, - self.num_kv_heads, - self.head_dim, - ], - dtype="bfloat16", - ) - k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len] - k_padding = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim]) - real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight) - compute_result = k_gate_weight[i][0 : real_result.shape[0]] - - assert (real_result - compute_result).abs().max() < 0.5 - - print("[consistency]Moba attention: MLP einsum matches.") - - def compare_qk_gemm(self, qk_gate_weight): - for i in range(self.max_num_seqs): - q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len] - k_input_mean = self.cache_k_block_means[i][0 : self.num_blocks] - - qk_gemm_out = paddle.zeros( - [ - self.seq_len, - self.num_heads, - self.num_blocks, - ], - dtype="bfloat16", - ) - - for j in range(self.num_heads): - qk_gemm_out[:, j, :] = paddle.matmul( - q_input[:, j, :], k_input_mean[:, j // self.gqa_group_size, :], transpose_y=True - ) - - conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks] - assert (qk_gemm_out - conpute_result).abs().max() < 1e-4 - - print("[consistency]Moba attention: qk_gemm matches.") - - def compare_qk_gate_topk(self, qk_gate_topk_idx): - limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size - for i in range(self.max_num_seqs): - qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks] - qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1] - - assert ( - qk_gate_topk_idx_batch_no_sparse - - paddle.ones(qk_gate_topk_idx_batch_no_sparse.shape, qk_gate_topk_idx_batch_no_sparse.dtype) - ).abs().max() < 1e-6 - - for j in range(limit_topk, self.num_blocks): - qk_gate_topk_idx_batch_sparse = qk_gate_topk_idx_batch[j, :, 1 : (j + 1) // self.sparse_step] - - assert ( - qk_gate_topk_idx_batch_sparse - - paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype) - * self.sparse_step - ).abs().max() < 1e-6 - print("[consistency]Moba attention: qk_gate_topk matches.") - - def compare_attn(self, attn_out, qk_gate_topk_idx): - x = ( - paddle.tensor.triu(paddle.ones([self.moba_block_size, self.moba_block_size], dtype="bfloat16"), 1) - * -1000000 - ) - limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size - for i in range(self.max_num_seqs): - q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) - k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) - v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) - mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000 - mask[self.moba_use_encoder_seq_limit - self.moba_block_size :] = -1000000 - for i in range(limit_topk - 1, self.num_blocks): - n_block = i - mask[ - i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size, - n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size, - ] = x - idx = 0 - n_block -= int(qk_gate_topk_idx[i, 0, idx]) - idx += 1 - while n_block >= 0: - mask[ - i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size, - n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size, - ] = 0 - n_block -= int(qk_gate_topk_idx[i, 0, idx]) - idx += 1 - naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2]) - assert (attn_out - naive_attn_out).abs().max() < 0.016 - - def test_moba_attention(self): - qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16") - - seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32") - seq_len_decoder = paddle.to_tensor([0] * self.max_num_seqs, dtype="int32") - cu_seq_q = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len - cu_seq_k = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len - seq_lens_this_time = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32") - - cu_seq_q_pack, cu_seqlens_k, q_pack_tokens = get_cur_cu_seq_len_k( - seq_len_encoder, - seq_len_decoder, - seq_lens_this_time, - int(self.attn_block_m), - ) - - fused_block_mean_and_rope( - qkv_out, - self.cache_k_block_means, - self.q_input, - self.k_input, - self.v_input, - self.rotary_embs, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - None, - self.num_heads, - self.num_kv_heads, - self.head_dim, - self.moba_max_seq_length, - self.seq_len, - self.seq_len, - "none", - ) - - self.compare_split_qkv_rope(qkv_out) - - k_gate_weight = moba_mlp_einsum( - self.k_input, - self.attn_gate_weight, - seq_len_encoder, - seq_len_decoder, - cu_seq_k, - self.seq_len, - self.num_kv_heads, - ) - - self.compare_mlp_einsum(k_gate_weight) - - qk_gate_weight = moba_qk_gemm( - self.q_input, - self.cache_k_block_means, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - self.seq_len, - self.seq_len, - self.num_heads, - self.num_kv_heads, - False, - self.max_num_seqs, - ) - - self.compare_qk_gemm(qk_gate_weight) - - for i in range(0, self.num_blocks, self.sparse_step): - qk_gate_weight[:, :, i] = 100 - - qk_gate_topk_idx = moba_qk_sort_encoder( - qk_gate_weight, - seq_len_encoder, - seq_len_decoder, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - q_pack_tokens, - self.seq_len, - self.seq_len, - self.num_heads, - self.num_kv_heads, - self.moba_encoder_top_k_left, - self.moba_encoder_top_k_right, - self.moba_use_encoder_seq_limit, - ) - - self.compare_qk_gate_topk(qk_gate_topk_idx) - - attn_out = paddle.zeros([self.tokens, self.num_heads, self.head_dim], dtype="bfloat16") - - moba_encoder_attn( - self.q_input, - self.k_input, - self.v_input, - qk_gate_topk_idx, - cu_seq_q, - cu_seq_k, - cu_seq_q_pack, - seq_len_encoder, - seq_len_decoder, - attn_out, - self.seq_len, - self.seq_len, - self.num_heads, - self.num_kv_heads, - self.head_dim, - self.moba_max_seq_length, - ) - - self.compare_attn(attn_out, qk_gate_topk_idx) - - -if __name__ == "__main__": - if paddle.is_compiled_with_cuda(): - unittest.main()