Revert "[Feature] block sparse attention (#3209)" (#3647)

This reverts commit 646a0c2fd8.
This commit is contained in:
Jiang-Jia-Jun
2025-08-27 17:35:04 +08:00
committed by GitHub
parent b2afdf4fc6
commit c694fa2879
31 changed files with 10 additions and 6507 deletions

View File

@@ -784,15 +784,15 @@ void SpeculateStepPaddle(
const int max_draft_tokens);
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,

View File

@@ -1,330 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn.h"
std::vector<paddle::Tensor> MobaAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& block_tables,
const paddle::Tensor& rope_sin_cos,
const paddle::Tensor& k_block_means,
const paddle::optional<paddle::Tensor>& attn_gate_weight,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time,
const int moba_encoder_top_k_left,
const int moba_encoder_top_k_right,
const int moba_use_encoder_seq_limit,
const int moba_decoder_top_k_left,
const int moba_decoder_top_k_right,
const int moba_use_decoder_seq_limit,
const bool moba_use_mlp,
const std::string &cache_quant_type_str) {
paddle::Tensor out = paddle::empty({qkv.dims()[0], head_num * head_dim}, qkv.dtype(), qkv.place());
if (max_dec_len_this_time > 0) {
MobaDecoderAttnWriteCacheKv(
qkv,
q_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
seq_len_decoder,
key_cache,
value_cache,
block_tables,
rope_sin_cos,
k_block_means,
qkv_bias,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_len,
cache_quant_type_str);
auto qk_gate_weight = MobaQKGemm(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_dec_len_this_time,
max_dec_len_this_time,
head_num,
kv_head_num,
true,
moba_use_decoder_seq_limit
)[0];
auto qk_gate_topk_idx = QkSortDecoder(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
head_num,
kv_head_num,
moba_decoder_top_k_left,
moba_decoder_top_k_right,
moba_use_decoder_seq_limit
)[0];
MobaDecoderAttn(
q_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
key_cache,
value_cache,
block_tables,
k_block_means,
out,
qk_gate_topk_idx,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_len,
moba_use_decoder_seq_limit,
max_dec_len_this_time,
max_dec_len_this_time,
cache_quant_type_str
);
}
if (max_enc_len_this_time > 0) {
FusedBlockMeanAndRope(
qkv,
k_block_means,
q_input,
k_input,
v_input,
rope_sin_cos,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
qkv_bias,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_enc_len_this_time,
cache_quant_type_str
);
MobaEncoderAttnWriteCacheKv(
k_input,
v_input,
cu_seq_k,
seq_len_encoder,
seq_len_decoder,
key_cache,
value_cache,
block_tables,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_enc_len_this_time,
cache_quant_type_str
);
GetKVFromCache(
k_input,
v_input,
cu_seq_k,
seq_len_encoder,
seq_len_decoder,
key_cache,
value_cache,
block_tables,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time + max_dec_len_this_time,
cache_quant_type_str
);
paddle::Tensor *k_gate_weight = const_cast<paddle::Tensor*>(&k_block_means);
if (moba_use_mlp && attn_gate_weight) {
paddle::Tensor k_gate_mlp = MobaMlpEinsum(
k_input,
attn_gate_weight.get(),
seq_len_encoder,
seq_len_decoder,
cu_seq_k,
max_seq_len,
kv_head_num
)[0];
k_gate_weight = &k_gate_mlp;
}
auto qk_gate_weight = MobaQKGemm(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_enc_len_this_time,
max_enc_len_this_time + max_dec_len_this_time,
head_num,
kv_head_num,
false,
moba_use_encoder_seq_limit
)[0];
auto qk_gate_topk_idx = QkSortEncoder(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
max_enc_len_this_time,
max_enc_len_this_time + max_dec_len_this_time,
head_num,
kv_head_num,
moba_encoder_top_k_left,
moba_encoder_top_k_right,
moba_use_encoder_seq_limit)[0];
MobaEncoderAttn(
q_input,
k_input,
v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
out,
max_enc_len_this_time,
max_enc_len_this_time + max_dec_len_this_time,
head_num,
kv_head_num,
head_dim,
max_seq_len
);
}
return {out};
}
PD_BUILD_OP(moba_attention)
.Inputs({
"qkv",
"q_input",
"k_input",
"v_input",
"cu_seq_q",
"cu_seq_k",
"cu_seq_q_pack",
"q_pack_tokens",
"seq_len_encoder",
"seq_len_decoder",
"key_cache",
"value_cache",
"block_tables",
"rope_sin_cos",
"k_block_means",
paddle::Optional("attn_gate_weight"),
paddle::Optional("qkv_bias"),
paddle::Optional("cache_k_quant_scale"),
paddle::Optional("cache_v_quant_scale"),
paddle::Optional("cache_k_dequant_scale"),
paddle::Optional("cache_v_dequant_scale"),
paddle::Optional("cache_k_zero_points"),
paddle::Optional("cache_v_zero_points")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int",
"moba_encoder_top_k_left: int",
"moba_encoder_top_k_right: int",
"moba_use_encoder_seq_limit: int",
"moba_decoder_top_k_left: int",
"moba_decoder_top_k_right: int",
"moba_use_decoder_seq_limit: int",
"moba_use_mlp: bool",
"cache_quant_type_str: std::string"})
.Outputs({
"out",
"q_input_out",
"k_input_out",
"v_input_out",
"key_cache_out",
"value_cache_out",
"k_block_means_out"})
.SetInplaceMap({{
"q_input", "q_input_out"},
{"k_input", "k_input_out"},
{"v_input", "v_input_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"},
{"k_block_means", "k_block_means_out"}})
.SetKernelFn(PD_KERNEL(MobaAttention));

View File

@@ -1,204 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
void MobaDecoderAttnWriteCacheKv(
const paddle::Tensor& qkv_out,
const paddle::Tensor& q_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& rope_sin_cos,
const paddle::Tensor& k_block_means,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const std::string &cache_quant_type_str);
void MobaEncoderAttnWriteCacheKv(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_q,
const std::string &cache_quant_type_str);
void MobaDecoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& k_block_means,
const paddle::Tensor& out,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str);
void FusedBlockMeanAndRope(
const paddle::Tensor& qkv_out,
const paddle::Tensor& k_block_means,
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::optional<paddle::Tensor>& qkv_bias,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str);
std::vector<paddle::Tensor> GetCurCuSeqLenk(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int pack_size);
std::vector<paddle::Tensor> MobaQKGemm(
const paddle::Tensor& q_input,
const paddle::Tensor& k_block_means,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const bool is_split_kv,
const int use_moba_seq_limit);
std::vector<paddle::Tensor> QkSortDecoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit);
void GetKVFromCache(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_k,
const std::string &cache_quant_type_str);
void MobaEncoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& out,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length);
std::vector<paddle::Tensor> QkSortEncoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit);
std::vector<paddle::Tensor> MobaMlpEinsum(
const paddle::Tensor& k_input,
const paddle::Tensor& attn_gate_weight,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seq_k,
const int max_seq_len,
const int kv_head_num);

View File

@@ -1,748 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cub/cub.cuh>
#include "cute/tensor.hpp"
#include "cute/algorithm/copy.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/int_tuple.hpp"
#include <cute/arch/cluster_sm90.hpp>
#include <cub/cub.cuh>
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
using namespace cute;
template<typename T>
struct PackedHalf;
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
};
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
};
template<>
struct PackedHalf<phi::dtype::float16> {
using Type = __half2;
};
template<>
struct PackedHalf<phi::dtype::bfloat16> {
using Type = nv_bfloat162;
};
template<typename T>
struct HalfSub;
template<>
struct HalfSub<cutlass::half_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
}
};
template<>
struct HalfSub<cutlass::bfloat16_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
*reinterpret_cast<nv_bfloat162*>(result_ptr) -= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
}
};
template<typename T>
struct HalfMul;
template<>
struct HalfMul<cutlass::half_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
}
};
template<>
struct HalfMul<cutlass::bfloat16_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
*reinterpret_cast<nv_bfloat162*>(result_ptr) *= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
}
};
template<typename T>
struct HalfMax;
template<>
struct HalfMax<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("max.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMax<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("max.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<typename T>
struct HalfMin;
template<>
struct HalfMin<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("min.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMin<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("min.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
template<typename T>
struct MinOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
};
template <>
struct MinOp<float> {
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
};
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
template<typename T, bool Is_K>
inline __device__ static void convert_c8_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
uint32_t* half_result_ptr = reinterpret_cast<uint32_t*>(dst);
if constexpr (std::is_same_v<T, cutlass::bfloat16_t>) {
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(*src, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(*src, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(*src, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(*src, fp32_base, 0x7653);
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
fp32_intermediates[ii] -= 8388608.f;
}
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
half_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
}
} else {
static constexpr uint32_t head_for_fp16 = 0x64006400;
half_result_ptr[0] = __byte_perm(*src, head_for_fp16, 0x7150);
half_result_ptr[1] = __byte_perm(*src, head_for_fp16, 0x7352);
}
using pack_half = typename PackedHalf<T>::Type;
#pragma unroll
for (int i = 0; i < 2; i++){
if constexpr (Is_K) {
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + i * 2));
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + i * 2));
} else {
pack_half bias;
pack_half scale;
bias.x = cache_zp[0];
bias.y = cache_zp[0];
scale.x = cache_scale[0];
scale.y = cache_scale[0];
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
}
}
}
template<typename T, bool Is_K>
inline __device__ static void convert_c4_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
using pack_half = typename PackedHalf<T>::Type;
static constexpr uint32_t MASK = 0x0f0f0f0f;
static constexpr uint32_t head_for_fp16 = std::is_same_v<T, cutlass::bfloat16_t> ? 0x43004300 : 0x64006400;
static constexpr uint32_t mask_for_c42fp16_one = 0x7253;
static constexpr uint32_t mask_for_c42fp16_two = 0x7051;
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(dst);
uint32_t source = *reinterpret_cast<uint32_t const*>(src);
// source = {e0 e4 e1 e5 e2 e6 e3 e7}
uint32_t bottom_i4s = source & MASK;
// bottom_i4s = {0 e4 0 e5 0 e6 0 e7}
uint32_t top_i4s = (source >> 4) & MASK;
// top_i4s = {0 e0 0 e1 0 e2 0 e3}
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[0]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
// result_ptr[0] = {e0 e1}
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[1]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[2]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[3]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
#pragma unroll
for (int i = 0; i < 4; ++i) {
if constexpr (Is_K) {
const int ith_col = i % 2 * 2;
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + ith_col));
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + ith_col));
} else {
const int ith_col = i / 2;
pack_half bias;
pack_half scale;
bias.x = cache_zp[ith_col];
bias.y = cache_zp[ith_col];
scale.x = cache_scale[ith_col];
scale.y = cache_scale[ith_col];
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
}
}
}
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
inline __device__ void gemm_qk_quant(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
Tensor4 const& sB, TiledMma tiled_mma,
ThrCopy0 smem_thr_copy_A,
TiledCopy0 smem_tiled_copy_A,
const int32_t tidx,
const T * cache_scale, const T * cache_zp) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
}
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (kDataNumPer2Byte / 4);
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
}
}
if constexpr (kDataNumPer2Byte == 4) {
convert_c4_2_half<T, true>(sBdata + i * kHeadDim, tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
} else {
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2), tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2) + 1, tCrB.data() + 4, cache_scale + i * 4, cache_zp + i * 4);
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
}
}
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
inline __device__ void gemm_value_quant(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
Tensor4 const& sB, TiledMma tiled_mma,
ThrCopy0 smem_thr_copy_A,
TiledCopy0 smem_tiled_copy_A,
int32_t tidx,
const T * cache_scale, const T * cache_zp) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
}
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (2 * kDataNumPer2Byte / 4);
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
const int cur_idx = i * kHeadDim * (2 * kDataNumPer2Byte / 4);
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
}
}
if constexpr (kDataNumPer2Byte == 4) {
convert_c4_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
convert_c4_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
} else {
convert_c8_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
convert_c8_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 4, cache_scale + 1, cache_zp + 1);
convert_c8_2_half<T, false>(sBdata + cur_idx + 2, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
convert_c8_2_half<T, false>(sBdata + cur_idx + 3, tCrB.data() + 12, cache_scale + 3, cache_zp + 3);
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
}
}
template<int kMiLen, typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &scores, const uint32_t warp_id, const uint32_t col, const uint32_t reamin_seq_len) {
const int cols = size<1>(scores) / 2;
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
#pragma unroll
for (int ni = 0; ni < cols; ++ni) {
const int col_index = warp_id * 8 + ni * 32 + col * 2;
if (col_index >= reamin_seq_len) {
scores(mi, ni * 2) = -INFINITY;
}
if (col_index + 1 >= reamin_seq_len) {
scores(mi, ni * 2 + 1) = -INFINITY;
}
}
}
}
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template<int kMiLen, typename Engine0, typename Layout0, typename T>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, T *scores_max){
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
MaxOp<T> max_op;
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ni++) {
scores_max[mi] = max_op(scores_max[mi], tensor(mi, ni));
}
scores_max[mi] = Allreduce<4>::run(scores_max[mi], max_op);
}
}
template <int kMiLen, typename Engine0, typename Layout0, typename T>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, T const *max, T *sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
const float max_scaled = max[mi] * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled);
sum[mi] += tensor(mi, ni);
}
}
}
template <typename paddle_type>
struct cuteType;
template <>
struct cuteType<phi::dtype::float16> {
using type = cutlass::half_t;
};
template <>
struct cuteType<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
};
template<typename T>
__forceinline__ __device__ auto float_2_half2(const float x) {
if constexpr (std::is_same<T, cutlass::half_t>::value) {
return __float2half2_rn(x);
} else {
return __float2bfloat162_rn(x);
}
}
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
struct uint8 {
uint4 u;
uint4 v;
};
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
inline __device__ Vec() {}
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr) {
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
}
inline __device__ void store_to(void *base_ptr) {
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
}
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
}
}
inline __device__ void set_zero() {
constexpr int size = sizeof(Vec_type) / sizeof(int);
#pragma unroll
for (int i = 0; i < size; ++i) {
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
}
}
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
}
}
};
template<typename T, int PackSize>
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
static_assert(PackSize % 2 == 0);
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
const float cos_inv_freq = cos.data.elt[i];
const float sin_inv_freq = sin.data.elt[i];
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
}
}
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
__forceinline__ __device__ void copy(
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
}
}
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename ThrCopy0, typename ThrCopy1,
typename TiledCopy0, typename TiledCopy1>
inline __device__ void gemm(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
ThrCopy0 &smem_thr_copy_A, ThrCopy1 &smem_thr_copy_B,
TiledCopy0 &smem_tiled_copy_A, TiledCopy1 &smem_tiled_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template<typename T, typename ReductionOp, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) { result_broadcast = result; }
__syncthreads();
return result_broadcast;
}
template<typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}
};
template<typename T, typename ReductionOp, int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
ReductionOp op;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <int kPackSize, int knthreads>
__device__ inline int get_data_count(const float * src, const float limit_value) {
int count = 0;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (src[i] >= limit_value) {
count++;
}
}
count = BlockAllReduce<int, SumOp<int>, knthreads>(count);
return count;
}

View File

@@ -1,802 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
#include "moba_decoder_attn_kernel.h"
#include "moba_attn/moba_attn.h"
template<bool Is_first, int kMiLen, typename Tensor0, typename Tensor1, typename T>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &acc_o, const T *scores_max, const T *scores_max_prev, T * scores_sum, const float softmax_scale) {
if (Is_first) {
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
} else {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
const float scores_scale = expf((scores_max_prev[mi] - scores_max[mi]) * softmax_scale);
scores_sum[mi] *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale;
}
}
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
}
};
template<typename Kernel_traits, typename ParamType>
__global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attention_kernel(ParamType params) {
using cuteType = typename Kernel_traits::cuteType;
using ElementAccum = typename Kernel_traits::ElementAccum;
using CacheKV_traits = typename Kernel_traits::CacheKV_traits;
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
constexpr int32_t kHeadDimKV = Kernel_traits::kHeadDimKV;
constexpr int32_t kBlockM = Kernel_traits::kBlockM;
constexpr int32_t kBlockSize = Kernel_traits::kBlockSize;
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
constexpr int32_t kNWarps = Kernel_traits::kNWarps;
constexpr int32_t kTileN = Kernel_traits::kTileN;
constexpr int32_t kBlockN = kTileN * kBlockSize;
constexpr int32_t kDataBits = Kernel_traits::kDataBits;
constexpr int32_t kMiLen = (kGqaGroupSize + 7) / 8;
const int32_t bi = blockIdx.y;
const int32_t tidx = threadIdx.x;
const int32_t partition_idx = blockIdx.x;
const int32_t kv_head_idx = blockIdx.z;
const int32_t q_head_idx = kv_head_idx * kGqaGroupSize;
const int32_t seq_len = params.seq_lens_decoder[bi] == 0 ? 0 : params.seq_lens_decoder[bi] + 1;
const int32_t head_num = params.head_num;
const int32_t kv_head_num = params.kv_head_num;
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
if (seq_len == 0 || partition_idx >= partition_num) {
return;
}
if (seq_len >= params.use_moba_seq_limit && params.qk_gate_topk_idx_ptr[(bi * kv_head_num + kv_head_idx) * Kernel_traits::kMaxN + partition_idx] == 0) {
return;
}
const int q_bias_offset = q_head_idx * kHeadDim;
cuteType * q_input = reinterpret_cast<cuteType *>(params.q_input) + params.cu_seq_q[bi] * head_num * kHeadDim;
Tensor gQ = make_tensor(
make_gmem_ptr(reinterpret_cast<const cuteType *>(q_input) + q_bias_offset),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
const int32_t block_idx = partition_idx * kTileN;
const int* block_table = params.block_table + bi * params.max_num_blocks_per_seq + block_idx;
const int32_t physical_block_number = block_table[0];
const int32_t cache_offset = (physical_block_number * kv_head_num + kv_head_idx) * kBlockSize * kHeadDimKV;
Tensor gK = make_tensor(
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_k) + cache_offset),
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
Stride<Int<kHeadDimKV>, _1>{});
Tensor gV = make_tensor(
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_v) + cache_offset),
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
Stride<Int<kHeadDimKV>, _1>{});
extern __shared__ char smem_[];
Tensor sQ = make_tensor(
make_smem_ptr(reinterpret_cast<cuteType *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Tensor sQK = make_tensor(
sQ.data() + size(sQ),
typename Kernel_traits::SmemLayoutQK{});
Tensor sK = make_tensor(sQK.data() + size(sQK), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
__shared__ ElementAccum scores_warp[kNWarps][kMiLen * kBlockM];
auto gmem_tiled_copy_Q = typename Kernel_traits::GmemTiledCopyQ{};
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
auto gmem_tiled_copy_KV = typename Kernel_traits::GmemTiledCopyKV{};
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK);
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV);
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor cKV = make_identity_tensor(make_shape(kBlockSize, kHeadDim));
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
using SmemCopyAtom = typename Kernel_traits::SmemCopyAtom;
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSsQK = smem_thr_copy_Q.partition_S(sQK);
Tensor tSrQK = thr_mma.partition_fragment_A(sQK);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);
copy<false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, kGqaGroupSize);
cute::cp_async_fence();
cp_async_wait<0>();
const int32_t remain_seq_len = seq_len - partition_idx * kTileN * kBlockSize;
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
cute::cp_async_fence();
const int32_t warp_id = tidx / 32;
const int32_t lane_id = tidx % 32;
const int32_t row = lane_id / 4;
const int32_t col = lane_id % 4;
const int row_idx = tidx / 4;
using scale_k_vec = Vec<cuteType, 32>;
using scale_v_vec = Vec<cuteType, 4>;
scale_k_vec scale_k;
scale_k_vec zp_k;
scale_v_vec scale_v;
scale_v_vec zp_v;
if constexpr (kDataBits == 4) {
scale_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_dequant_scale + kv_head_idx * kHeadDim + col * 32);
zp_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_zp + kv_head_idx * kHeadDim + col * 32);
scale_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_dequant_scale + kv_head_idx * kHeadDim + row_idx * 4);
zp_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_zp + kv_head_idx * kHeadDim + row_idx * 4);
}
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
clear(acc_o);
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockSize>>{});
ElementAccum scores_max[kMiLen];
ElementAccum scores_max_prev[kMiLen];
ElementAccum scores_sum[kMiLen];
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_max[mi] = -INFINITY;
scores_sum[mi] = 0;
}
const int cache_offset_step = kv_head_num * kBlockSize * kHeadDimKV;
#pragma unroll
for (int n = 0; n < kTileN; ++n) {
const int cur_remain_seq_len = remain_seq_len - n * kBlockSize;
if (cur_remain_seq_len <= 0) {
break;
}
clear(acc_s);
cp_async_wait<0>();
__syncthreads();
if (n > 0) {
tVgV.data() = tVgV.data() + (block_table[n] - block_table[n - 1]) * cache_offset_step;
}
copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV);
cute::cp_async_fence();
if constexpr (kDataBits == 16) {
if (n == 0) {
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
} else {
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
}
} else {
Tensor tSrKQuant = make_tensor<cuteType>(
Layout<
Shape<Shape<_2, _2>, Int<kBlockSize / 32>>,
Stride<Shape<_1, _2>, _4>>{});
if (n == 0) {
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
} else {
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits, true>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
}
}
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
if (partition_idx == partition_num - 1 && cur_remain_seq_len < kBlockSize) {
apply_mask<kMiLen>(scores, warp_id, col, cur_remain_seq_len);
}
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_max_prev[mi] = scores_max[mi];
}
reduce_max<kMiLen>(scores, scores_max);
if (col == 0) {
scores_warp[warp_id][row] = scores_max[0];
if constexpr (kMiLen > 1) {
scores_warp[warp_id][row + 8] = scores_max[1];
}
}
__syncthreads();
MaxOp<ElementAccum> max_op;
if (tidx < kGqaGroupSize) {
float cur_max = scores_warp[0][tidx];
#pragma unroll
for (uint32_t i = 1; i < kNWarps; ++i) {
cur_max = max_op(scores_warp[i][tidx], cur_max);
}
scores_warp[0][tidx] = cur_max;
}
cp_async_wait<0>();
__syncthreads();
if (cur_remain_seq_len > kBlockSize && n < kTileN - 1) {
tKgK.data() = tKgK.data() + (block_table[n + 1] - block_table[n]) * cache_offset_step;
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
cute::cp_async_fence();
}
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_max[mi] = scores_warp[0][row + mi * 8];
}
if (n == 0) {
softmax_rescale_o<true, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
} else {
softmax_rescale_o<false, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
}
Tensor rS = convert_type<cuteType>(acc_s);
Tensor trQK = smem_thr_copy_O.retile_S(rS);
Tensor tsQK = smem_thr_copy_O.partition_D(sQK);
cute::copy(smem_tiled_copy_O, trQK, tsQK);
__syncthreads();
if constexpr (kDataBits == 16) {
gemm(acc_o, tSrQK, tOrVt, tSsQK, tOsVt, tiled_mma, smem_thr_copy_Q, smem_thr_copy_V, smem_tiled_copy_Q, smem_tiled_copy_V);
} else {
Tensor tSrVQuant = make_tensor<cuteType>(
Layout<
Shape<_4, Shape<_2, _2>>,
Stride<_1, Shape<_4, _8>>>{});
gemm_value_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_o, tSrQK, tSsQK, tSrVQuant, sV, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_v.data.elt, zp_v.data.elt);
}
}
const uint32_t pack_max_partition_num = (params.max_num_partitions + 3) / 4 * 4;
uint32_t max_sum_offset = bi * pack_max_partition_num * head_num + (tidx + q_head_idx) * pack_max_partition_num + partition_idx;
if (tidx < kGqaGroupSize) {
params.maxs[max_sum_offset] = scores_warp[0][tidx] * params.inv_sqrt_dh;
}
SumOp<ElementAccum> sum_op;
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_sum[mi] = Allreduce<4>::run(scores_sum[mi], sum_op);
}
__syncthreads();
if (col == 0) {
scores_warp[warp_id][row] = scores_sum[0];
if constexpr (kMiLen > 1) {
scores_warp[warp_id][row + 8] = scores_sum[1];
}
}
Tensor rO = convert_type<cuteType>(acc_o);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO);
Tensor taccOsO = smem_thr_copy_O.partition_D(sQ);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
__syncthreads();
if (tidx < kGqaGroupSize) {
float cur_sum = scores_warp[0][tidx];
#pragma unroll
for (uint32_t i = 1; i < kNWarps; ++i) {
cur_sum = sum_op(scores_warp[i][tidx], cur_sum);
}
scores_warp[0][tidx] = cur_sum;
}
Tensor gO = make_tensor(
make_gmem_ptr(reinterpret_cast<cuteType *>(params.partition_attn_out) + ((bi * params.max_num_partitions + partition_idx) * head_num + q_head_idx)* kHeadDim),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
auto gmem_tiled_copy_O = typename Kernel_traits::GmemTiledCopyO{};
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sQ);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
constexpr int32_t copy_size = kGqaGroupSize * 16;
__syncthreads();
if (tidx < copy_size) {
cute::copy(gmem_tiled_copy_O, tOsO(_, 0, _), tOgO(_, 0, _));
}
if constexpr (kMiLen > 1) {
if (tidx < copy_size - 128) {
cute::copy(gmem_tiled_copy_O, tOsO(_, 1, _), tOgO(_, 1, _));
}
}
if (tidx < kGqaGroupSize) {
params.sums[max_sum_offset] = scores_warp[0][tidx];
}
}
template<typename Kernel_traits, typename ParamType>
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType &params, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
const int32_t bi = blockIdx.z;
const int32_t tidx = threadIdx.x;
const int32_t head_idx = blockIdx.y;
const int32_t head_num = params.head_num;
using float_vec = Vec<float, kNFloatPacksize>;
const int32_t offset = bi * head_num * pack_max_partition_num + head_idx * pack_max_partition_num;
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = params.maxs + offset;
float global_max_logit = -FLT_MAX;
int32_t idx = tidx * kNFloatPacksize;
#pragma unroll
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
float_vec cur_max = *reinterpret_cast<const float_vec*>(max_logits_ptr + idx);
#pragma unroll
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
}
} else {
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
}
}
cur_max.store_to(shared_max_logits + idx);
}
const int32_t packed_data_num = partition_num / kNFloatPacksize * kNFloatPacksize;
idx = packed_data_num + tidx;
#pragma unroll
for (; idx < partition_num; idx += kNReduceThreads) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx] != 0) {
float cur_max = max_logits_ptr[idx];
global_max_logit = fmaxf(global_max_logit, cur_max);
shared_max_logits[idx] = cur_max;
}
} else {
float cur_max = max_logits_ptr[idx];
global_max_logit = fmaxf(global_max_logit, cur_max);
shared_max_logits[idx] = cur_max;
}
}
__syncthreads();
global_max_logit = BlockAllReduce<float, MaxOp<float>, kNReduceThreads>(global_max_logit);
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
const float* exp_sums_ptr = params.sums + offset;
float global_exp_sum = 0.0f;
idx = tidx * kNFloatPacksize;
#pragma unroll
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
float_vec share_max = *reinterpret_cast<const float_vec*>(shared_max_logits + idx);
#pragma unroll
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_max.data.elt[j] = exp_sub_max;
}
} else {
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_max.data.elt[j] = exp_sub_max;
}
}
share_max.store_to(share_sum_scale + idx);
}
idx = packed_data_num + tidx;
#pragma unroll
for (; idx < partition_num; idx += kNReduceThreads) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx] != 0) {
float share_max = shared_max_logits[idx];
float exp_sub_max = expf(share_max - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_sum_scale[idx] = exp_sub_max;
}
} else {
float share_max = shared_max_logits[idx];
float exp_sub_max = expf(share_max - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_sum_scale[idx] = exp_sub_max;
}
}
__syncthreads();
global_exp_sum = BlockAllReduce<float, SumOp<float>, kNReduceThreads>(global_exp_sum);
const float inv_global_exp_sum = fdividef(1.0f, global_exp_sum + 1e-6f);
return inv_global_exp_sum;
}
template<typename Kernel_traits, typename ParamType>
__global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_attention_merge_kernel(ParamType params) {
using cuteType = typename Kernel_traits::cuteType;
constexpr int32_t kBlockN = Kernel_traits::kTileN * Kernel_traits::kBlockSize;
constexpr int32_t kNReducePacksize = 16 / sizeof(cuteType);
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
constexpr int32_t kNReduceWarps = Kernel_traits::kNReduceWarps;
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
const int32_t bi = blockIdx.z;
const int32_t headdim_idx = kNReducePacksize * kNReduceWarps * blockIdx.x;
const int32_t tidx = threadIdx.x;
const int32_t head_idx = blockIdx.y;
const int32_t warp_id = tidx / 32;
const int32_t lane_id = tidx % 32;
const int32_t seq_len = params.seq_lens_decoder[bi] + 1;
const int32_t head_num = params.head_num;
using pack_half = typename PackedHalf<cuteType>::Type;
if (params.seq_lens_decoder[bi] == 0) {
return;
}
extern __shared__ char shared_mem[];
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
const int32_t pack_max_partition_num = (params.max_num_partitions + kNFloatPacksize - 1) / kNFloatPacksize * kNFloatPacksize;
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
using T_vec = Vec<cuteType, kNReducePacksize>;
cuteType* partition_attn_out = reinterpret_cast<cuteType*>(params.partition_attn_out) + bi * head_num * params.max_num_partitions * kHeadDim + head_idx * kHeadDim + headdim_idx;
Vec<float, kNReducePacksize> acc;
acc.set_zero();
#pragma unroll
for (int idx = lane_id; idx < partition_num; idx += 32) {
if (seq_len >= params.use_moba_seq_limit && qk_gate_topk_idx_ptr[idx] == 0) {
continue;
}
T_vec sub_logits = *reinterpret_cast<T_vec*>(&partition_attn_out[idx * head_num * kHeadDim + warp_id * kNReducePacksize]);
float scale = share_sum_scale[idx];
#pragma unroll
for (int k = 0; k < kNReducePacksize; ++k) {
acc.data.elt[k] += static_cast<float>(sub_logits.data.elt[k]) * scale;
}
}
__syncthreads();
T_vec out;
#pragma unroll
for (int k = 0; k < kNReducePacksize; ++k) {
out.data.elt[k] = static_cast<cuteType>(WarpAllReduce<float, SumOp<float>>(acc.data.elt[k]) * inv_global_exp_sum);
}
const int ori_token_idx = params.cu_seq_q[bi];
cuteType * attn_out = reinterpret_cast<cuteType *>(params.attn_out) + ori_token_idx * head_num * kHeadDim + head_idx * kHeadDim + headdim_idx + warp_id * kNReducePacksize;
if (lane_id == 0) {
out.store_to(attn_out);
}
}
template<typename Kernel_traits, typename ParamType>
void run_moba_decoder_attn(ParamType &params, cudaStream_t stream) {
dim3 grid;
grid.x = params.max_num_partitions;
grid.y = params.batch_size;
grid.z = params.kv_head_num;
constexpr int smem_size = Kernel_traits::kShareMemSize;
constexpr auto kernel = &moba_decoder_attention_kernel<Kernel_traits, ParamType>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
int32_t reduce_shared_mem_size = 2 * (params.max_num_partitions + 4) * sizeof(float);
constexpr int32_t pack_size = 16 / sizeof(typename Kernel_traits::cuteType);
static_assert(Kernel_traits::kHeadDim % pack_size == 0);
static_assert((Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps) % pack_size == 0);
grid.x = Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps / pack_size;
grid.y = params.head_num;
grid.z = params.batch_size;
auto reduce_kernel = &moba_decoder_attention_merge_kernel<Kernel_traits, ParamType>;
if (reduce_shared_mem_size >= 48 * 1024) {
cudaFuncSetAttribute(
reduce_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size);
}
reduce_kernel<<<grid, Kernel_traits::kNReduceThreads, reduce_shared_mem_size, stream>>>(params);
}
template<typename cute_type, int kCacheBits, int kBlockN, int kMaxN, typename ParamType>
void run_moba_decoder_attn_hdim128(ParamType &params, cudaStream_t stream) {
const int gqaGroupSize = params.head_num / params.kv_head_num;
using CacheKVTraits = CacheKV_quant_traits<cute_type, kCacheBits>;
constexpr int kTileN = kBlockN / CacheKVTraits::kBlockSize;
switch (gqaGroupSize) {
case 12: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<12, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 8: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<8, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 7: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<7, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 6: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<6, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 5: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<5, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 4: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<4, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"DecoderBlockAttention not implemented for gqaGroupSize = %d", gqaGroupSize));
}
}
}
template <typename T>
void DispatchMobaDecoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& k_block_means,
const paddle::Tensor& out,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_q,
const int max_seq_k,
const int batch_size,
const int max_input_length,
const int use_moba_seq_limit,
const std::string &cache_quant_type_str) {
using cute_type = typename cuteType<T>::type;
const int kMobaBlockSize = 128;
const int kMaxN = 1024;
constexpr int max_seq_per_block = kMobaBlockSize;
moba_decoder_attn_params<cute_type> params;
memset(&params, 0, sizeof(params));
const uint32_t max_num_partitions = (max_seq_k + max_seq_per_block) / max_seq_per_block;
assert(head_dim == 128);
paddle::Tensor maxs = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
paddle::Tensor sums = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
paddle::Tensor partition_attn_out = paddle::empty({batch_size, max_num_partitions, head_num, head_dim}, q_input.dtype(), q_input.place());
params.q_input = reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>()));
params.attn_out = reinterpret_cast<cute_type *>(const_cast<T*>(out.data<T>()));
params.seq_lens_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_len_decoder.data<int>());
params.block_table = const_cast<int*>(block_tables.data<int>());
params.max_input_length = max_input_length;
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_num_blocks_per_seq = block_tables.dims()[1];
params.batch_size = batch_size;
params.inv_sqrt_dh = 1.0f / std::sqrt(head_dim);
params.max_num_partitions = max_num_partitions;
params.maxs = reinterpret_cast<float*>(maxs.data<float>());
params.sums = reinterpret_cast<float*>(sums.data<float>());
params.partition_attn_out = reinterpret_cast<cute_type *>(partition_attn_out.data<T>());
params.qk_gate_topk_idx_ptr = const_cast<int*>(qk_gate_topk_idx.data<int>());
params.use_moba_seq_limit = use_moba_seq_limit;
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
if (cache_quant_type_str == "none") {
params.cache_k = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k.data<T>()));
params.cache_v = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v.data<T>()));
run_moba_decoder_attn_hdim128<cute_type, 16, max_seq_per_block, kMaxN>(params, q_input.stream());
} else {
params.cache_k = const_cast<uint8_t*>(cache_k.data<uint8_t>());
params.cache_v = const_cast<uint8_t*>(cache_v.data<uint8_t>());
params.cache_k_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_quant_scale.get().data<T>()));
params.cache_v_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_quant_scale.get().data<T>()));
params.cache_k_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_dequant_scale.get().data<T>()));
params.cache_v_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_dequant_scale.get().data<T>()));
params.cache_k_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_zero_points.get().data<T>()));
params.cache_v_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_zero_points.get().data<T>()));
if (cache_quant_type_str == "cache_int8_zp") {
run_moba_decoder_attn_hdim128<cute_type, 8, max_seq_per_block, kMaxN>(params, q_input.stream());
} else if (cache_quant_type_str == "cache_int4_zp") {
run_moba_decoder_attn_hdim128<cute_type, 4, max_seq_per_block, kMaxN>(params, q_input.stream());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"GQA Attention not implemented for cache_quant_type_str = %s", cache_quant_type_str.c_str()));
}
}
}
void MobaDecoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& k_block_means,
const paddle::Tensor& out,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str) {
const int batch_size = block_tables.dims()[0];
if (q_input.dtype() == paddle::DataType::FLOAT16) {
return DispatchMobaDecoderAttn<phi::dtype::float16>(
q_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cache_k,
cache_v,
block_tables,
k_block_means,
out,
qk_gate_topk_idx,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_q,
max_seq_k,
batch_size,
max_input_length,
use_moba_seq_limit,
cache_quant_type_str);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
return DispatchMobaDecoderAttn<phi::dtype::bfloat16>(
q_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cache_k,
cache_v,
block_tables,
k_block_means,
out,
qk_gate_topk_idx,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_q,
max_seq_k,
batch_size,
max_input_length,
use_moba_seq_limit,
cache_quant_type_str);
}
}

View File

@@ -1,225 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
#include "cute/tensor.hpp"
#include "cute/algorithm/copy.hpp"
#include "cute/algorithm/gemm.hpp"
#include "../moba_attn_utils.hpp"
using namespace cute;
template <typename T>
struct moba_decoder_attn_params {
T *__restrict__ q_input;
void *__restrict__ cache_k;
void *__restrict__ cache_v;
T *__restrict__ attn_out;
T *__restrict__ partition_attn_out;
T *__restrict__ cache_k_dequant_scale;
T *__restrict__ cache_v_dequant_scale;
T *__restrict__ cache_k_quant_scale;
T *__restrict__ cache_v_quant_scale;
T *__restrict__ cache_k_zp;
T *__restrict__ cache_v_zp;
int * __restrict__ cu_seq_q;
float * sums;
float * maxs;
int * seq_lens_encoder;
int * seq_lens_decoder;
int * block_table;
int max_input_length;
int max_seq_len;
int head_num;
int kv_head_num;
int max_num_blocks_per_seq;
float scale_softmax;
int batch_size;
int max_num_partitions;
float inv_sqrt_dh;
int *qk_gate_topk_idx_ptr;
int use_moba_seq_limit;
};
template <typename cute_type_, int DataBits_>
struct CacheKV_quant_traits {
using cuteType = cute_type_;
static constexpr int kDataBits = DataBits_;
static constexpr int kBlockSize = 64;
static constexpr int kHeadDim = 128;
static constexpr int kBlockKSmem = 64;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<
Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockSize>, Int<kHeadDim>>{}));
static constexpr int kNWarps = 4;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kThreadPerValue = 16 / sizeof(cuteType);
static constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
using GmemLayoutAtom = Layout<
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
Stride<Int<kThreadsPerRow>, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<cuteType, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1,_4,_1>>;
using PermutationMNK = Tile<_16, Int<16 * kNWarps>, _16>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
ValLayoutMNK,
PermutationMNK>;
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, cuteType>;
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<Shape<Int<kBlockKSmem>, Int<kBlockSize>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockSize>>{}));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, cuteType>;
static constexpr int kShareMemSize = size(SmemLayoutKV{}) * 2 * sizeof(cuteType);
};
template <int kGqaGroupSize_, int kTileN_, int kMaxN_, typename CacheKV_traits_>
struct moba_decoder_attn_kernel_traits {
using ElementAccum = float;
using CacheKV_traits = CacheKV_traits_;
using cuteType = typename CacheKV_traits::cuteType;
static constexpr int kDataBits = CacheKV_traits::kDataBits;
static constexpr int kTileN = kTileN_;
static constexpr int kMaxN = kMaxN_;
static constexpr int kGqaGroupSize = kGqaGroupSize_;
static constexpr int kHeadDim = CacheKV_traits::kHeadDim;
static constexpr int kHeadDimKV = kHeadDim / (16 / kDataBits);
static constexpr int kMinGemmM = 16;
static constexpr int kBlockM = (kGqaGroupSize + kMinGemmM - 1) / kMinGemmM * kMinGemmM;
static constexpr int kBlockSize = CacheKV_traits::kBlockSize;
static_assert(kGqaGroupSize <= 16);
static constexpr int32_t kNWarps = CacheKV_traits::kNWarps;
static constexpr int kBlockKSmem = CacheKV_traits::kBlockKSmem;
static constexpr int kBlockKVSmem = kHeadDimKV <= 64 ? kHeadDimKV : 64;
static_assert(kHeadDim % kBlockKSmem == 0);
static constexpr int kNReduceWarps = 4;
static constexpr int kNReduceThreads = kNReduceWarps * 32;
using SmemLayoutAtomQ = typename CacheKV_traits::SmemLayoutAtomQ;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutQK = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kBlockSize>>{}));
using SmemLayoutAtomKV = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<
Shape<Int<8>, Int<kBlockKVSmem>>,
Stride<Int<kBlockKVSmem>, _1>>{}));
using SmemLayoutKV_ = decltype(tile_to_shape(
SmemLayoutAtomKV{},
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{}));
using SmemLayoutKV = std::conditional_t<
kDataBits == 16,
SmemLayoutKV_,
decltype(get_nonswizzle_portion(SmemLayoutKV_{}))
>;
constexpr static int kBlockKVSize = kDataBits == 4 ? 32 : kBlockSize;
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<Shape<Int<kBlockKSmem>, Int<kBlockKVSize>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockKVSize>>{}));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
static constexpr int kThreadsPerRow = CacheKV_traits::kThreadsPerRow;
static constexpr int kThreadsKVPerRow = kThreadsPerRow / (16 / kDataBits);
static constexpr int kNThreads = CacheKV_traits::kNThreads;
using GmemKVLayoutAtom = Layout<
Shape<Int<kNThreads / kThreadsKVPerRow>, Int<kThreadsKVPerRow>>,
Stride<Int<kThreadsKVPerRow>, _1>>;
using SmemCopyAtom = typename CacheKV_traits::SmemCopyAtom;
using TiledMma = typename CacheKV_traits::TiledMma;
static constexpr int kThreadPerValue = CacheKV_traits::kThreadPerValue;
using GmemTiledCopyQ = typename CacheKV_traits::GmemTiledCopyQ;
using GmemLayoutAtom = typename CacheKV_traits::GmemLayoutAtom;
using GmemTiledCopyKV = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
GmemKVLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using SmemCopyAtomTransposed = typename CacheKV_traits::SmemCopyAtomTransposed;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, cuteType>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, cuteType>;
using SmemLayoutAtomO = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<
Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
static constexpr int kShareMemSize = (size(SmemLayoutQ{}) + size(SmemLayoutQK{}) + size(SmemLayoutKV{}) * 2) * sizeof(cuteType);
};

View File

@@ -1,189 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "../moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int kBlockSize, int kHeadDim, int moba_block_size, int kMaxN>
__global__ void moba_decoder_attn_write_c16(
const T * qkv_out,
const T * qkv_bias,
T * q_input,
const int * cu_seq_q,
const int * cu_seq_k,
const int * seq_len_encoder,
const int * seq_len_decoder,
T * cache_k,
T * cache_v,
const int * block_tables,
const float * rope_sin_cos,
T *k_block_means,
const int head_num,
const int kv_head_num,
const int max_blocks_per_seq,
const int max_input_length) {
int bidh = blockIdx.x;
const int bidb = blockIdx.y;
const int tidx = threadIdx.x;
const int seq_len = seq_len_decoder[bidb];
if (seq_len == 0) {
return;
}
constexpr int kPackSize = 4;
using SrcType = Vec<T, kPackSize>;
using rope_type = Vec<float, kPackSize / 2>;
SrcType src, bias, k_prev;
rope_type sin, cos;
const int bias_idx = bidh * kHeadDim + tidx * kPackSize;
const int ori_token_idx = cu_seq_q[bidb];
src.load_from(qkv_out + ori_token_idx * (head_num + 2 * kv_head_num) * kHeadDim + bias_idx);
if (qkv_bias != nullptr) {
bias.load_from(qkv_bias + bias_idx);
src.add(bias);
}
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
const int32_t physical_block_number = block_table_now[seq_len / kBlockSize];
if (bidh < head_num) {
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
src.store_to(q_input + cu_seq_q[bidb] * head_num * kHeadDim + bias_idx);
} else if (bidh < head_num + kv_head_num) {
bidh -= head_num;
const int token_in_blocks = seq_len % kBlockSize;
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
src.store_to(cache);
const int seq_len_block = seq_len / moba_block_size;
const int store_mean_idx = (bidb * kMaxN + seq_len_block) * kv_head_num * kHeadDim + bidh * kHeadDim + tidx * kPackSize;
if (seq_len % moba_block_size != 0) {
const int token_num_prev = seq_len % moba_block_size;
const float inv_tokens_sum = fdividef(1.0f, token_num_prev + 1);
k_prev.load_from(k_block_means + store_mean_idx);
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
src.data.elt[i] = T(inv_tokens_sum * (float(src.data.elt[i]) + float(k_prev.data.elt[i]) * token_num_prev));
}
}
src.store_to(k_block_means + store_mean_idx);
} else {
bidh -= head_num + kv_head_num;
const int token_in_blocks = seq_len % kBlockSize;
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
src.store_to(cache);
}
}
void MobaDecoderAttnWriteCacheKv(
const paddle::Tensor& qkv_out,
const paddle::Tensor& q_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& rope_sin_cos,
const paddle::Tensor& k_block_means,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const std::string &cache_quant_type_str) {
constexpr int kThreads = 32;
constexpr int kHeadDim = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
assert(kHeadDim == head_dim);
constexpr int kBlockSize = 64;
const int max_blocks_per_seq = block_tables.dims()[1];
const int batch_size = block_tables.dims()[0];
if (cache_quant_type_str == "none") {
dim3 grid_dims;
grid_dims.x = head_num + kv_head_num * 2;
grid_dims.y = batch_size;
if (qkv_out.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
qkv_out.data<T>(),
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
const_cast<T*>(q_input.data<T>()),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T *>(cache_k.data<T>()),
const_cast<T *>(cache_v.data<T>()),
block_tables.data<int>(),
rope_sin_cos.data<float>(),
const_cast<T*>(k_block_means.data<T>()),
head_num,
kv_head_num,
max_blocks_per_seq,
max_input_length);
} else if (qkv_out.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
qkv_out.data<T>(),
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
const_cast<T*>(q_input.data<T>()),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T *>(cache_k.data<T>()),
const_cast<T *>(cache_v.data<T>()),
block_tables.data<int>(),
rope_sin_cos.data<float>(),
const_cast<T*>(k_block_means.data<T>()),
head_num,
kv_head_num,
max_blocks_per_seq,
max_input_length);
}
} else {
PD_THROW("Only supported cache_quant_type_str in ['none'].");
}
}

View File

@@ -1,236 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int knthreads, int moba_block_size, int kBlockMaxN, int searchtimes>
__global__ void qk_gate_sort_decoder_kernel(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *decoder_seq_lens,
const int head_num,
const int kv_head_num,
const int kGqaGroupSize,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
const int bidb = blockIdx.x;
const int bidh = blockIdx.y;
const int tidx = threadIdx.x;
const int bidh_kv = bidh / kGqaGroupSize;
if (decoder_seq_lens[bidb] == 0 || decoder_seq_lens[bidb] < use_moba_seq_limit) {
return;
}
const int seq_len = (decoder_seq_lens[bidb] + moba_block_size - 1) / moba_block_size;
constexpr int kPackSize = kBlockMaxN / knthreads;
static_assert(kBlockMaxN % knthreads == 0);
T token_mean[kPackSize];
using SrcType = Vec<T, kPackSize>;
using SrcType_f = Vec<float, kPackSize>;
using SrcType_i = Vec<int, kPackSize>;
SrcType src;
SrcType_f src_f;
SrcType_i select_idx;
select_idx.set_zero();
const int load_offset = bidb * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
src.load_from(qk_gate_weight + load_offset);
float max_global = -FLT_MAX;
float min_global = FLT_MAX;
const int data_len = seq_len - tidx * kPackSize;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (i < data_len) {
src_f.data.elt[i] = float(src.data.elt[i]);
min_global = min(min_global, src_f.data.elt[i]);
} else {
src_f.data.elt[i] = -FLT_MAX;
}
max_global = max(max_global, src_f.data.elt[i]);
}
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
float right_limit = max_global;
float left_limit = min_global;
float mid_limit;
int count;
#pragma unroll
for (int i = 0; i < searchtimes; i++) {
mid_limit = (left_limit + right_limit) * 0.5f;
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
if (count < top_k_left) {
right_limit = mid_limit;
} else if (count > top_k_right) {
left_limit = mid_limit;
} else {
break;
}
}
const int store_idx = bidb * kv_head_num * kBlockMaxN + bidh_kv * kBlockMaxN + tidx * kPackSize;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (src_f.data.elt[i] >= mid_limit) {
qk_gate_topk_idx[store_idx + i] = 1;
}
}
if (tidx == 0) {
qk_gate_topk_idx[store_idx] = 1;
qk_gate_topk_idx[store_idx + seq_len - 1] = 1;
qk_gate_topk_idx[store_idx + seq_len - 2] = 1;
}
}
template <int kBlockMaxN, int moba_block_size, typename T>
void qk_gate_sort_decoder(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *decoder_seq_lens,
const int head_num,
const int kv_head_num,
const int batch_size,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit,
cudaStream_t stream) {
const int gqa_group_size = head_num / kv_head_num;
constexpr int kPackSize = 16 / sizeof(T);
const int knthreads = kBlockMaxN / kPackSize;
dim3 grid_dims;
grid_dims.x = batch_size;
grid_dims.y = head_num;
const int searchtimes = 6;
constexpr auto kernel = qk_gate_sort_decoder_kernel<T, knthreads, moba_block_size, kBlockMaxN, searchtimes>;
kernel<<<grid_dims, knthreads, 0, 0>>>(
qk_gate_weight,
qk_gate_topk_idx,
decoder_seq_lens,
head_num,
kv_head_num,
gqa_group_size,
top_k_left,
top_k_right,
use_moba_seq_limit);
}
template <typename T>
std::vector<paddle::Tensor> DispatchQkSortDecoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
const int batch_size = seq_len_decoder.dims()[0];
paddle::Tensor qk_gate_topk_idx = paddle::empty({batch_size, kv_head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
qk_gate_sort_decoder<kMaxN, kMobaBlockSize, T>(
qk_gate_weight.data<T>(),
qk_gate_topk_idx.data<int>(),
seq_len_decoder.data<int>(),
head_num,
kv_head_num,
batch_size,
top_k_left,
top_k_right,
use_moba_seq_limit,
qk_gate_weight.stream()
);
return {qk_gate_topk_idx};
}
std::vector<paddle::Tensor> QkSortDecoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
return std::move(
DispatchQkSortDecoder<phi::dtype::float16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit)
);
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
return std::move(
DispatchQkSortDecoder<phi::dtype::bfloat16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit)
);
}
}
PD_BUILD_OP(moba_qk_sort_decoder)
.Inputs({
"qk_gate_weight",
"seq_len_encoder",
"seq_len_decoder"})
.Attrs({
"head_num: int",
"kv_head_num: int",
"top_k_left: int",
"top_k_right: int",
"use_moba_seq_limit: int"})
.Outputs({"qk_gate_topk_idx"})
.SetKernelFn(PD_KERNEL(QkSortDecoder));

View File

@@ -1,143 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
struct moba_encoder_attn_params {
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void * __restrict__ o_ptr;
int * __restrict__ cu_seq_q;
int * __restrict__ cu_seq_k;
int * __restrict__ qk_gate_topk_idx;
int * __restrict__ seq_len_encoder;
int * __restrict__ cu_seq_q_pack;
int head_num;
int kv_head_num;
int max_seq_q;
int max_seq_k;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;
int use_moba_seq_limit;
};
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
struct SharedStorageQKVO {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, int kMaxN_, bool UseMoba_, typename elem_type=cutlass::half_t>
struct moba_encoder_attn_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int32_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int UseMoba = UseMoba_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static constexpr int kMaxN = kMaxN_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
static constexpr int kStages = kStages_;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using GmemTiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{},
TiledCopyOThrLayout{}, // Thr layout
TiledCopyOValLayout{} // Val layout
));
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
};

View File

@@ -1,473 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
using namespace cute;
enum class AttnNamedBarriers {
QueryEmpty = 0,
ValueEmpty = 1,
TileCountSmemEmpty = 2,
TileCountSmemFull = 3,
WarpSchedulerWG1 = 4,
WarpSchedulerWG2 = 5,
WarpSchedulerWG3 = 6,
};
template <typename Ktraits>
struct CollectiveMainloopAttn {
using Element = typename Ktraits::Element;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
using SmemLayoutO = typename Ktraits::SmemLayoutO;
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
LayoutT layout_Q;
Element const* ptr_K;
LayoutT layout_K;
Element const* ptr_V;
LayoutT layout_V;
float const softmax_scale_log2;
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for Q
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.layout_Q, args.layout_K, args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
const MTensor &m_tensor,
const Shape &tile_shape,
const int *cu_seq_len,
const int bidh,
const int bidb,
const int actual_seq_len) const {
auto g_offset = local_tile(
m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb], _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()
));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <bool UseMoba, typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
const int *qk_gate_topk_idx,
const int n_block_max,
const int m_block,
const int bidh,
const int bidb,
const int *cu_seq_q,
const int *cu_seq_k,
const int seq_len_q,
const int seq_len_k) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
Tensor gQ = get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
Tensor gK = get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor gV = get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
uint16_t mcast_mask_kv = 0;
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
}
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
if (lane_predicate) {
int idx = 0;
#pragma unroll 2
for (; n_block > 0; ) {
pipeline_k.producer_acquire(smem_pipe_write_k);
int pre_idx = 1;
if constexpr (UseMoba) {
pre_idx = qk_gate_topk_idx[idx];
}
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - pre_idx), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
n_block -= pre_idx;
idx += 1;
}
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
CUTLASS_DEVICE void
warp_scheduler_barrier_sync() {
if constexpr (UseSchedulerBarrier) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
}
}
CUTLASS_DEVICE void
mma_init() {
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if (cutlass::canonical_warp_group_idx() > 1) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
}
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
if (cutlass::canonical_warp_group_idx() > 2) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
}
}
}
CUTLASS_DEVICE void
warp_scheduler_barrier_arrive() {
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
}
}
template <bool UseMoba, typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
const int *qk_gate_topk_idx,
const int n_block_max,
const int thread_idx,
const int m_block,
const int seq_len_q,
const int seq_len_k,
SharedStorage& shared_storage) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int n_block = n_block_max - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
warp_scheduler_barrier_sync();
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warp_scheduler_barrier_arrive();
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
};
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >=
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
int idx = 0;
#pragma unroll 2
for (; n_block > 0; ) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
warp_scheduler_barrier_sync();
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
warp_scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
if constexpr (UseMoba) {
n_block -= qk_gate_topk_idx[idx];
idx += 1;
} else {
n_block -= 1;
}
}
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v);
++smem_pipe_read_v;
softmax.rescale_o(tOrO, scores_scale);
}
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO const& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
const int o_head_stride,
const int real_seq,
T * out_ptr) {
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::ValueEmpty) /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(o_head_stride, _1{}));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
if (real_seq >= kBlockM) {
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
} else {
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
}
}
};

View File

@@ -1,384 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
#include "kernel_traits.h"
#include "mainloop_attn.hpp"
#include "softmax.hpp"
#include "cutlass/arch/reg_reconfig.h"
template <int kHeadDim>
auto get_gmem_layout(int token_num, int head_num) {
return make_layout(
make_shape(token_num, kHeadDim, head_num),
make_stride(head_num * kHeadDim, _1{}, kHeadDim));
}
template <typename Ktraits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
moba_encoder_attention_kernel(
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
CUTE_GRID_CONSTANT moba_encoder_attn_params const data_params) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
constexpr int kHeadDim = Ktraits::kHeadDim;
constexpr int kMaxN = Ktraits::kMaxN;
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int seq_len_q = data_params.seq_len_encoder[bidb];
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
if (seq_len_q == 0) {
return;
}
__align__(16) __shared__ int qk_gate_topk_idx[kMaxN];
const int *qk_gate_idx_cur_offset = data_params.qk_gate_topk_idx + data_params.cu_seq_q_pack[bidb] / kBlockM * data_params.head_num * kMaxN + (m_block * data_params.head_num + bidh) * kMaxN;
#pragma unroll
for (int i = threadIdx.x; i < kMaxN / 4; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
reinterpret_cast<int4*>(qk_gate_topk_idx)[i] = reinterpret_cast<const int4*>(qk_gate_idx_cur_offset)[i];
}
const int n_block_max = min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN), cute::ceil_div(seq_len_k, kBlockN));
if (m_block * kBlockM >= seq_len_q) {
return;
}
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1);
}
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
__syncthreads();
CollectiveMainloop collective_mainloop;
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) {
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load<Ktraits::UseMoba>(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_write_k,
smem_pipe_write_v,
shared_storage,
qk_gate_topk_idx,
n_block_max,
m_block,
bidh,
bidb,
data_params.cu_seq_q,
data_params.cu_seq_k,
seq_len_q,
seq_len_k);
}
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
typename Ktraits::TiledMma1 tiled_mma1;
collective_mainloop.mma_init();
PipelineState smem_pipe_read_k, smem_pipe_read_v;
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
collective_mainloop.mma<Ktraits::UseMoba>(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_read_k,
smem_pipe_read_v,
tOrO,
softmax,
qk_gate_topk_idx,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
seq_len_q,
seq_len_k,
shared_storage);
const int o_head_stride = data_params.head_num * kHeadDim;
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
const int real_seq = seq_len_q - m_block * kBlockM;
collective_mainloop.store<NumMmaThreads>(
mainloop_params,
tOrO,
shared_storage,
tiled_mma1,
threadIdx.x - NumCopyThreads,
o_head_stride,
real_seq,
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
}
}
template<typename Kernel_traits>
void run_moba_decoder_attn(moba_encoder_attn_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_q * params.batch_size, params.head_num),
static_cast<Element const*>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
static_cast<Element const*>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
params.scale_softmax_log2
});
int num_blocks_m = cutlass::ceil_div(params.max_seq_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
void *kernel;
kernel = (void *)moba_encoder_attention_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = num_blocks_m;
grid_dims.y = params.head_num;
grid_dims.z = params.batch_size;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
}
template <int kBlockM, int kBlockN, int kMaxN, typename InputType>
void run_moba_encoder_attn_hdim128(moba_encoder_attn_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
constexpr static int kNWarps = kBlockM / 16 + 4;
constexpr static int kStages = 2;
using Ktraits = moba_encoder_attn_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, kMaxN, true, InputType>;
run_moba_decoder_attn<Ktraits>(params, stream);
}
template <typename T>
void DispatchMobaEncoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& out,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int head_dim,
const int batch_size,
const int max_input_length) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
using cute_type = typename cuteType<T>::type;
moba_encoder_attn_params params;
memset(&params, 0, sizeof(moba_encoder_attn_params));
params.q_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>()));
params.k_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>()));
params.v_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>()));
params.o_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(out.data<T>()));
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_q = max_seq_q;
params.max_seq_k = max_seq_k;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
params.qk_gate_topk_idx = const_cast<int*>(qk_gate_topk_idx.data<int>());
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.cu_seq_q_pack = const_cast<int*>(cu_seq_q_pack.data<int>());
run_moba_encoder_attn_hdim128<kBlockM, kBlockN, kMaxN, cute_type>(params, out.stream());
}
void MobaEncoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& out,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length) {
const int batch_size = seq_len_encoder.dims()[0];
if (q_input.dtype() == paddle::DataType::FLOAT16) {
return
DispatchMobaEncoderAttn<phi::dtype::float16>(
q_input,
k_input,
v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
out,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
head_dim,
batch_size,
max_input_length);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
return
DispatchMobaEncoderAttn<phi::dtype::bfloat16>(
q_input,
k_input,
v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
out,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
head_dim,
batch_size,
max_input_length);
}
}
PD_BUILD_OP(moba_encoder_attn)
.Inputs({
"q_input",
"k_input",
"v_input",
"qk_gate_topk_idx",
"cu_seq_q",
"cu_seq_k",
"cu_seq_q_pack",
"seq_len_encoder",
"seq_len_decoder",
"out"})
.Attrs({
"max_seq_q: int",
"max_seq_k: int",
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_input_length: int"})
.Outputs({"attn_out"})
.SetInplaceMap({{"out", "attn_out"}})
.SetKernelFn(PD_KERNEL(MobaEncoderAttn));

View File

@@ -1,163 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn.h"
template <typename T, int kBlockSize, int kHeadDim>
__global__ void write_encoder_cachekv_c16(
const T * k_input,
const T * v_input,
const int * cu_seq_k,
const int * seq_len_encoder,
const int * seq_len_decoder,
T * cache_k,
T * cache_v,
const int * block_tables,
const int kv_head_num,
const int max_blocks_per_seq) {
constexpr int kPackSize = 16 / sizeof(T);
const int block_idx = blockIdx.x * kBlockSize;
int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
const int row_idx = tidx / (kHeadDim / kPackSize);
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
const int seq_len = seq_len_encoder[bidb];
if (seq_len == 0) return;
const int ramian_tokens = seq_len - block_idx;
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
if (bidh < kv_head_num) {
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
}
}
} else {
bidh -= kv_head_num;
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
}
}
}
}
void MobaEncoderAttnWriteCacheKv(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_q,
const std::string &cache_quant_type_str) {
constexpr int kThreads = 128;
constexpr int kHeadDim = 128;
assert(kHeadDim == head_dim);
constexpr int kBlockSize = 64;
const int batch_size = block_tables.dims()[0];
const int max_blocks_per_seq = block_tables.dims()[1];
if (cache_quant_type_str == "none") {
dim3 grid_dims;
grid_dims.x = (max_seq_q + kBlockSize - 1) / kBlockSize;
grid_dims.y = kv_head_num * 2;
grid_dims.z = batch_size;
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(v_input.data<T>()),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T*>(cache_k.data<T>()),
const_cast<T*>(cache_v.data<T>()),
block_tables.data<int>(),
kv_head_num,
max_blocks_per_seq);
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(v_input.data<T>()),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T*>(cache_k.data<T>()),
const_cast<T*>(cache_v.data<T>()),
block_tables.data<int>(),
kv_head_num,
max_blocks_per_seq);
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Quantized cache not implemented for cache_quant_type = %s", cache_quant_type_str.c_str()));
}
}
PD_BUILD_OP(moba_encoder_attn_write_cache_kv)
.Inputs({
"k_input",
"v_input",
"cu_seq_k",
"seq_len_encoder",
"seq_len_decoder",
"cache_k",
"cache_v",
"block_tables",
paddle::Optional("cache_k_quant_scale"),
paddle::Optional("cache_v_quant_scale"),
paddle::Optional("cache_k_dequant_scale"),
paddle::Optional("cache_v_dequant_scale"),
paddle::Optional("cache_k_zero_points"),
paddle::Optional("cache_v_zero_points")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_q: int",
"cache_quant_type_str: std::string"})
.Outputs({"cache_k_out", "cache_v_out"})
.SetInplaceMap({{"cache_k", "cache_k_out"},
{"cache_v", "cache_v_out"}})
.SetKernelFn(PD_KERNEL(MobaEncoderAttnWriteCacheKv));

View File

@@ -1,341 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
template <typename T, int knthreads, int moba_block_size, int kBlockM, int kBlockMaxN, int searchtimes>
__global__ void qk_gate_sort_encoder_kernel(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int* cu_seq_q,
const int* cu_seq_k,
const int* cu_seq_q_pack,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int kGqaGroupSize,
const int top_k_left,
const int top_k_right) {
const int bidt = blockIdx.x * kBlockM;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
constexpr int kPackSize = kBlockMaxN / knthreads;
static_assert(kBlockMaxN % knthreads == 0);
const int seq_len_q = seq_len_encoder[bidb];
if (seq_len_q == 0 || bidt >= seq_len_q) {
return;
}
const int seq_len_k = (bidt + kBlockM + seq_len_decoder[bidb]);
const int seq_len_moba = seq_len_k / moba_block_size;
using SrcType = Vec<T, kPackSize>;
using SrcType_f = Vec<float, kPackSize>;
using SrcType_i = Vec<int, kPackSize>;
SrcType src;
SrcType_f src_f;
SrcType_i select_idx;
select_idx.set_zero();
const int store_idx = cu_seq_q_pack[bidb] / kBlockM * head_num * kBlockMaxN + bidh * kBlockMaxN + blockIdx.x * head_num * kBlockMaxN + tidx * kPackSize;
if (seq_len_k < use_moba_seq_limit) {
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
select_idx.data.elt[i] = 1;
}
select_idx.store_to(qk_gate_topk_idx + store_idx);
return;
}
const int load_offset = (cu_seq_q[bidb] + bidt) * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
const int data_len = seq_len_moba - tidx * kPackSize;
#pragma unroll
for (int t = 0; t < kBlockM; t++) {
if (bidt + t >= seq_len_q) {
break;
}
src.load_from(qk_gate_weight + load_offset + t * head_num * kBlockMaxN);
float min_global = FLT_MAX;
float max_global = -FLT_MAX;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (i < data_len) {
src_f.data.elt[i] = float(src.data.elt[i]);
min_global = min(min_global, src_f.data.elt[i]);
} else {
src_f.data.elt[i] = -FLT_MAX;
}
max_global = max(max_global, src_f.data.elt[i]);
}
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
float right_limit = max_global;
float left_limit = min_global;
float mid_limit;
int count;
if (right_limit == left_limit) {
mid_limit = (left_limit + right_limit) * 0.5f;
} else {
#pragma unroll
for (int i = 0; i < searchtimes; i++) {
mid_limit = (left_limit + right_limit) * 0.5f;
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
if (count < top_k_left) {
right_limit = mid_limit;
} else if (count > top_k_right) {
left_limit = mid_limit;
} else {
break;
}
}
}
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (src_f.data.elt[i] >= mid_limit) {
select_idx.data.elt[i] = 1;
}
}
}
if (tidx == 0) {
select_idx.data.elt[0] = 1;
}
__align__(16) __shared__ int qk_gate_mem[kBlockMaxN];
__align__(16) __shared__ int qk_continue_idx_mem[kBlockMaxN];
select_idx.store_to(qk_gate_mem + tidx * kPackSize);
__syncthreads();
if (tidx == 0) {
int cur_idx = 0;
int idx = -1;
const int last_idx = seq_len_moba - 1;
while (last_idx + idx >= 0 && qk_gate_mem[last_idx + idx] == 0) {
idx--;
}
qk_continue_idx_mem[cur_idx] = -idx;
cur_idx++;
for (int i = last_idx - 1; i >= 0; --i) {
if (qk_gate_mem[i] == 1) {
int idx = -1;
while (i + idx >= 0 && qk_gate_mem[i + idx] == 0) {
idx--;
}
qk_continue_idx_mem[cur_idx] = -idx;
cur_idx++;
}
}
qk_continue_idx_mem[cur_idx] = 10000000;
}
__syncthreads();
*reinterpret_cast<SrcType_i *>(qk_gate_topk_idx + store_idx) = reinterpret_cast<SrcType_i *>(qk_continue_idx_mem)[tidx];
}
template <int kBlockM, int kMaxN, int moba_block_size, typename T>
void qk_gate_sort_encoder(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int* cu_seq_q,
const int* cu_seq_k,
const int* cu_seq_q_pack,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int batch_size,
const int top_k_left,
const int top_k_right,
cudaStream_t stream) {
constexpr int kPackSize = 16 / sizeof(T);
const int gqa_group_size = head_num / kv_head_num;
const int knthreads = kMaxN / kPackSize;
const int searchtimes = 6;
dim3 grid_dims;
grid_dims.x = (max_seq_q + kBlockM - 1) / kBlockM;
grid_dims.y = head_num;
grid_dims.z = batch_size;
constexpr auto kernel = qk_gate_sort_encoder_kernel<T, knthreads, moba_block_size, kBlockM, kMaxN, searchtimes>;
kernel<<<grid_dims, knthreads, 0, stream>>>(
qk_gate_weight,
qk_gate_topk_idx,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
gqa_group_size,
top_k_left,
top_k_right);
}
template <typename T>
std::vector<paddle::Tensor> DispatchQkSortEncoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
using cute_type = typename cuteType<T>::type;
const int batch_size = seq_len_encoder.dims()[0];
paddle::Tensor qk_gate_topk_idx = paddle::empty({q_pack_tokens.data<int>()[0] / kBlockM, head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
qk_gate_sort_encoder<kBlockM, kMaxN, kMobaBlockSize, cute_type>(
reinterpret_cast<const cute_type *>(qk_gate_weight.data<T>()),
qk_gate_topk_idx.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
cu_seq_q_pack.data<int>(),
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
batch_size,
top_k_left,
top_k_right,
qk_gate_weight.stream());
return {qk_gate_topk_idx};
}
std::vector<paddle::Tensor> QkSortEncoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
return std::move(
DispatchQkSortEncoder<phi::dtype::float16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit
)
);
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
return std::move(
DispatchQkSortEncoder<phi::dtype::bfloat16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit
)
);
}
}
PD_BUILD_OP(moba_qk_sort_encoder)
.Inputs({
"qk_gate_weight",
"seq_len_encoder",
"seq_len_decoder",
"cu_seq_q",
"cu_seq_k",
"cu_seq_q_pack",
"q_pack_tokens"})
.Attrs({
"max_seq_q: int",
"max_seq_k: int",
"head_num: int",
"kv_head_num: int",
"top_k_left: int",
"top_k_right: int",
"use_moba_seq_limit: int"})
.Outputs({"qk_gate_topk_idx"})
.SetKernelFn(PD_KERNEL(QkSortEncoder));

View File

@@ -1,194 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "../moba_attn_utils.hpp"
using namespace cute;
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
}
__forceinline__ __device__ __half2 half_exp(__half2 x) {
uint32_t tmp_out, tmp_in;
tmp_in = reinterpret_cast<uint32_t&>(x);
asm ("ex2.approx.f16x2 %0, %1;\n"
: "=r"(tmp_out)
: "r"(tmp_in));
__half2 out = reinterpret_cast<__half2&>(tmp_out);
return out;
}
// Apply the exp to all the elements.
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const float max_scaled = max(mi) * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
CUTLASS_DEVICE Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
cute::fill(scores_scale, 1.f);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*zero_init=*/false>(scores, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale(mi);
}
}
return scores_scale;
};
template<bool Is_first, typename Tensor0>
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f);
} else {
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
}
return scores_scale;
};
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT scores_scale;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.0f / sum;
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
scores_scale(mi) = inv_sum;
}
return scores_scale;
};
template<typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
};

View File

@@ -1,288 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int kBlockSize, int kHeadDim>
__global__ void get_kv_from_cache_c16_kernel(
T * k_input,
T * v_input,
const int * seq_len_encoder,
const int * seq_len_decoder,
const int * cu_seq_k,
const T * cache_k,
const T * cache_v,
const int * block_tables,
const int kv_head_num,
const int head_dim,
const int batch_size,
const int max_input_length,
const int max_blocks_per_seq) {
const int block_idx = blockIdx.x;
int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int seq_len = seq_len_decoder[bidb] + seq_len_encoder[bidb];
const int tidx = threadIdx.x;
const int base_token_idx = block_idx * kBlockSize;
if (base_token_idx >= seq_len || seq_len_encoder[bidb] == 0) {
return;
}
constexpr int kPackSize = 16 / sizeof(T);
const int row_idx = tidx / (kHeadDim / kPackSize);
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
const int ramian_tokens = seq_len - base_token_idx;
if (bidh < kv_head_num) {
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
}
}
} else {
bidh -= kv_head_num;
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
}
}
}
}
template <typename T>
void get_kv_from_cache(
T * k_input,
T * v_input,
const int * seq_len_encoder,
const int * seq_len_decoder,
const int * cu_seq_k,
const void * cache_k,
const void * cache_v,
const int * block_tables,
const T * cache_k_dequant_scale,
const T * cache_v_dequant_scale,
const T * cache_k_zero_points,
const T * cache_v_zero_points,
const int kv_head_num,
const int head_dim,
const int max_seq_k,
const int batch_size,
const int max_input_length,
const int max_blocks_per_seq,
const std::string &cache_quant_type_str,
cudaStream_t stream) {
constexpr int kThreads = 128;
constexpr int kHeadDim = 128;
assert(kHeadDim == head_dim);
constexpr int kBlockSize = 64;
if (cache_quant_type_str == "none") {
dim3 grid_dims;
grid_dims.x = (max_seq_k + kBlockSize - 1) / kBlockSize;
grid_dims.y = kv_head_num * 2;
grid_dims.z = batch_size;
get_kv_from_cache_c16_kernel<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, stream>>>(
k_input,
v_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_k,
reinterpret_cast<const T*>(cache_k),
reinterpret_cast<const T*>(cache_v),
block_tables,
kv_head_num,
head_dim,
batch_size,
max_input_length,
max_blocks_per_seq);
} else {
PD_THROW("Only supported cache_quant_type_str in ['none'].");
}
}
void GetKVFromCache(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_k,
const std::string &cache_quant_type_str) {
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
using cute_type = typename cuteType<T>::type;
get_kv_from_cache<cute_type>(
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_k.data<int>(),
cache_k.data(),
cache_v.data(),
block_tables.data<int>(),
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
kv_head_num,
head_dim,
max_seq_k,
seq_len_encoder.dims()[0],
max_input_length,
block_tables.dims()[1],
cache_quant_type_str,
k_input.stream());
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
using cute_type = typename cuteType<T>::type;
get_kv_from_cache<cute_type>(
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_k.data<int>(),
cache_k.data(),
cache_v.data(),
block_tables.data<int>(),
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
kv_head_num,
head_dim,
max_seq_k,
seq_len_encoder.dims()[0],
max_input_length,
block_tables.dims()[1],
cache_quant_type_str,
k_input.stream());
}
}
__global__ void get_cur_cu_seq_len_k_kernel(
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
int* __restrict__ cu_seqlens_k,
int* __restrict__ cu_seq_q_pack,
int* __restrict__ q_pack_tokens,
const int pack_size,
const int bsz) {
int total_tokens = 0;
cu_seqlens_k[0] = 0;
cu_seq_q_pack[0] = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int cache_len = seq_lens_decoder[bid];
const int q_len = seq_lens_encoder[bid];
if (q_len <= 0) {
cache_len = 0;
}
total_tokens += (cache_len + q_len);
cu_seqlens_k[bid + 1] = total_tokens;
cu_seq_q_pack[bid + 1] = cu_seq_q_pack[bid] + (q_len + pack_size -1) / pack_size * pack_size;
}
q_pack_tokens[0] = cu_seq_q_pack[bsz];
}
std::vector<paddle::Tensor> GetCurCuSeqLenk(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int pack_size) {
auto stream = seq_lens_decoder.stream();
auto place = seq_lens_decoder.place();
int bsz = seq_lens_this_time.shape()[0];
paddle::Tensor cu_seq_q_pack = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
paddle::Tensor cu_seqlens_k = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
paddle::Tensor q_pack_tokens = paddle::empty({1}, paddle::DataType::INT32, place);
get_cur_cu_seq_len_k_kernel<<<1, 1, 0, stream>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seqlens_k.data<int>(),
cu_seq_q_pack.data<int>(),
q_pack_tokens.data<int>(),
pack_size,
bsz
);
auto q_pack_tokens_cpu = q_pack_tokens.copy_to(paddle::CPUPlace(), true);
return {cu_seq_q_pack, cu_seqlens_k, q_pack_tokens_cpu};
}
PD_BUILD_OP(get_kv_from_cache)
.Inputs({
"k_input",
"v_input",
"cu_seq_k",
"seq_len_encoder",
"seq_len_decoder",
"cache_k",
"cache_v",
"block_tables",
paddle::Optional("cache_k_dequant_scale"),
paddle::Optional("cache_v_dequant_scale"),
paddle::Optional("cache_k_zero_points"),
paddle::Optional("cache_v_zero_points")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_input_length: int",
"max_seq_k: int",
"cache_quant_type_str: std::string"})
.Outputs({"k_input_out", "v_input_out"})
.SetInplaceMap({{"k_input", "k_input_out"},
{"v_input", "v_input_out"}})
.SetKernelFn(PD_KERNEL(GetKVFromCache));
PD_BUILD_OP(get_cur_cu_seq_len_k)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time"})
.Attrs({
"pack_size: int"})
.Outputs({"cu_seq_q_pack", "cu_seqlens_k", "q_pack_tokens"})
.SetKernelFn(PD_KERNEL(GetCurCuSeqLenk));

View File

@@ -1,221 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int moba_block_size, int kHeadDim, int kMaxN>
__global__ void moba_mlp_einsum_kernel(
const T * src_data,
const T * weight_data,
const int * seq_lens_encoder,
const int * seq_lens_decoder,
const int * cu_seq_k,
T * dst_data,
const int head_num) {
constexpr int kPackSize = 16 / sizeof(T);
const int block_idx = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
const int lane_id = tidx % 32;
const int warp_id = tidx / 32;
__align__(16) __shared__ T local_sum_mem[128 / 32 * kHeadDim];
const int seq_len_encoder = seq_lens_encoder[bidb];
const int seq_len_decoder = seq_len_encoder + seq_lens_decoder[bidb];
const int seq_len_this_block = seq_len_decoder - block_idx * moba_block_size;
if (seq_len_encoder == 0 || seq_len_this_block <= 0) {
return;
}
using SrcType = Vec<T, kPackSize>;
constexpr int tidx_per_row = kHeadDim / kPackSize;
const int row_idx = tidx / tidx_per_row;
const int col_idx = tidx % tidx_per_row * kPackSize;
const int src_base_idx = cu_seq_k[bidb] * head_num * kHeadDim + block_idx * moba_block_size * head_num * kHeadDim + bidh * kHeadDim + row_idx * head_num * kHeadDim + col_idx;
const int weight_base_idx = bidh * kHeadDim * moba_block_size + row_idx * kHeadDim + col_idx;
constexpr int step = 128 / tidx_per_row;
SrcType sums, src, weight;
sums.set_zero();
for (int i = 0; i < moba_block_size; i += step) {
if (i >= seq_len_this_block) {
break;
}
src.load_from(src_data + src_base_idx + i * head_num * kHeadDim);
weight.load_from(weight_data + weight_base_idx + i * kHeadDim);
sums.fma(src, weight);
}
SrcType neighbor;
#pragma unroll
for (int i = 0; i < kPackSize; i+=2) {
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(sums.data.elt + i), 16);
}
sums.add(neighbor);
if (lane_id < 16) {
sums.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
}
__syncthreads();
using pack_half = std::conditional_t<std::is_same<T, phi::dtype::float16>::value, __half2, nv_bfloat162>;
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
if (tidx < kHeadDim / 2) {
pack_half local_sum_half = local_sum_mem_half[tidx];
#pragma unroll
for (int i = 1; i < 4; i++) {
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
}
local_sum_mem_half[tidx] = local_sum_half;
}
__syncthreads();
const int store_row_id = tidx / (kHeadDim / kPackSize);
const int store_col_id = tidx % (kHeadDim / kPackSize) * kPackSize;
sums.load_from(local_sum_mem + store_col_id);
const int base_store_idx = bidb * kMaxN * head_num * kHeadDim + (block_idx * (moba_block_size / 128) + store_row_id) * head_num * kHeadDim + bidh * kHeadDim + store_col_id;
if (store_row_id < moba_block_size / 128) {
sums.store_to(dst_data + base_store_idx);
}
}
template <typename T, int kHeadDim, int kMaxN>
void moba_mlp_einsum(
const T * src_data,
const T * weight_data,
const int * seq_lens_encoder,
const int * seq_lens_decoder,
const int * cu_seq_k,
T * dst_data,
const int moba_block_size,
const int max_seq_len,
const int head_num,
const int batch_size,
cudaStream_t stream) {
dim3 grid_dims;
grid_dims.x = (max_seq_len + moba_block_size - 1) / moba_block_size;
grid_dims.y = head_num;
grid_dims.z = batch_size;
if (moba_block_size == 1024) {
moba_mlp_einsum_kernel<T, 1024, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
src_data,
weight_data,
seq_lens_encoder,
seq_lens_decoder,
cu_seq_k,
dst_data,
head_num);
} else if (moba_block_size == 128) {
moba_mlp_einsum_kernel<T, 128, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
src_data,
weight_data,
seq_lens_encoder,
seq_lens_decoder,
cu_seq_k,
dst_data,
head_num);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"MobaMlpEinsum not implemented for moba_block_size = %d", moba_block_size));
}
}
std::vector<paddle::Tensor> MobaMlpEinsum(
const paddle::Tensor& k_input,
const paddle::Tensor& attn_gate_weight,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seq_k,
const int max_seq_len,
const int kv_head_num) {
const int kHeadDim = 128;
const int kMaxN = 1024;
const int moba_block_size = attn_gate_weight.dims()[1];
const int batch_size = seq_lens_encoder.dims()[0];
paddle::Tensor k_gate_weight = paddle::zeros({batch_size, kMaxN, kv_head_num, kHeadDim}, k_input.dtype(), k_input.place());
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
moba_mlp_einsum<T, kHeadDim, kMaxN>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(attn_gate_weight.data<T>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int*>(cu_seq_k.data<int>()),
k_gate_weight.data<T>(),
moba_block_size,
max_seq_len,
kv_head_num,
batch_size,
k_input.stream()
);
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
moba_mlp_einsum<T, kHeadDim, kMaxN>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(attn_gate_weight.data<T>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int*>(cu_seq_k.data<int>()),
k_gate_weight.data<T>(),
moba_block_size,
max_seq_len,
kv_head_num,
batch_size,
k_input.stream()
);
}
return {k_gate_weight};
}
PD_BUILD_OP(moba_mlp_einsum)
.Inputs({
"k_input",
"attn_gate_weight",
"seq_lens_encoder",
"seq_lens_decoder",
"cu_seq_k"})
.Attrs({
"max_seq_len: int",
"kv_head_num: int"})
.Outputs({"k_gate"})
.SetKernelFn(PD_KERNEL(MobaMlpEinsum));

View File

@@ -1,465 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, int kHeadDim, bool is_split_kv>
__global__ void qk_gemm_kernel(
const input_type *q_input,
const input_type *k_gate_mean,
input_type *qk_gate_weight,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int kGQA_groupsize) {
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using SmemLayoutAtomQ = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, input_type,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, input_type, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
using SmemLayoutAtomQK = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, input_type,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutQK = decltype(tile_to_shape(SmemLayoutAtomQK{}, select<0, 1>(TileShape_MNK{})));
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<input_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = std::conditional_t<
is_split_kv,
Layout<Shape<_1,_4,_1>>,
Layout<Shape<_4,_1,_1>>
>;
using PermutationMNK = std::conditional_t<
is_split_kv,
Tile<_16,_64,_16>,
Tile<_64,_16,_16>
>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
ValLayoutMNK,
PermutationMNK>;
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, input_type>;
using SmemCopyAtomQK = Copy_Atom<cute::SM90_U32x4_STSM_N, input_type>;
constexpr int kNThreads = 128;
constexpr int kThreadPerValue = 16 / sizeof(input_type);
constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
constexpr int kThreadsPerRowQK = kBlockN / kThreadPerValue;
using GmemLayoutAtom = Layout<
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
Stride<Int<kThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, input_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using GmemLayoutAtomQK = Layout<
Shape <Int<kNThreads / kThreadsPerRowQK>, Int<kThreadsPerRowQK>>,
Stride<Int<kThreadsPerRowQK>, _1>>;
using GmemTiledCopyQK = decltype(
make_tiled_copy(Copy_Atom<
UniversalCopy<cutlass::uint128_t>, input_type>{},
GmemLayoutAtomQK{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
int mn_block = blockIdx.x;
const int bidb = blockIdx.y;
const int bidh = blockIdx.z;
const int bidh_k = bidh / kGQA_groupsize;
const int tidx = threadIdx.x;
extern __shared__ char smem_[];
const int seq_len_q = seq_len_encoder[bidb];
const int seq_len_k = seq_len_decoder[bidb];
const int seq_len_qk = seq_len_q + seq_len_k;
int q_head_stride;
const int k_head_stride = kv_head_num * kHeadDim;
int qk_head_stride;
int offset_q;
int offset_k;
int offset_qk;
int remain_q_seq;
if constexpr (is_split_kv) {
if (seq_len_k < use_moba_seq_limit || seq_len_k == 0) {
return;
}
mn_block *= kBlockN;
q_head_stride = kHeadDim;
qk_head_stride = kMaxN;
if (mn_block >= (seq_len_k + kMobaBlockSize - 1) / kMobaBlockSize) {
return;
}
offset_q = cu_seq_q[bidb] * head_num * kHeadDim + bidh * kGQA_groupsize * kHeadDim;
offset_k = (bidb * kMaxN + mn_block) * k_head_stride + bidh * kHeadDim;
offset_qk = bidb * head_num * kMaxN + bidh * kGQA_groupsize * kMaxN + mn_block;
remain_q_seq = kGQA_groupsize;
} else {
if (seq_len_q == 0 || seq_len_qk < use_moba_seq_limit) {
return;
}
q_head_stride = head_num * kHeadDim;
qk_head_stride = head_num * kMaxN;
mn_block *= kBlockM;
if (mn_block >= seq_len_q) {
return;
}
offset_q = (cu_seq_q[bidb] + mn_block) * q_head_stride + bidh * kHeadDim;
offset_k = bidb * kMaxN * k_head_stride + bidh_k * kHeadDim;
offset_qk = (cu_seq_q[bidb] + mn_block) * qk_head_stride + bidh * kMaxN;
remain_q_seq = seq_len_q - mn_block;
}
Tensor gQ = make_tensor(make_gmem_ptr(q_input + offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(q_head_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(k_gate_mean + offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(k_head_stride, _1{}));
Tensor gQK = make_tensor(make_gmem_ptr(qk_gate_weight + offset_qk),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(qk_head_stride, _1{}));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<input_type *>(smem_)), SmemLayoutK{});
Tensor sQ = make_tensor(sK.data() + size(sK), SmemLayoutQ{});
Tensor sQK = make_tensor(sK.data() + size(sK), SmemLayoutQK{});
auto gmem_tiled_copy = GmemTiledCopy{};
auto gmem_tiled_copy_qk = GmemTiledCopyQK{};
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx);
auto gmem_thr_copy_qk = gmem_tiled_copy_qk.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy.partition_D(sQ);
Tensor tKgK = gmem_thr_copy.partition_S(gK);
Tensor tKsK = gmem_thr_copy.partition_D(sK);
Tensor tQKgQK = gmem_thr_copy_qk.partition_S(gQK);
Tensor tQKsQK = gmem_thr_copy_qk.partition_D(sQK);
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
Tensor tQcQ = gmem_thr_copy.partition_S(cQ);
Tensor cK = make_identity_tensor(make_shape(kBlockN, kHeadDim));
Tensor tKcK = gmem_thr_copy.partition_S(cK);
Tensor cQK = make_identity_tensor(make_shape(kBlockM, kBlockN));
Tensor tQKcQK = gmem_thr_copy.partition_S(cQK);
if (remain_q_seq >= kBlockM) {
copy(gmem_tiled_copy, tQgQ, tQsQ, tQcQ);
} else {
copy<false>(gmem_tiled_copy, tQgQ, tQsQ, tQcQ, remain_q_seq);
}
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
cute::cp_async_fence();
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_QK = make_tiled_copy_C(SmemCopyAtomQK{}, tiled_mma);
auto smem_thr_copy_QK = smem_tiled_copy_QK.get_thread_slice(tidx);
Tensor tsQK = smem_thr_copy_QK.partition_D(sQK);
const int n_blocks = is_split_kv ? 1 : cute::ceil_div(cute::ceil_div(seq_len_qk, kMobaBlockSize), kBlockN);
#pragma unroll
for (int n_block = 0; n_block < n_blocks; ++n_block) {
clear(acc_s);
cp_async_wait<0>();
__syncthreads();
if (n_block == 0) {
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
} else {
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
}
if constexpr (!is_split_kv) {
if (n_block < n_blocks - 1) {
__syncthreads();
tKgK.data() = tKgK.data() + kBlockN * k_head_stride;
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
cute::cp_async_fence();
}
}
Tensor rS = convert_type<input_type>(acc_s);
Tensor trQK = smem_thr_copy_QK.retile_S(rS);
cute::copy(smem_tiled_copy_QK, trQK, tsQK);
__syncthreads();
if (remain_q_seq >= kBlockM) {
copy(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK);
} else {
copy<false>(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK, remain_q_seq);
}
if constexpr (!is_split_kv) {
__syncthreads();
tQKgQK.data() = tQKgQK.data() + kBlockN;
}
}
}
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, bool is_split_kv>
void qk_gemm(
const input_type *q_input,
const input_type *k_gate_mean,
input_type *qk_gate_weight,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int bsz,
cudaStream_t stream) {
const int gqa_group_size = head_num / kv_head_num;
dim3 grid_dims;
const int num_m_block = (max_seq_q + kBlockM - 1) / kBlockM;
const int num_n_block = ((max_seq_k + kMobaBlockSize - 1) / kMobaBlockSize + kBlockN - 1) / kBlockN;
if (is_split_kv) {
grid_dims.x = num_n_block;
grid_dims.z = kv_head_num;
} else {
grid_dims.x = num_m_block;
grid_dims.z = head_num;
}
grid_dims.y = bsz;
constexpr int kHeadDim = 128;
constexpr int smemq = kBlockM * kHeadDim * sizeof(input_type);
constexpr int smemk = kBlockN * kHeadDim * sizeof(input_type);
constexpr int smemqk = kBlockM * kBlockN * sizeof(input_type);
const int smem_size = smemk + max(smemq, smemqk);
auto kernel = &qk_gemm_kernel<input_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, kHeadDim, is_split_kv>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid_dims, 128, smem_size, stream>>>(
q_input,
k_gate_mean,
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
gqa_group_size);
}
template <typename T>
std::vector<paddle::Tensor> DispatchMobaQKGemm(
const paddle::Tensor& q_input,
const paddle::Tensor& k_block_means,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const bool is_split_kv,
const int use_moba_seq_limit) {
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
const int batch_size = seq_len_encoder.dims()[0];
using cute_type = typename cuteType<T>::type;
if (is_split_kv) {
paddle::Tensor qk_gate_weight = paddle::empty({batch_size, head_num, kMaxN}, q_input.dtype(), q_input.place());
qk_gemm<cute_type, 16, kMobaBlockSize, kMobaBlockSize, kMaxN, true>(
reinterpret_cast<const cute_type*>(q_input.data<T>()),
reinterpret_cast<const cute_type*>(k_block_means.data<T>()),
reinterpret_cast<cute_type*>(qk_gate_weight.data<T>()),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
batch_size,
q_input.stream()
);
return {qk_gate_weight};
} else {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int token_num = q_input.dims()[0];
paddle::Tensor qk_gate_weight = paddle::empty({token_num, head_num, kMaxN}, q_input.dtype(), q_input.place());
qk_gemm<cute_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, false>(
reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>())),
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
reinterpret_cast<cute_type *>(qk_gate_weight.data<T>()),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
batch_size,
q_input.stream());
return {qk_gate_weight};
}
}
std::vector<paddle::Tensor> MobaQKGemm(
const paddle::Tensor& q_input,
const paddle::Tensor& k_block_means,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const bool is_split_kv,
const int use_moba_seq_limit) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
return std::move(
DispatchMobaQKGemm<phi::dtype::float16>(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
is_split_kv,
use_moba_seq_limit
)
);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
return std::move(
DispatchMobaQKGemm<phi::dtype::bfloat16>(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
is_split_kv,
use_moba_seq_limit
)
);
}
}
PD_BUILD_OP(moba_qk_gemm)
.Inputs({
"q_input",
"k_block_means",
"seq_len_encoder",
"seq_len_decoder",
"cu_seq_q",
"cu_seq_k"})
.Attrs({
"max_seq_q: int",
"max_seq_k: int",
"head_num: int",
"kv_head_num: int",
"is_split_kv: bool",
"use_moba_seq_limit: int"})
.Outputs({"qk_gate_weight"})
.SetKernelFn(PD_KERNEL(MobaQKGemm));

View File

@@ -1,370 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN, int tokens_per_block, bool need_k_mean>
__global__ void fused_block_mean_and_rope_kernel(
const input_type *qkv_input,
const input_type *qkv_bias,
input_type *k_gate_mean,
input_type *q_input,
input_type *k_input,
input_type *v_input,
const float *rope_sin_cos,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int max_input_length) {
constexpr int kPackSize = 16 / sizeof(input_type);
constexpr int kHeadDim = 128;
using src_type = Vec<input_type, kPackSize>;
using rope_type = Vec<float, kPackSize / 2>;
using pack_half = std::conditional_t<std::is_same<input_type, cutlass::half_t>::value, __half2, nv_bfloat162>;
__align__(16) __shared__ input_type local_sum_mem[128 / 32 * kHeadDim];
const int bidb = blockIdx.x;
const int bidh = blockIdx.y;
const int bidt_q = blockIdx.z * tokens_per_block;
const int bidt_v = blockIdx.z * tokens_per_block;
const int bidt_k = need_k_mean ? blockIdx.z * moba_block_size : blockIdx.z * tokens_per_block;
const int tidx = threadIdx.x;
const int lane_id = tidx % 32;
const int warp_id = tidx / 32;
const int seq_len = seq_len_encoder[bidb];
const int seq_len_start = seq_len_decoder[bidb];
if (seq_len == 0) {
return;
}
const int all_head_num = head_num + 2 * kv_head_num;
const int hidden = all_head_num * kHeadDim;
const int row_idx = tidx / (kHeadDim / kPackSize);
const int col_idx = tidx % (kHeadDim / kPackSize);
const int bias_idx = bidh * kHeadDim + col_idx * kPackSize;
src_type src, src_bias;
rope_type sin, cos;
const bool need_add_bias = qkv_bias != nullptr;
if (need_add_bias) {
src_bias.load_from(qkv_bias + bias_idx);
}
if (bidh < head_num) {
const int cur_token = bidt_q + row_idx;
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
if (cur_token < seq_len) {
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
src.store_to(q_input + (cu_seq_q[bidb] + cur_token) * head_num * kHeadDim + bias_idx);
}
} else if (bidh < head_num + kv_head_num) {
if constexpr (!need_k_mean) {
const int cur_token = bidt_k + row_idx;
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
if (cur_token < seq_len) {
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * head_num * kHeadDim + bias_idx- head_num * kHeadDim);
}
} else {
if (bidt_k >= seq_len) {
return;
}
src_type local_sum;
local_sum.set_zero();
const input_type* qkv = qkv_input + cu_seq_q[bidb] * hidden + bias_idx;
for (int i = 0; i < moba_block_size; i += tokens_per_block) {
const int cur_token = bidt_k + i + row_idx;
if (cur_token < seq_len) {
src.load_from(qkv + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - head_num * kHeadDim);
local_sum.add(src);
}
}
src_type neighbor;
#pragma unroll
for (int i = 0; i < kPackSize; i+=2) {
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(local_sum.data.elt + i), 16);
}
local_sum.add(neighbor);
if (lane_id < 16) {
local_sum.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
}
__syncthreads();
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
pack_half local_sum_half = local_sum_mem_half[tidx];
if (tidx < kHeadDim / 2) {
#pragma unroll
for (int i = 1; i < 4; i++) {
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
}
float inv_tokens_sum = fdividef(1.0f, min(seq_len - bidt_k, moba_block_size));
local_sum_half *= float_2_half2<input_type>(inv_tokens_sum);
const int store_mean_idx = ((bidb * kMaxN + blockIdx.z + seq_len_start / moba_block_size) * kv_head_num * kHeadDim + (bidh - head_num) * kHeadDim) / 2 + tidx;
reinterpret_cast<pack_half*>(k_gate_mean)[store_mean_idx] = local_sum_half;
}
}
} else {
const int cur_token = bidt_v + row_idx;
if (cur_token < seq_len) {
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
src.store_to(v_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - (head_num + kv_head_num) * kHeadDim);
}
}
}
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN>
void fused_block_mean_and_rope(
const input_type *qkv_input,
const input_type *qkv_bias,
input_type *k_gate_mean,
input_type *q_input,
input_type *k_input,
input_type *v_input,
const float *rope_sin_cos,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int bsz,
const int max_input_length,
cudaStream_t stream) {
static_assert(moba_block_size >= 64, "moba_block_size must be at least 64");
constexpr int kPackSize = 16 / sizeof(input_type);
constexpr int kHeadDim = 128;
constexpr int kThreads = 128;
constexpr int tokens_per_block = kThreads / (kHeadDim / kPackSize);
dim3 grid_dims;
grid_dims.x = bsz;
grid_dims.y = head_num + 2 * kv_head_num;
grid_dims.z = (max_seq_q + tokens_per_block - 1) / tokens_per_block;
if (k_gate_mean != nullptr) {
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, true>
<<<grid_dims, kThreads, 0, stream>>>(
qkv_input,
qkv_bias,
k_gate_mean,
q_input,
k_input,
v_input,
rope_sin_cos,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
max_input_length);
} else {
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, false>
<<<grid_dims, kThreads, 0, stream>>>(
qkv_input,
qkv_bias,
k_gate_mean,
q_input,
k_input,
v_input,
rope_sin_cos,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
max_input_length);
}
}
void FusedBlockMeanAndRope(
const paddle::Tensor& qkv_out,
const paddle::Tensor& k_block_means,
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::optional<paddle::Tensor>& qkv_bias,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
using cute_type = typename cuteType<T>::type;
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
rotary_embs.data<float>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
seq_len_encoder.dims()[0],
max_input_length,
qkv_out.stream());
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
using cute_type = typename cuteType<T>::type;
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
rotary_embs.data<float>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
seq_len_encoder.dims()[0],
max_input_length,
qkv_out.stream());
}
}
PD_BUILD_OP(fused_block_mean_and_rope)
.Inputs({
"qkv_out",
"k_block_means",
"q_input",
"k_input",
"v_input",
"rotary_embs",
"seq_len_encoder",
"seq_len_decoder",
"cu_seq_q",
"cu_seq_k",
paddle::Optional("qkv_bias")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_input_length: int",
"max_seq_q: int",
"max_seq_k: int",
"cache_quant_type_str: std::string"})
.Outputs({"q_input_out", "k_input_out", "v_input_out", "k_block_means_out"})
.SetInplaceMap({{"q_input", "q_input_out"},
{"k_input", "k_input_out"},
{"v_input", "v_input_out"},
{"k_block_means", "k_block_means_out"}})
.SetKernelFn(PD_KERNEL(FusedBlockMeanAndRope));

View File

@@ -385,7 +385,6 @@ elif paddle.is_compiled_with_cuda():
"-Igpu_ops",
"-Ithird_party/nlohmann_json/include",
]
nvcc_version = get_nvcc_version()
print(f"nvcc_version = {nvcc_version}")
if nvcc_version >= 12.0:
@@ -509,10 +508,6 @@ elif paddle.is_compiled_with_cuda():
# Hopper optmized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
sources += find_end_files("gpu_ops/moba_attn/moba_decoder_attn/", ".cu")
sources += find_end_files("gpu_ops/moba_attn/moba_encoder_attn/", ".cu")
sources += find_end_files("gpu_ops/moba_attn/moba_process/", ".cu")
sources += ["gpu_ops/moba_attn/moba_attn.cu"]
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py")

View File

@@ -1,31 +0,0 @@
# moba_sparse_attention
## Introduction
We propose Lite MoBA and improve it based on MoBA. Specifically, we still draw on the MoE structure to divide KV into multiple blocks, introduce a learnable MLP layer to adaptively select important blocks. We use Full Attention's 1D Max Pooling Attention Map as Ground Truth. Then, we employ KLDivLoss to distill and train the MLP layer weights. Lite MoBA can be directly applied to post - training, where only the weights of the MLP are learnable and the weights of the original model remain unchanged.
Compared to NSA or MoBA, our Lite MoBA is more scalable and pluggable, without the need to change traditional attention architectures or interfere with model weight training in the Pre - training and Post - training stages. It only requires a small amount of training on the MLP layer in the final stage of the model to achieve almost lossless accuracy. Since MoBA updates the weights of the entire model, even when Full Attention is automatically invoked for inputs shorter than BlockSize x BlockNum, it still cannot avoid the impact of model updates on the model's effectiveness in text processing. In contrast, our pluggable Lite MoBA can achieve Full Attention that is truly equivalent to that of the original model in short text scenarios.
Compared with MoBA, in terms of effectiveness, its use of Average Pooling to represent inter - block relationships appears relatively limited and has poor handling of outlier representations. Our ablation experiments also demonstrated that the effectiveness of Average Pooling is inferior to that of the learnable MLP. In terms of training performance, since only the MLP weights need to be updated and the model weights do not need to be updated, a large amount of video memory will be saved during training (which needs to be tested). In terms of inference performance, when the input length is 128K, Block Size = 1024, and Block Num = 16, the performance is improved by 322% compared to Flash Attention 3.
## Usage
```bash
export FD_ATTENTION_BACKEND="MOBA_ATTN"
python -m fastdeploy.entrypoints.openai.api_server
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 4 \
--quantization wint4 \
--enable-chunked-prefill \
--max-num-batched-tokens 8192 \
--max-model-len 131072 \
--max-num-seqs 32 \
--moba-attention-config '{"moba_encoder_top_k_left": 60, "moba_encoder_top_k_right": 80, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
```
## Environmental Variables Description
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention.
* `moba_encoder_top_k_left=60, moba_encoder_top_k_right=80` indicates that the range of top - k is between 80 and 100 when the encoder is sparse.
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=100` indicates that the range of top - k is between 120 and 140 when the decoder is sparse.

View File

@@ -682,67 +682,6 @@ class GraphOptimizationConfig:
argument = self.use_cudagraph
class MobaAttentionConfig:
def __init__(
self,
args,
):
self.moba_encoder_top_k_left: int = None
self.moba_encoder_top_k_right: int = None
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
self.moba_decoder_top_k_left: int = None
self.moba_decoder_top_k_right: int = None
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
self.moba_use_encoder_seq_limit: int = None
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse"
self.moba_use_decoder_seq_limit: int = None
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse"
self.moba_block_size: int = 128
self.mlp_weight_name: str = "moba_mlp_weight.safetensors"
self.moba_max_seq_length: int = 128 * 1024
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None:
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None:
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size
self.check_legality_parameters()
def check_legality_parameters(
self,
) -> None:
if self.moba_encoder_top_k_left is not None:
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0"
if self.moba_encoder_top_k_right is not None:
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0"
assert (
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left"
if self.moba_decoder_top_k_left is not None:
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0"
if self.moba_decoder_top_k_right is not None:
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0"
assert (
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left"
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None:
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None:
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size
def to_json_string(self):
"""
Convert moba_attention_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
class EarlyStopConfig:
def __init__(
self,
@@ -1097,7 +1036,6 @@ class FDConfig:
decoding_config: DecodingConfig = None,
quant_config: QuantConfigBase = None,
graph_opt_config: GraphOptimizationConfig = None,
moba_attention_config: MobaAttentionConfig = None,
speculative_config: SpeculativeConfig = None,
tokenizer: str = None,
max_model_len: int = 8192,
@@ -1132,7 +1070,7 @@ class FDConfig:
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
self.decoding_config: DecodingConfig = decoding_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)

View File

@@ -26,7 +26,6 @@ from fastdeploy.config import (
FDConfig,
GraphOptimizationConfig,
LoadConfig,
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
SpeculativeConfig,
@@ -337,10 +336,6 @@ class EngineArgs:
"""
Configuration for graph optimization backend execution.
"""
moba_attention_config: Optional[Dict[str, Any]] = None
"""
Configuration for moba attention.
"""
enable_logprob: bool = False
"""
@@ -539,12 +534,6 @@ class EngineArgs:
default=EngineArgs.graph_optimization_config,
help="",
)
model_group.add_argument(
"--moba-attention-config",
type=json.loads,
default=EngineArgs.moba_attention_config,
help="",
)
model_group.add_argument(
"--guided-decoding-backend",
type=str,
@@ -940,18 +929,6 @@ class EngineArgs:
graph_optimization_args[k] = v
return GraphOptimizationConfig(graph_optimization_args)
def create_moba_attention_config(self) -> MobaAttentionConfig:
"""
Create and retuan a MobaAttentionConfig object based on the current settings.
"""
attention_args = asdict(self)
if self.moba_attention_config is not None:
for k, v in self.moba_attention_config.items():
attention_args[k] = v
return MobaAttentionConfig(attention_args)
else:
return MobaAttentionConfig(None)
def create_early_stop_config(self) -> EarlyStopConfig:
"""
Create and retuan an EarlyStopConfig object based on the current settings.
@@ -989,7 +966,6 @@ class EngineArgs:
speculative_cfg = self.create_speculative_config()
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
moba_attention_config = self.create_moba_attention_config()
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
@@ -1027,7 +1003,6 @@ class EngineArgs:
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
graph_opt_config=graph_opt_cfg,
moba_attention_config=moba_attention_config,
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
early_stop_config=early_stop_cfg,

View File

@@ -464,7 +464,6 @@ class LLMEngine:
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --ips {ips}"
)

View File

@@ -20,7 +20,6 @@ from .block_multihead_attn_backend import BlockAttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import MobaAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend
@@ -35,5 +34,4 @@ __all__ = [
"IluvatarAttnBackend",
"BlockAttentionBackend",
"Attention",
"MobaAttentionBackend",
]

View File

@@ -28,11 +28,6 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
import os
from safetensors import safe_open
from fastdeploy.model_executor.layers.utils import get_tensor
@@ -118,42 +113,6 @@ class Attention(nn.Layer):
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()
if fd_config.moba_attention_config is not None:
mlp_weight_path = os.path.join(
fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name
)
self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
moba_block_size = fd_config.moba_attention_config.moba_block_size
moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
if self.moba_use_mlp:
mlp_weight = {}
with safe_open(mlp_weight_path, framework="np", device="cpu") as f:
for key_name in f.keys():
weight = f.get_tensor(key_name)
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
mlp_weight[key_name] = weight
if self.layer_id < fd_config.model_config.num_hidden_layers - 1:
self.attn_gate_weight = mlp_weight[
f"ernie.layers.{self.layer_id}.self_attn.attn_gate.weight"
].astype(paddle.get_default_dtype())[
fd_config.parallel_config.tensor_parallel_rank
* self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1)
* self.kv_num_heads
]
assert self.attn_gate_weight.shape[1] % moba_block_size == 0
self.cache_k_block_means = paddle.zeros(
[
fd_config.parallel_config.max_num_seqs,
moba_max_seq_length // moba_block_size,
self.kv_num_heads,
self.head_dim,
],
dtype=paddle.get_default_dtype(),
)
def init_weight(self):
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],

View File

@@ -1,198 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import paddle
try:
from fastdeploy.model_executor.ops.gpu import get_cur_cu_seq_len_k, moba_attention
except:
moba_attention = None
get_cur_cu_seq_len_k = None
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
@dataclass
class MobaAttentionMetadata(AttentionMetadata):
"""
AppendAttentionMetadata
"""
q_input: paddle.Tensor = None
k_input: paddle.Tensor = None
v_input: paddle.Tensor = None
cu_seq_q_pack: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
q_pack_tokens: paddle.Tensor = None
max_enc_len_this_time: int = 0
max_dec_len_this_time: int = 0
class MobaAttentionBackend(AttentionBackend):
"""
The backend class that uses paddle native attention implementation.
Which is used only for testing purpose.
"""
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None:
"""
MobaAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: MobaAttentionMetadata = None
assert fd_config.moba_attention_config is not None, "moba_attention_config is None"
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.attn_block_m = 128
self.moba_block_size = fd_config.moba_attention_config.moba_block_size
self.moba_encoder_top_k_left = int(fd_config.moba_attention_config.moba_encoder_top_k_left)
self.moba_encoder_top_k_right = int(fd_config.moba_attention_config.moba_encoder_top_k_right)
self.moba_use_encoder_seq_limit = int(fd_config.moba_attention_config.moba_use_encoder_seq_limit)
self.moba_decoder_top_k_left = int(fd_config.moba_attention_config.moba_decoder_top_k_left)
self.moba_decoder_top_k_right = int(fd_config.moba_attention_config.moba_decoder_top_k_right)
self.moba_use_decoder_seq_limit = int(fd_config.moba_attention_config.moba_use_decoder_seq_limit)
self.moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Init the metadata for a forward pass."""
metadata = MobaAttentionMetadata()
metadata._dtype = paddle.get_default_dtype()
metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
int(self.attn_block_m),
)
metadata.max_enc_len_this_time = forward_meta.seq_lens_encoder.max().cpu()
metadata.max_dec_len_this_time = forward_meta.seq_lens_decoder.max().cpu()
q_token_num = int(forward_meta.cu_seqlens_q[-1])
k_token_num = int(metadata.cu_seqlens_k[-1])
metadata.q_input = paddle.zeros(
[q_token_num + self.attn_block_m, self.num_heads * self.head_dim], dtype=metadata._dtype
)
metadata.k_input = paddle.zeros(
[k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype
)
metadata.v_input = paddle.zeros(
[k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype
)
self.attention_metadata = metadata
assert self.max_seq_len <= self.moba_max_seq_length
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
):
"""
Caculate kv cache shape
"""
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim,
)
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Mixed模式的前向传播
"""
attention_metadata = self.attention_metadata
out = moba_attention(
qkv,
attention_metadata.q_input,
attention_metadata.k_input,
attention_metadata.v_input,
forward_meta.cu_seqlens_q,
attention_metadata.cu_seqlens_k,
attention_metadata.cu_seq_q_pack,
attention_metadata.q_pack_tokens,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.block_tables,
forward_meta.rotary_embs,
layer.cache_k_block_means,
getattr(layer, "attn_gate_weight", None),
layer.qkv_bias,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
self.num_heads,
self.kv_num_heads,
self.head_dim,
self.max_seq_len,
attention_metadata.max_enc_len_this_time,
attention_metadata.max_dec_len_this_time,
self.moba_encoder_top_k_left,
self.moba_encoder_top_k_right,
self.moba_use_encoder_seq_limit,
self.moba_decoder_top_k_left,
self.moba_decoder_top_k_right,
self.moba_use_decoder_seq_limit,
layer.moba_use_mlp,
getattr(layer, "cache_quant_type_str", "none"),
)[0]
return out

View File

@@ -26,7 +26,6 @@ class _Backend(enum.Enum):
MLA_ATTN = enum.auto()
FLASH_ATTN = enum.auto()
BLOCK_ATTN = enum.auto()
MOBA_ATTN = enum.auto()
class Platform:

View File

@@ -64,9 +64,6 @@ class CUDAPlatform(Platform):
elif selected_backend == _Backend.FLASH_ATTN:
logger.info("Using FLASH ATTN backend.")
return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
elif selected_backend == _Backend.MOBA_ATTN:
logger.info("Using MOBA ATTN backend.")
return "fastdeploy.model_executor.layers.attention.MobaAttentionBackend"
else:
raise ValueError(
"Invalid attention backend you specified.\n"

View File

@@ -59,7 +59,6 @@ class RolloutModelConfig:
graph_optimization_config: str = None,
early_stop_config: str = None,
local_rank: int = 0,
moba_attention_config: str = None,
):
# Required parameters
self.model = model_name_or_path
@@ -104,7 +103,6 @@ class RolloutModelConfig:
self.local_rank = local_rank
self.early_stop_config = early_stop_config
self.ips = None
self.moba_attention_config = moba_attention_config
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -34,7 +34,6 @@ from fastdeploy.config import (
FDConfig,
GraphOptimizationConfig,
LoadConfig,
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
SpeculativeConfig,
@@ -542,12 +541,6 @@ def parse_args():
default=None,
help="Configation of Graph optimization backend.",
)
parser.add_argument(
"--moba_attention_config",
type=json.loads,
default=None,
help="Configation of moba attention.",
)
parser.add_argument(
"--guided_decoding_backend",
type=str,
@@ -653,8 +646,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
moba_attention_config = MobaAttentionConfig(args.moba_attention_config)
early_stop_config = EarlyStopConfig(args.early_stop_config)
# Note(tangbinhan): used for load_checkpoint
@@ -734,7 +725,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
cache_config=cache_config,
engine_worker_queue_port=args.engine_worker_queue_port,
ips=args.ips,
moba_attention_config=moba_attention_config,
)
update_fd_config_for_mm(fd_config)

View File

@@ -1,340 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
try:
from fastdeploy.model_executor.ops.gpu import (
fused_block_mean_and_rope,
get_cur_cu_seq_len_k,
moba_encoder_attn,
moba_mlp_einsum,
moba_qk_gemm,
moba_qk_sort_encoder,
)
except:
moba_attention = None
get_cur_cu_seq_len_k = None
import unittest
import numpy as np
def naive_attn(q_input, k_input, v_input, mask):
gqa_group_size = q_input.shape[2] // k_input.shape[2]
q_cur = q_input.transpose([0, 2, 1, 3])
k_cur = k_input.transpose([0, 2, 1, 3])
v_cur = v_input.transpose([0, 2, 1, 3])
out = paddle.zeros(q_cur.shape, dtype=q_input.dtype)
for bsz in range(0, q_cur.shape[0]):
for hi in range(0, q_cur.shape[1]):
qk = paddle.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3]))
qk += mask
qk_max = qk.max(axis=-1).unsqueeze(-1)
qk -= qk_max
qk = qk.exp()
exp_sum = qk.sum(axis=-1).unsqueeze(-1)
exp_sum_inv = 1.0 / exp_sum
out[bsz, hi] = (paddle.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
return out
class TestMobaAttention(unittest.TestCase):
def setUp(self):
paddle.seed(0)
self.seq_len = int(8 * 1024)
self.num_heads = int(8)
self.num_kv_heads = int(1)
self.head_dim = int(128)
self.max_num_seqs = 1
self.moba_max_seq_length = int(128 * 1024)
self.moba_block_size = int(128)
self.moba_encoder_top_k_left = 2
self.moba_encoder_top_k_right = 3
self.moba_use_encoder_seq_limit = int(4 * 1024)
self.cache_k_block_means = paddle.zeros(
[
self.max_num_seqs,
self.moba_max_seq_length // self.moba_block_size,
self.num_kv_heads,
self.head_dim,
],
dtype="bfloat16",
)
self.attn_block_m = 128
self.tokens = self.seq_len * self.max_num_seqs
self.q_input = paddle.zeros(
[self.tokens + self.attn_block_m, self.num_heads, self.head_dim],
dtype="bfloat16",
)
self.k_input = paddle.zeros(
[self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim],
dtype="bfloat16",
)
self.v_input = paddle.zeros(
[self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim],
dtype="bfloat16",
)
self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32")
self.attn_gate_weight = paddle.randn(
[self.num_kv_heads, self.moba_block_size, self.head_dim], dtype="bfloat16"
)
self.gqa_group_size = self.num_heads // self.num_kv_heads
self.num_blocks = (self.seq_len + self.moba_block_size - 1) // self.moba_block_size
self.sparse_step = 4
def compare_split_qkv_rope(self, qkv_out):
assert (qkv_out[:, 0 : self.num_heads, :] - self.q_input[0 : self.tokens]).abs().max() < 1e-3
assert (
qkv_out[:, self.num_heads : self.num_heads + self.num_kv_heads, :] - self.k_input[0 : self.tokens]
).abs().max() < 1e-3
assert (qkv_out[:, self.num_heads + self.num_kv_heads :, :] - self.v_input[0 : self.tokens]).abs().max() < 1e-3
for i in range(self.max_num_seqs):
k_padding = paddle.zeros(
[
(self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size,
self.num_kv_heads,
self.head_dim,
],
dtype="bfloat16",
)
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
real_k_block_means = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim])
real_k_block_means = real_k_block_means.mean(axis=1)
compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]]
assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003
print("[consistency]Moba attention: split_qkv_rope matches.")
def compare_mlp_einsum(self, k_gate_weight):
for i in range(self.max_num_seqs):
k_padding = paddle.zeros(
[
(self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size,
self.num_kv_heads,
self.head_dim,
],
dtype="bfloat16",
)
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
k_padding = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim])
real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight)
compute_result = k_gate_weight[i][0 : real_result.shape[0]]
assert (real_result - compute_result).abs().max() < 0.5
print("[consistency]Moba attention: MLP einsum matches.")
def compare_qk_gemm(self, qk_gate_weight):
for i in range(self.max_num_seqs):
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len]
k_input_mean = self.cache_k_block_means[i][0 : self.num_blocks]
qk_gemm_out = paddle.zeros(
[
self.seq_len,
self.num_heads,
self.num_blocks,
],
dtype="bfloat16",
)
for j in range(self.num_heads):
qk_gemm_out[:, j, :] = paddle.matmul(
q_input[:, j, :], k_input_mean[:, j // self.gqa_group_size, :], transpose_y=True
)
conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks]
assert (qk_gemm_out - conpute_result).abs().max() < 1e-4
print("[consistency]Moba attention: qk_gemm matches.")
def compare_qk_gate_topk(self, qk_gate_topk_idx):
limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size
for i in range(self.max_num_seqs):
qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks]
qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1]
assert (
qk_gate_topk_idx_batch_no_sparse
- paddle.ones(qk_gate_topk_idx_batch_no_sparse.shape, qk_gate_topk_idx_batch_no_sparse.dtype)
).abs().max() < 1e-6
for j in range(limit_topk, self.num_blocks):
qk_gate_topk_idx_batch_sparse = qk_gate_topk_idx_batch[j, :, 1 : (j + 1) // self.sparse_step]
assert (
qk_gate_topk_idx_batch_sparse
- paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype)
* self.sparse_step
).abs().max() < 1e-6
print("[consistency]Moba attention: qk_gate_topk matches.")
def compare_attn(self, attn_out, qk_gate_topk_idx):
x = (
paddle.tensor.triu(paddle.ones([self.moba_block_size, self.moba_block_size], dtype="bfloat16"), 1)
* -1000000
)
limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size
for i in range(self.max_num_seqs):
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000
mask[self.moba_use_encoder_seq_limit - self.moba_block_size :] = -1000000
for i in range(limit_topk - 1, self.num_blocks):
n_block = i
mask[
i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size,
n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size,
] = x
idx = 0
n_block -= int(qk_gate_topk_idx[i, 0, idx])
idx += 1
while n_block >= 0:
mask[
i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size,
n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size,
] = 0
n_block -= int(qk_gate_topk_idx[i, 0, idx])
idx += 1
naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2])
assert (attn_out - naive_attn_out).abs().max() < 0.016
def test_moba_attention(self):
qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16")
seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
seq_len_decoder = paddle.to_tensor([0] * self.max_num_seqs, dtype="int32")
cu_seq_q = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len
cu_seq_k = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len
seq_lens_this_time = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
cu_seq_q_pack, cu_seqlens_k, q_pack_tokens = get_cur_cu_seq_len_k(
seq_len_encoder,
seq_len_decoder,
seq_lens_this_time,
int(self.attn_block_m),
)
fused_block_mean_and_rope(
qkv_out,
self.cache_k_block_means,
self.q_input,
self.k_input,
self.v_input,
self.rotary_embs,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
None,
self.num_heads,
self.num_kv_heads,
self.head_dim,
self.moba_max_seq_length,
self.seq_len,
self.seq_len,
"none",
)
self.compare_split_qkv_rope(qkv_out)
k_gate_weight = moba_mlp_einsum(
self.k_input,
self.attn_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_k,
self.seq_len,
self.num_kv_heads,
)
self.compare_mlp_einsum(k_gate_weight)
qk_gate_weight = moba_qk_gemm(
self.q_input,
self.cache_k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
self.seq_len,
self.seq_len,
self.num_heads,
self.num_kv_heads,
False,
self.max_num_seqs,
)
self.compare_qk_gemm(qk_gate_weight)
for i in range(0, self.num_blocks, self.sparse_step):
qk_gate_weight[:, :, i] = 100
qk_gate_topk_idx = moba_qk_sort_encoder(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
self.seq_len,
self.seq_len,
self.num_heads,
self.num_kv_heads,
self.moba_encoder_top_k_left,
self.moba_encoder_top_k_right,
self.moba_use_encoder_seq_limit,
)
self.compare_qk_gate_topk(qk_gate_topk_idx)
attn_out = paddle.zeros([self.tokens, self.num_heads, self.head_dim], dtype="bfloat16")
moba_encoder_attn(
self.q_input,
self.k_input,
self.v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
attn_out,
self.seq_len,
self.seq_len,
self.num_heads,
self.num_kv_heads,
self.head_dim,
self.moba_max_seq_length,
)
self.compare_attn(attn_out, qk_gate_topk_idx)
if __name__ == "__main__":
if paddle.is_compiled_with_cuda():
unittest.main()