mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] block sparse attention (#3668)
* 支持稀疏attn * fix bug * code style * fix moba attn get kv shape * 修复a100编译 * codestyle * code style * code style * code style * fix conflict * 增加单侧 * code style * 增加eblite 加载时间 * fix bug * for ci * for ci * for ci * for ci * 支持mlp block size 128 * 增加小算子单测 * fix 单测 mlp * 将环境变量加入到config里面 * fix rollout config * 修复显存 * add test server * add test server * fix mlp 最后一层使用full attn
This commit is contained in:
@@ -845,15 +845,15 @@ void SpeculateStepPaddle(
|
||||
const int max_draft_tokens);
|
||||
|
||||
void MergePrefillDecodeOutput(
|
||||
const paddle::Tensor &encoder_res,
|
||||
const paddle::Tensor &decoder_res,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seq_q,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int max_token);
|
||||
const paddle::Tensor &encoder_res,
|
||||
const paddle::Tensor &decoder_res,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seq_q,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int max_token);
|
||||
|
||||
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_p,
|
||||
|
||||
330
custom_ops/gpu_ops/moba_attn/moba_attn.cu
Normal file
330
custom_ops/gpu_ops/moba_attn/moba_attn.cu
Normal file
@@ -0,0 +1,330 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn.h"
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MobaAttention(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& rope_sin_cos,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::optional<paddle::Tensor>& attn_gate_weight,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time,
|
||||
const int moba_encoder_top_k_left,
|
||||
const int moba_encoder_top_k_right,
|
||||
const int moba_use_encoder_seq_limit,
|
||||
const int moba_decoder_top_k_left,
|
||||
const int moba_decoder_top_k_right,
|
||||
const int moba_use_decoder_seq_limit,
|
||||
const bool moba_use_mlp,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
paddle::Tensor out = paddle::empty({qkv.dims()[0], head_num * head_dim}, qkv.dtype(), qkv.place());
|
||||
if (max_dec_len_this_time > 0) {
|
||||
MobaDecoderAttnWriteCacheKv(
|
||||
qkv,
|
||||
q_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
rope_sin_cos,
|
||||
k_block_means,
|
||||
qkv_bias,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
cache_quant_type_str);
|
||||
|
||||
auto qk_gate_weight = MobaQKGemm(
|
||||
q_input,
|
||||
k_block_means,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_dec_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
true,
|
||||
moba_use_decoder_seq_limit
|
||||
)[0];
|
||||
|
||||
auto qk_gate_topk_idx = QkSortDecoder(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
moba_decoder_top_k_left,
|
||||
moba_decoder_top_k_right,
|
||||
moba_use_decoder_seq_limit
|
||||
)[0];
|
||||
|
||||
MobaDecoderAttn(
|
||||
q_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
k_block_means,
|
||||
out,
|
||||
qk_gate_topk_idx,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
moba_use_decoder_seq_limit,
|
||||
max_dec_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
}
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
FusedBlockMeanAndRope(
|
||||
qkv,
|
||||
k_block_means,
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
rope_sin_cos,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
qkv_bias,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
|
||||
MobaEncoderAttnWriteCacheKv(
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_enc_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
|
||||
GetKVFromCache(
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
|
||||
paddle::Tensor *k_gate_weight = const_cast<paddle::Tensor*>(&k_block_means);
|
||||
|
||||
if (moba_use_mlp && attn_gate_weight) {
|
||||
paddle::Tensor k_gate_mlp = MobaMlpEinsum(
|
||||
k_input,
|
||||
attn_gate_weight.get(),
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_k,
|
||||
max_seq_len,
|
||||
kv_head_num
|
||||
)[0];
|
||||
k_gate_weight = &k_gate_mlp;
|
||||
}
|
||||
|
||||
auto qk_gate_weight = MobaQKGemm(
|
||||
q_input,
|
||||
*k_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
false,
|
||||
moba_use_encoder_seq_limit
|
||||
)[0];
|
||||
|
||||
|
||||
auto qk_gate_topk_idx = QkSortEncoder(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
q_pack_tokens,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
moba_encoder_top_k_left,
|
||||
moba_encoder_top_k_right,
|
||||
moba_use_mlp && !attn_gate_weight ? max_seq_len : moba_use_encoder_seq_limit)[0];
|
||||
|
||||
MobaEncoderAttn(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
qk_gate_topk_idx,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
out,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len
|
||||
);
|
||||
}
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moba_attention)
|
||||
.Inputs({
|
||||
"qkv",
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"cu_seq_q_pack",
|
||||
"q_pack_tokens",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"block_tables",
|
||||
"rope_sin_cos",
|
||||
"k_block_means",
|
||||
paddle::Optional("attn_gate_weight"),
|
||||
paddle::Optional("qkv_bias"),
|
||||
paddle::Optional("cache_k_quant_scale"),
|
||||
paddle::Optional("cache_v_quant_scale"),
|
||||
paddle::Optional("cache_k_dequant_scale"),
|
||||
paddle::Optional("cache_v_dequant_scale"),
|
||||
paddle::Optional("cache_k_zero_points"),
|
||||
paddle::Optional("cache_v_zero_points")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_len: int",
|
||||
"max_enc_len_this_time: int",
|
||||
"max_dec_len_this_time: int",
|
||||
"moba_encoder_top_k_left: int",
|
||||
"moba_encoder_top_k_right: int",
|
||||
"moba_use_encoder_seq_limit: int",
|
||||
"moba_decoder_top_k_left: int",
|
||||
"moba_decoder_top_k_right: int",
|
||||
"moba_use_decoder_seq_limit: int",
|
||||
"moba_use_mlp: bool",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({
|
||||
"out",
|
||||
"q_input_out",
|
||||
"k_input_out",
|
||||
"v_input_out",
|
||||
"key_cache_out",
|
||||
"value_cache_out",
|
||||
"k_block_means_out"})
|
||||
.SetInplaceMap({{
|
||||
"q_input", "q_input_out"},
|
||||
{"k_input", "k_input_out"},
|
||||
{"v_input", "v_input_out"},
|
||||
{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"},
|
||||
{"k_block_means", "k_block_means_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MobaAttention));
|
||||
204
custom_ops/gpu_ops/moba_attn/moba_attn.h
Normal file
204
custom_ops/gpu_ops/moba_attn/moba_attn.h
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void MobaDecoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& rope_sin_cos,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
void MobaEncoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_q,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
void MobaDecoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
|
||||
void FusedBlockMeanAndRope(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
std::vector<paddle::Tensor> GetCurCuSeqLenk(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const int pack_size);
|
||||
|
||||
std::vector<paddle::Tensor> MobaQKGemm(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const bool is_split_kv,
|
||||
const int use_moba_seq_limit);
|
||||
|
||||
std::vector<paddle::Tensor> QkSortDecoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit);
|
||||
|
||||
void GetKVFromCache(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
|
||||
void MobaEncoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& out,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length);
|
||||
|
||||
std::vector<paddle::Tensor> QkSortEncoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit);
|
||||
|
||||
std::vector<paddle::Tensor> MobaMlpEinsum(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& attn_gate_weight,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_len,
|
||||
const int kv_head_num);
|
||||
748
custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp
Normal file
748
custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp
Normal file
@@ -0,0 +1,748 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/int_tuple.hpp"
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cub/cub.cuh>
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<phi::dtype::float16> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<phi::dtype::bfloat16> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct HalfSub;
|
||||
|
||||
template<>
|
||||
struct HalfSub<cutlass::half_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfSub<cutlass::bfloat16_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
*reinterpret_cast<nv_bfloat162*>(result_ptr) -= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct HalfMul;
|
||||
|
||||
template<>
|
||||
struct HalfMul<cutlass::half_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMul<cutlass::bfloat16_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
*reinterpret_cast<nv_bfloat162*>(result_ptr) *= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct HalfMax;
|
||||
template<>
|
||||
struct HalfMax<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMax<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct HalfMin;
|
||||
template<>
|
||||
struct HalfMin<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("min.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMin<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("min.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct MinOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinOp<float> {
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<typename T, bool Is_K>
|
||||
inline __device__ static void convert_c8_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
|
||||
uint32_t* half_result_ptr = reinterpret_cast<uint32_t*>(dst);
|
||||
if constexpr (std::is_same_v<T, cutlass::bfloat16_t>) {
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
float fp32_intermediates[4];
|
||||
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(*src, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(*src, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(*src, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(*src, fp32_base, 0x7653);
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
fp32_intermediates[ii] -= 8388608.f;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
half_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
||||
}
|
||||
} else {
|
||||
static constexpr uint32_t head_for_fp16 = 0x64006400;
|
||||
half_result_ptr[0] = __byte_perm(*src, head_for_fp16, 0x7150);
|
||||
half_result_ptr[1] = __byte_perm(*src, head_for_fp16, 0x7352);
|
||||
}
|
||||
|
||||
using pack_half = typename PackedHalf<T>::Type;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++){
|
||||
if constexpr (Is_K) {
|
||||
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + i * 2));
|
||||
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + i * 2));
|
||||
} else {
|
||||
pack_half bias;
|
||||
pack_half scale;
|
||||
bias.x = cache_zp[0];
|
||||
bias.y = cache_zp[0];
|
||||
scale.x = cache_scale[0];
|
||||
scale.y = cache_scale[0];
|
||||
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
|
||||
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, bool Is_K>
|
||||
inline __device__ static void convert_c4_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
|
||||
using pack_half = typename PackedHalf<T>::Type;
|
||||
static constexpr uint32_t MASK = 0x0f0f0f0f;
|
||||
static constexpr uint32_t head_for_fp16 = std::is_same_v<T, cutlass::bfloat16_t> ? 0x43004300 : 0x64006400;
|
||||
static constexpr uint32_t mask_for_c42fp16_one = 0x7253;
|
||||
static constexpr uint32_t mask_for_c42fp16_two = 0x7051;
|
||||
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(dst);
|
||||
uint32_t source = *reinterpret_cast<uint32_t const*>(src);
|
||||
// source = {e0 e4 e1 e5 e2 e6 e3 e7}
|
||||
uint32_t bottom_i4s = source & MASK;
|
||||
// bottom_i4s = {0 e4 0 e5 0 e6 0 e7}
|
||||
uint32_t top_i4s = (source >> 4) & MASK;
|
||||
// top_i4s = {0 e0 0 e1 0 e2 0 e3}
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[0]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
|
||||
// result_ptr[0] = {e0 e1}
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[1]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[2]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[3]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if constexpr (Is_K) {
|
||||
const int ith_col = i % 2 * 2;
|
||||
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + ith_col));
|
||||
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + ith_col));
|
||||
} else {
|
||||
const int ith_col = i / 2;
|
||||
pack_half bias;
|
||||
pack_half scale;
|
||||
bias.x = cache_zp[ith_col];
|
||||
bias.y = cache_zp[ith_col];
|
||||
scale.x = cache_scale[ith_col];
|
||||
scale.y = cache_scale[ith_col];
|
||||
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
|
||||
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
|
||||
inline __device__ void gemm_qk_quant(
|
||||
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
|
||||
Tensor4 const& sB, TiledMma tiled_mma,
|
||||
ThrCopy0 smem_thr_copy_A,
|
||||
TiledCopy0 smem_tiled_copy_A,
|
||||
const int32_t tidx,
|
||||
const T * cache_scale, const T * cache_zp) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||
}
|
||||
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (kDataNumPer2Byte / 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
|
||||
}
|
||||
}
|
||||
if constexpr (kDataNumPer2Byte == 4) {
|
||||
convert_c4_2_half<T, true>(sBdata + i * kHeadDim, tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
|
||||
} else {
|
||||
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2), tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
|
||||
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2) + 1, tCrB.data() + 4, cache_scale + i * 4, cache_zp + i * 4);
|
||||
}
|
||||
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
|
||||
inline __device__ void gemm_value_quant(
|
||||
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
|
||||
Tensor4 const& sB, TiledMma tiled_mma,
|
||||
ThrCopy0 smem_thr_copy_A,
|
||||
TiledCopy0 smem_tiled_copy_A,
|
||||
int32_t tidx,
|
||||
const T * cache_scale, const T * cache_zp) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||
}
|
||||
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (2 * kDataNumPer2Byte / 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
const int cur_idx = i * kHeadDim * (2 * kDataNumPer2Byte / 4);
|
||||
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
|
||||
}
|
||||
}
|
||||
if constexpr (kDataNumPer2Byte == 4) {
|
||||
convert_c4_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
|
||||
convert_c4_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
|
||||
} else {
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 4, cache_scale + 1, cache_zp + 1);
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx + 2, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx + 3, tCrB.data() + 12, cache_scale + 3, cache_zp + 3);
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kMiLen, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &scores, const uint32_t warp_id, const uint32_t col, const uint32_t reamin_seq_len) {
|
||||
const int cols = size<1>(scores) / 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < cols; ++ni) {
|
||||
const int col_index = warp_id * 8 + ni * 32 + col * 2;
|
||||
if (col_index >= reamin_seq_len) {
|
||||
scores(mi, ni * 2) = -INFINITY;
|
||||
}
|
||||
if (col_index + 1 >= reamin_seq_len) {
|
||||
scores(mi, ni * 2 + 1) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template<int kMiLen, typename Engine0, typename Layout0, typename T>
|
||||
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, T *scores_max){
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
MaxOp<T> max_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ni++) {
|
||||
scores_max[mi] = max_op(scores_max[mi], tensor(mi, ni));
|
||||
}
|
||||
scores_max[mi] = Allreduce<4>::run(scores_max[mi], max_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <int kMiLen, typename Engine0, typename Layout0, typename T>
|
||||
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, T const *max, T *sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
const float max_scaled = max[mi] * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled);
|
||||
sum[mi] += tensor(mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename paddle_type>
|
||||
struct cuteType;
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
__forceinline__ __device__ auto float_2_half2(const float x) {
|
||||
if constexpr (std::is_same<T, cutlass::half_t>::value) {
|
||||
return __float2half2_rn(x);
|
||||
} else {
|
||||
return __float2bfloat162_rn(x);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct uint16 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
};
|
||||
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
template<int BYTES>
|
||||
struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<64> {
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
template<typename Elt_type, uint32_t NUM_ELT>
|
||||
struct Vec {
|
||||
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
|
||||
Alias_type data;
|
||||
|
||||
inline __device__ Vec() {}
|
||||
|
||||
template<typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void load_from(const void *base_ptr) {
|
||||
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
|
||||
}
|
||||
|
||||
|
||||
inline __device__ void store_to(void *base_ptr) {
|
||||
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
|
||||
}
|
||||
|
||||
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void set_zero() {
|
||||
constexpr int size = sizeof(Vec_type) / sizeof(int);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; ++i) {
|
||||
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
|
||||
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, int PackSize>
|
||||
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
|
||||
static_assert(PackSize % 2 == 0);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < PackSize / 2; i++) {
|
||||
const float cos_inv_freq = cos.data.elt[i];
|
||||
const float sin_inv_freq = sin.data.elt[i];
|
||||
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
|
||||
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
|
||||
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
|
||||
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
|
||||
__forceinline__ __device__ void copy(
|
||||
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &identity_MN,
|
||||
const int max_MN = 0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename ThrCopy0, typename ThrCopy1,
|
||||
typename TiledCopy0, typename TiledCopy1>
|
||||
inline __device__ void gemm(
|
||||
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
ThrCopy0 &smem_thr_copy_A, ThrCopy1 &smem_thr_copy_B,
|
||||
TiledCopy0 &smem_tiled_copy_A, TiledCopy1 &smem_tiled_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));
|
||||
|
||||
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template<typename T, typename ReductionOp, int block_size>
|
||||
__inline__ __device__ T BlockAllReduce(T val) {
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
|
||||
if (threadIdx.x == 0) { result_broadcast = result; }
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
}
|
||||
|
||||
template<typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename ReductionOp, int thread_group_width = 32>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
ReductionOp op;
|
||||
#pragma unroll
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
template <int kPackSize, int knthreads>
|
||||
__device__ inline int get_data_count(const float * src, const float limit_value) {
|
||||
int count = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (src[i] >= limit_value) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
count = BlockAllReduce<int, SumOp<int>, knthreads>(count);
|
||||
return count;
|
||||
}
|
||||
@@ -0,0 +1,802 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_decoder_attn_kernel.h"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template<bool Is_first, int kMiLen, typename Tensor0, typename Tensor1, typename T>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &acc_o, const T *scores_max, const T *scores_max_prev, T * scores_sum, const float softmax_scale) {
|
||||
if (Is_first) {
|
||||
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
|
||||
} else {
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
const float scores_scale = expf((scores_max_prev[mi] - scores_max[mi]) * softmax_scale);
|
||||
scores_sum[mi] *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale;
|
||||
}
|
||||
}
|
||||
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
__global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attention_kernel(ParamType params) {
|
||||
using cuteType = typename Kernel_traits::cuteType;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using CacheKV_traits = typename Kernel_traits::CacheKV_traits;
|
||||
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int32_t kHeadDimKV = Kernel_traits::kHeadDimKV;
|
||||
constexpr int32_t kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int32_t kBlockSize = Kernel_traits::kBlockSize;
|
||||
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
|
||||
constexpr int32_t kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int32_t kTileN = Kernel_traits::kTileN;
|
||||
constexpr int32_t kBlockN = kTileN * kBlockSize;
|
||||
constexpr int32_t kDataBits = Kernel_traits::kDataBits;
|
||||
constexpr int32_t kMiLen = (kGqaGroupSize + 7) / 8;
|
||||
|
||||
const int32_t bi = blockIdx.y;
|
||||
const int32_t tidx = threadIdx.x;
|
||||
const int32_t partition_idx = blockIdx.x;
|
||||
const int32_t kv_head_idx = blockIdx.z;
|
||||
const int32_t q_head_idx = kv_head_idx * kGqaGroupSize;
|
||||
|
||||
const int32_t seq_len = params.seq_lens_decoder[bi] == 0 ? 0 : params.seq_lens_decoder[bi] + 1;
|
||||
|
||||
const int32_t head_num = params.head_num;
|
||||
const int32_t kv_head_num = params.kv_head_num;
|
||||
|
||||
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
|
||||
|
||||
if (seq_len == 0 || partition_idx >= partition_num) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (seq_len >= params.use_moba_seq_limit && params.qk_gate_topk_idx_ptr[(bi * kv_head_num + kv_head_idx) * Kernel_traits::kMaxN + partition_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const int q_bias_offset = q_head_idx * kHeadDim;
|
||||
|
||||
cuteType * q_input = reinterpret_cast<cuteType *>(params.q_input) + params.cu_seq_q[bi] * head_num * kHeadDim;
|
||||
|
||||
Tensor gQ = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<const cuteType *>(q_input) + q_bias_offset),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
const int32_t block_idx = partition_idx * kTileN;
|
||||
const int* block_table = params.block_table + bi * params.max_num_blocks_per_seq + block_idx;
|
||||
const int32_t physical_block_number = block_table[0];
|
||||
|
||||
const int32_t cache_offset = (physical_block_number * kv_head_num + kv_head_idx) * kBlockSize * kHeadDimKV;
|
||||
|
||||
Tensor gK = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_k) + cache_offset),
|
||||
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
|
||||
Stride<Int<kHeadDimKV>, _1>{});
|
||||
|
||||
Tensor gV = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_v) + cache_offset),
|
||||
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
|
||||
Stride<Int<kHeadDimKV>, _1>{});
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
Tensor sQ = make_tensor(
|
||||
make_smem_ptr(reinterpret_cast<cuteType *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQ{});
|
||||
Tensor sQK = make_tensor(
|
||||
sQ.data() + size(sQ),
|
||||
typename Kernel_traits::SmemLayoutQK{});
|
||||
|
||||
Tensor sK = make_tensor(sQK.data() + size(sQK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||
__shared__ ElementAccum scores_warp[kNWarps][kMiLen * kBlockM];
|
||||
|
||||
auto gmem_tiled_copy_Q = typename Kernel_traits::GmemTiledCopyQ{};
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
|
||||
auto gmem_tiled_copy_KV = typename Kernel_traits::GmemTiledCopyKV{};
|
||||
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
|
||||
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV);
|
||||
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
|
||||
|
||||
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
|
||||
|
||||
Tensor cKV = make_identity_tensor(make_shape(kBlockSize, kHeadDim));
|
||||
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
using SmemCopyAtom = typename Kernel_traits::SmemCopyAtom;
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
||||
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
|
||||
|
||||
Tensor tSsQK = smem_thr_copy_Q.partition_S(sQK);
|
||||
Tensor tSrQK = thr_mma.partition_fragment_A(sQK);
|
||||
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);
|
||||
|
||||
copy<false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, kGqaGroupSize);
|
||||
|
||||
|
||||
cute::cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
|
||||
const int32_t remain_seq_len = seq_len - partition_idx * kTileN * kBlockSize;
|
||||
|
||||
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
|
||||
|
||||
cute::cp_async_fence();
|
||||
|
||||
const int32_t warp_id = tidx / 32;
|
||||
const int32_t lane_id = tidx % 32;
|
||||
const int32_t row = lane_id / 4;
|
||||
const int32_t col = lane_id % 4;
|
||||
const int row_idx = tidx / 4;
|
||||
|
||||
using scale_k_vec = Vec<cuteType, 32>;
|
||||
using scale_v_vec = Vec<cuteType, 4>;
|
||||
|
||||
scale_k_vec scale_k;
|
||||
scale_k_vec zp_k;
|
||||
scale_v_vec scale_v;
|
||||
scale_v_vec zp_v;
|
||||
if constexpr (kDataBits == 4) {
|
||||
scale_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_dequant_scale + kv_head_idx * kHeadDim + col * 32);
|
||||
zp_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_zp + kv_head_idx * kHeadDim + col * 32);
|
||||
scale_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_dequant_scale + kv_head_idx * kHeadDim + row_idx * 4);
|
||||
zp_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_zp + kv_head_idx * kHeadDim + row_idx * 4);
|
||||
}
|
||||
|
||||
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
clear(acc_o);
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockSize>>{});
|
||||
|
||||
ElementAccum scores_max[kMiLen];
|
||||
ElementAccum scores_max_prev[kMiLen];
|
||||
ElementAccum scores_sum[kMiLen];
|
||||
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_max[mi] = -INFINITY;
|
||||
scores_sum[mi] = 0;
|
||||
}
|
||||
|
||||
const int cache_offset_step = kv_head_num * kBlockSize * kHeadDimKV;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < kTileN; ++n) {
|
||||
const int cur_remain_seq_len = remain_seq_len - n * kBlockSize;
|
||||
|
||||
if (cur_remain_seq_len <= 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
clear(acc_s);
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (n > 0) {
|
||||
tVgV.data() = tVgV.data() + (block_table[n] - block_table[n - 1]) * cache_offset_step;
|
||||
}
|
||||
|
||||
copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV);
|
||||
|
||||
cute::cp_async_fence();
|
||||
|
||||
if constexpr (kDataBits == 16) {
|
||||
if (n == 0) {
|
||||
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
} else {
|
||||
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
}
|
||||
} else {
|
||||
Tensor tSrKQuant = make_tensor<cuteType>(
|
||||
Layout<
|
||||
Shape<Shape<_2, _2>, Int<kBlockSize / 32>>,
|
||||
Stride<Shape<_1, _2>, _4>>{});
|
||||
if (n == 0) {
|
||||
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
|
||||
} else {
|
||||
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits, true>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
|
||||
}
|
||||
}
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
if (partition_idx == partition_num - 1 && cur_remain_seq_len < kBlockSize) {
|
||||
apply_mask<kMiLen>(scores, warp_id, col, cur_remain_seq_len);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_max_prev[mi] = scores_max[mi];
|
||||
}
|
||||
|
||||
reduce_max<kMiLen>(scores, scores_max);
|
||||
|
||||
if (col == 0) {
|
||||
scores_warp[warp_id][row] = scores_max[0];
|
||||
if constexpr (kMiLen > 1) {
|
||||
scores_warp[warp_id][row + 8] = scores_max[1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
MaxOp<ElementAccum> max_op;
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
float cur_max = scores_warp[0][tidx];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 1; i < kNWarps; ++i) {
|
||||
cur_max = max_op(scores_warp[i][tidx], cur_max);
|
||||
}
|
||||
scores_warp[0][tidx] = cur_max;
|
||||
}
|
||||
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (cur_remain_seq_len > kBlockSize && n < kTileN - 1) {
|
||||
tKgK.data() = tKgK.data() + (block_table[n + 1] - block_table[n]) * cache_offset_step;
|
||||
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_max[mi] = scores_warp[0][row + mi * 8];
|
||||
}
|
||||
|
||||
if (n == 0) {
|
||||
softmax_rescale_o<true, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
|
||||
} else {
|
||||
softmax_rescale_o<false, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
|
||||
}
|
||||
|
||||
Tensor rS = convert_type<cuteType>(acc_s);
|
||||
|
||||
Tensor trQK = smem_thr_copy_O.retile_S(rS);
|
||||
Tensor tsQK = smem_thr_copy_O.partition_D(sQK);
|
||||
cute::copy(smem_tiled_copy_O, trQK, tsQK);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (kDataBits == 16) {
|
||||
gemm(acc_o, tSrQK, tOrVt, tSsQK, tOsVt, tiled_mma, smem_thr_copy_Q, smem_thr_copy_V, smem_tiled_copy_Q, smem_tiled_copy_V);
|
||||
} else {
|
||||
Tensor tSrVQuant = make_tensor<cuteType>(
|
||||
Layout<
|
||||
Shape<_4, Shape<_2, _2>>,
|
||||
Stride<_1, Shape<_4, _8>>>{});
|
||||
gemm_value_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_o, tSrQK, tSsQK, tSrVQuant, sV, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_v.data.elt, zp_v.data.elt);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t pack_max_partition_num = (params.max_num_partitions + 3) / 4 * 4;
|
||||
uint32_t max_sum_offset = bi * pack_max_partition_num * head_num + (tidx + q_head_idx) * pack_max_partition_num + partition_idx;
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
params.maxs[max_sum_offset] = scores_warp[0][tidx] * params.inv_sqrt_dh;
|
||||
}
|
||||
|
||||
SumOp<ElementAccum> sum_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_sum[mi] = Allreduce<4>::run(scores_sum[mi], sum_op);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (col == 0) {
|
||||
scores_warp[warp_id][row] = scores_sum[0];
|
||||
if constexpr (kMiLen > 1) {
|
||||
scores_warp[warp_id][row + 8] = scores_sum[1];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Tensor rO = convert_type<cuteType>(acc_o);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sQ);
|
||||
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
float cur_sum = scores_warp[0][tidx];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 1; i < kNWarps; ++i) {
|
||||
cur_sum = sum_op(scores_warp[i][tidx], cur_sum);
|
||||
}
|
||||
scores_warp[0][tidx] = cur_sum;
|
||||
}
|
||||
|
||||
Tensor gO = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<cuteType *>(params.partition_attn_out) + ((bi * params.max_num_partitions + partition_idx) * head_num + q_head_idx)* kHeadDim),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
auto gmem_tiled_copy_O = typename Kernel_traits::GmemTiledCopyO{};
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sQ);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
constexpr int32_t copy_size = kGqaGroupSize * 16;
|
||||
__syncthreads();
|
||||
|
||||
if (tidx < copy_size) {
|
||||
cute::copy(gmem_tiled_copy_O, tOsO(_, 0, _), tOgO(_, 0, _));
|
||||
}
|
||||
|
||||
if constexpr (kMiLen > 1) {
|
||||
if (tidx < copy_size - 128) {
|
||||
cute::copy(gmem_tiled_copy_O, tOsO(_, 1, _), tOgO(_, 1, _));
|
||||
}
|
||||
}
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
params.sums[max_sum_offset] = scores_warp[0][tidx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
|
||||
const int32_t bi = blockIdx.z;
|
||||
const int32_t tidx = threadIdx.x;
|
||||
const int32_t head_idx = blockIdx.y;
|
||||
const int32_t head_num = params.head_num;
|
||||
|
||||
using float_vec = Vec<float, kNFloatPacksize>;
|
||||
const int32_t offset = bi * head_num * pack_max_partition_num + head_idx * pack_max_partition_num;
|
||||
|
||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||
const float* max_logits_ptr = params.maxs + offset;
|
||||
float global_max_logit = -FLT_MAX;
|
||||
|
||||
int32_t idx = tidx * kNFloatPacksize;
|
||||
#pragma unroll
|
||||
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
|
||||
float_vec cur_max = *reinterpret_cast<const float_vec*>(max_logits_ptr + idx);
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
|
||||
}
|
||||
} else {
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
|
||||
}
|
||||
}
|
||||
cur_max.store_to(shared_max_logits + idx);
|
||||
}
|
||||
|
||||
const int32_t packed_data_num = partition_num / kNFloatPacksize * kNFloatPacksize;
|
||||
|
||||
idx = packed_data_num + tidx;
|
||||
#pragma unroll
|
||||
for (; idx < partition_num; idx += kNReduceThreads) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx] != 0) {
|
||||
float cur_max = max_logits_ptr[idx];
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max);
|
||||
shared_max_logits[idx] = cur_max;
|
||||
}
|
||||
} else {
|
||||
float cur_max = max_logits_ptr[idx];
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max);
|
||||
shared_max_logits[idx] = cur_max;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
global_max_logit = BlockAllReduce<float, MaxOp<float>, kNReduceThreads>(global_max_logit);
|
||||
|
||||
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
|
||||
const float* exp_sums_ptr = params.sums + offset;
|
||||
float global_exp_sum = 0.0f;
|
||||
|
||||
idx = tidx * kNFloatPacksize;
|
||||
#pragma unroll
|
||||
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
|
||||
float_vec share_max = *reinterpret_cast<const float_vec*>(shared_max_logits + idx);
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
|
||||
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_max.data.elt[j] = exp_sub_max;
|
||||
}
|
||||
} else {
|
||||
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_max.data.elt[j] = exp_sub_max;
|
||||
}
|
||||
}
|
||||
share_max.store_to(share_sum_scale + idx);
|
||||
}
|
||||
|
||||
idx = packed_data_num + tidx;
|
||||
#pragma unroll
|
||||
for (; idx < partition_num; idx += kNReduceThreads) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx] != 0) {
|
||||
float share_max = shared_max_logits[idx];
|
||||
float exp_sub_max = expf(share_max - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_sum_scale[idx] = exp_sub_max;
|
||||
}
|
||||
} else {
|
||||
float share_max = shared_max_logits[idx];
|
||||
float exp_sub_max = expf(share_max - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_sum_scale[idx] = exp_sub_max;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
global_exp_sum = BlockAllReduce<float, SumOp<float>, kNReduceThreads>(global_exp_sum);
|
||||
|
||||
const float inv_global_exp_sum = fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||
return inv_global_exp_sum;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
__global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_attention_merge_kernel(ParamType params) {
|
||||
using cuteType = typename Kernel_traits::cuteType;
|
||||
constexpr int32_t kBlockN = Kernel_traits::kTileN * Kernel_traits::kBlockSize;
|
||||
constexpr int32_t kNReducePacksize = 16 / sizeof(cuteType);
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceWarps = Kernel_traits::kNReduceWarps;
|
||||
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
|
||||
const int32_t bi = blockIdx.z;
|
||||
const int32_t headdim_idx = kNReducePacksize * kNReduceWarps * blockIdx.x;
|
||||
const int32_t tidx = threadIdx.x;
|
||||
const int32_t head_idx = blockIdx.y;
|
||||
const int32_t warp_id = tidx / 32;
|
||||
const int32_t lane_id = tidx % 32;
|
||||
const int32_t seq_len = params.seq_lens_decoder[bi] + 1;
|
||||
const int32_t head_num = params.head_num;
|
||||
using pack_half = typename PackedHalf<cuteType>::Type;
|
||||
|
||||
|
||||
if (params.seq_lens_decoder[bi] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ char shared_mem[];
|
||||
|
||||
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
|
||||
const int32_t pack_max_partition_num = (params.max_num_partitions + kNFloatPacksize - 1) / kNFloatPacksize * kNFloatPacksize;
|
||||
|
||||
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
|
||||
|
||||
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
|
||||
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
|
||||
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
|
||||
|
||||
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
|
||||
|
||||
using T_vec = Vec<cuteType, kNReducePacksize>;
|
||||
|
||||
cuteType* partition_attn_out = reinterpret_cast<cuteType*>(params.partition_attn_out) + bi * head_num * params.max_num_partitions * kHeadDim + head_idx * kHeadDim + headdim_idx;
|
||||
|
||||
Vec<float, kNReducePacksize> acc;
|
||||
acc.set_zero();
|
||||
#pragma unroll
|
||||
for (int idx = lane_id; idx < partition_num; idx += 32) {
|
||||
if (seq_len >= params.use_moba_seq_limit && qk_gate_topk_idx_ptr[idx] == 0) {
|
||||
continue;
|
||||
}
|
||||
T_vec sub_logits = *reinterpret_cast<T_vec*>(&partition_attn_out[idx * head_num * kHeadDim + warp_id * kNReducePacksize]);
|
||||
float scale = share_sum_scale[idx];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kNReducePacksize; ++k) {
|
||||
acc.data.elt[k] += static_cast<float>(sub_logits.data.elt[k]) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
T_vec out;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kNReducePacksize; ++k) {
|
||||
out.data.elt[k] = static_cast<cuteType>(WarpAllReduce<float, SumOp<float>>(acc.data.elt[k]) * inv_global_exp_sum);
|
||||
}
|
||||
|
||||
const int ori_token_idx = params.cu_seq_q[bi];
|
||||
cuteType * attn_out = reinterpret_cast<cuteType *>(params.attn_out) + ori_token_idx * head_num * kHeadDim + head_idx * kHeadDim + headdim_idx + warp_id * kNReducePacksize;
|
||||
|
||||
if (lane_id == 0) {
|
||||
out.store_to(attn_out);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
void run_moba_decoder_attn(ParamType ¶ms, cudaStream_t stream) {
|
||||
dim3 grid;
|
||||
grid.x = params.max_num_partitions;
|
||||
grid.y = params.batch_size;
|
||||
grid.z = params.kv_head_num;
|
||||
constexpr int smem_size = Kernel_traits::kShareMemSize;
|
||||
constexpr auto kernel = &moba_decoder_attention_kernel<Kernel_traits, ParamType>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
|
||||
int32_t reduce_shared_mem_size = 2 * (params.max_num_partitions + 4) * sizeof(float);
|
||||
constexpr int32_t pack_size = 16 / sizeof(typename Kernel_traits::cuteType);
|
||||
static_assert(Kernel_traits::kHeadDim % pack_size == 0);
|
||||
static_assert((Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps) % pack_size == 0);
|
||||
grid.x = Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps / pack_size;
|
||||
grid.y = params.head_num;
|
||||
grid.z = params.batch_size;
|
||||
auto reduce_kernel = &moba_decoder_attention_merge_kernel<Kernel_traits, ParamType>;
|
||||
|
||||
if (reduce_shared_mem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
reduce_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size);
|
||||
}
|
||||
reduce_kernel<<<grid, Kernel_traits::kNReduceThreads, reduce_shared_mem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
|
||||
template<typename cute_type, int kCacheBits, int kBlockN, int kMaxN, typename ParamType>
|
||||
void run_moba_decoder_attn_hdim128(ParamType ¶ms, cudaStream_t stream) {
|
||||
const int gqaGroupSize = params.head_num / params.kv_head_num;
|
||||
using CacheKVTraits = CacheKV_quant_traits<cute_type, kCacheBits>;
|
||||
constexpr int kTileN = kBlockN / CacheKVTraits::kBlockSize;
|
||||
switch (gqaGroupSize) {
|
||||
case 12: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<12, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<8, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<7, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<6, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<5, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<4, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"DecoderBlockAttention not implemented for gqaGroupSize = %d", gqaGroupSize));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void DispatchMobaDecoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int batch_size,
|
||||
const int max_input_length,
|
||||
const int use_moba_seq_limit,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
const int kMobaBlockSize = 128;
|
||||
const int kMaxN = 1024;
|
||||
|
||||
constexpr int max_seq_per_block = kMobaBlockSize;
|
||||
moba_decoder_attn_params<cute_type> params;
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
const uint32_t max_num_partitions = (max_seq_k + max_seq_per_block) / max_seq_per_block;
|
||||
assert(head_dim == 128);
|
||||
|
||||
paddle::Tensor maxs = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
|
||||
paddle::Tensor sums = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
|
||||
paddle::Tensor partition_attn_out = paddle::empty({batch_size, max_num_partitions, head_num, head_dim}, q_input.dtype(), q_input.place());
|
||||
|
||||
params.q_input = reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>()));
|
||||
params.attn_out = reinterpret_cast<cute_type *>(const_cast<T*>(out.data<T>()));
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_len_decoder.data<int>());
|
||||
params.block_table = const_cast<int*>(block_tables.data<int>());
|
||||
params.max_input_length = max_input_length;
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_num_blocks_per_seq = block_tables.dims()[1];
|
||||
params.batch_size = batch_size;
|
||||
params.inv_sqrt_dh = 1.0f / std::sqrt(head_dim);
|
||||
params.max_num_partitions = max_num_partitions;
|
||||
params.maxs = reinterpret_cast<float*>(maxs.data<float>());
|
||||
params.sums = reinterpret_cast<float*>(sums.data<float>());
|
||||
params.partition_attn_out = reinterpret_cast<cute_type *>(partition_attn_out.data<T>());
|
||||
params.qk_gate_topk_idx_ptr = const_cast<int*>(qk_gate_topk_idx.data<int>());
|
||||
params.use_moba_seq_limit = use_moba_seq_limit;
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
|
||||
|
||||
if (cache_quant_type_str == "none") {
|
||||
params.cache_k = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k.data<T>()));
|
||||
params.cache_v = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v.data<T>()));
|
||||
run_moba_decoder_attn_hdim128<cute_type, 16, max_seq_per_block, kMaxN>(params, q_input.stream());
|
||||
} else {
|
||||
params.cache_k = const_cast<uint8_t*>(cache_k.data<uint8_t>());
|
||||
params.cache_v = const_cast<uint8_t*>(cache_v.data<uint8_t>());
|
||||
params.cache_k_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_quant_scale.get().data<T>()));
|
||||
params.cache_v_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_quant_scale.get().data<T>()));
|
||||
params.cache_k_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_dequant_scale.get().data<T>()));
|
||||
params.cache_v_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_dequant_scale.get().data<T>()));
|
||||
params.cache_k_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_zero_points.get().data<T>()));
|
||||
params.cache_v_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_zero_points.get().data<T>()));
|
||||
if (cache_quant_type_str == "cache_int8_zp") {
|
||||
run_moba_decoder_attn_hdim128<cute_type, 8, max_seq_per_block, kMaxN>(params, q_input.stream());
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
run_moba_decoder_attn_hdim128<cute_type, 4, max_seq_per_block, kMaxN>(params, q_input.stream());
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"GQA Attention not implemented for cache_quant_type_str = %s", cache_quant_type_str.c_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MobaDecoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
const int batch_size = block_tables.dims()[0];
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
return DispatchMobaDecoderAttn<phi::dtype::float16>(
|
||||
q_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
block_tables,
|
||||
k_block_means,
|
||||
out,
|
||||
qk_gate_topk_idx,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
batch_size,
|
||||
max_input_length,
|
||||
use_moba_seq_limit,
|
||||
cache_quant_type_str);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return DispatchMobaDecoderAttn<phi::dtype::bfloat16>(
|
||||
q_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
block_tables,
|
||||
k_block_means,
|
||||
out,
|
||||
qk_gate_topk_idx,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
batch_size,
|
||||
max_input_length,
|
||||
use_moba_seq_limit,
|
||||
cache_quant_type_str);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,225 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "../moba_attn_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
template <typename T>
|
||||
struct moba_decoder_attn_params {
|
||||
T *__restrict__ q_input;
|
||||
void *__restrict__ cache_k;
|
||||
void *__restrict__ cache_v;
|
||||
|
||||
T *__restrict__ attn_out;
|
||||
T *__restrict__ partition_attn_out;
|
||||
T *__restrict__ cache_k_dequant_scale;
|
||||
T *__restrict__ cache_v_dequant_scale;
|
||||
T *__restrict__ cache_k_quant_scale;
|
||||
T *__restrict__ cache_v_quant_scale;
|
||||
T *__restrict__ cache_k_zp;
|
||||
T *__restrict__ cache_v_zp;
|
||||
int * __restrict__ cu_seq_q;
|
||||
float * sums;
|
||||
float * maxs;
|
||||
int * seq_lens_encoder;
|
||||
int * seq_lens_decoder;
|
||||
int * block_table;
|
||||
int max_input_length;
|
||||
int max_seq_len;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_num_blocks_per_seq;
|
||||
float scale_softmax;
|
||||
int batch_size;
|
||||
int max_num_partitions;
|
||||
float inv_sqrt_dh;
|
||||
int *qk_gate_topk_idx_ptr;
|
||||
int use_moba_seq_limit;
|
||||
};
|
||||
|
||||
template <typename cute_type_, int DataBits_>
|
||||
struct CacheKV_quant_traits {
|
||||
using cuteType = cute_type_;
|
||||
static constexpr int kDataBits = DataBits_;
|
||||
static constexpr int kBlockSize = 64;
|
||||
static constexpr int kHeadDim = 128;
|
||||
static constexpr int kBlockKSmem = 64;
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockSize>, Int<kHeadDim>>{}));
|
||||
|
||||
static constexpr int kNWarps = 4;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
|
||||
static constexpr int kThreadPerValue = 16 / sizeof(cuteType);
|
||||
static constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
|
||||
Stride<Int<kThreadsPerRow>, _1>>;
|
||||
|
||||
using GmemTiledCopyQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<cuteType, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
|
||||
using ValLayoutMNK = Layout<Shape<_1,_4,_1>>;
|
||||
|
||||
using PermutationMNK = Tile<_16, Int<16 * kNWarps>, _16>;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
MMA_Atom_Arch,
|
||||
ValLayoutMNK,
|
||||
PermutationMNK>;
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, cuteType>;
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockSize>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockSize>>{}));
|
||||
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, cuteType>;
|
||||
|
||||
static constexpr int kShareMemSize = size(SmemLayoutKV{}) * 2 * sizeof(cuteType);
|
||||
};
|
||||
|
||||
template <int kGqaGroupSize_, int kTileN_, int kMaxN_, typename CacheKV_traits_>
|
||||
struct moba_decoder_attn_kernel_traits {
|
||||
using ElementAccum = float;
|
||||
using CacheKV_traits = CacheKV_traits_;
|
||||
using cuteType = typename CacheKV_traits::cuteType;
|
||||
static constexpr int kDataBits = CacheKV_traits::kDataBits;
|
||||
static constexpr int kTileN = kTileN_;
|
||||
static constexpr int kMaxN = kMaxN_;
|
||||
static constexpr int kGqaGroupSize = kGqaGroupSize_;
|
||||
static constexpr int kHeadDim = CacheKV_traits::kHeadDim;
|
||||
static constexpr int kHeadDimKV = kHeadDim / (16 / kDataBits);
|
||||
static constexpr int kMinGemmM = 16;
|
||||
static constexpr int kBlockM = (kGqaGroupSize + kMinGemmM - 1) / kMinGemmM * kMinGemmM;
|
||||
static constexpr int kBlockSize = CacheKV_traits::kBlockSize;
|
||||
static_assert(kGqaGroupSize <= 16);
|
||||
static constexpr int32_t kNWarps = CacheKV_traits::kNWarps;
|
||||
|
||||
static constexpr int kBlockKSmem = CacheKV_traits::kBlockKSmem;
|
||||
static constexpr int kBlockKVSmem = kHeadDimKV <= 64 ? kHeadDimKV : 64;
|
||||
static_assert(kHeadDim % kBlockKSmem == 0);
|
||||
static constexpr int kNReduceWarps = 4;
|
||||
static constexpr int kNReduceThreads = kNReduceWarps * 32;
|
||||
|
||||
|
||||
using SmemLayoutAtomQ = typename CacheKV_traits::SmemLayoutAtomQ;
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutQK = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kBlockSize>>{}));
|
||||
|
||||
using SmemLayoutAtomKV = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
Shape<Int<8>, Int<kBlockKVSmem>>,
|
||||
Stride<Int<kBlockKVSmem>, _1>>{}));
|
||||
|
||||
using SmemLayoutKV_ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKV{},
|
||||
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{}));
|
||||
|
||||
using SmemLayoutKV = std::conditional_t<
|
||||
kDataBits == 16,
|
||||
SmemLayoutKV_,
|
||||
decltype(get_nonswizzle_portion(SmemLayoutKV_{}))
|
||||
>;
|
||||
|
||||
constexpr static int kBlockKVSize = kDataBits == 4 ? 32 : kBlockSize;
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockKVSize>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockKVSize>>{}));
|
||||
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
static constexpr int kThreadsPerRow = CacheKV_traits::kThreadsPerRow;
|
||||
static constexpr int kThreadsKVPerRow = kThreadsPerRow / (16 / kDataBits);
|
||||
static constexpr int kNThreads = CacheKV_traits::kNThreads;
|
||||
|
||||
using GmemKVLayoutAtom = Layout<
|
||||
Shape<Int<kNThreads / kThreadsKVPerRow>, Int<kThreadsKVPerRow>>,
|
||||
Stride<Int<kThreadsKVPerRow>, _1>>;
|
||||
|
||||
using SmemCopyAtom = typename CacheKV_traits::SmemCopyAtom;
|
||||
using TiledMma = typename CacheKV_traits::TiledMma;
|
||||
|
||||
static constexpr int kThreadPerValue = CacheKV_traits::kThreadPerValue;
|
||||
|
||||
using GmemTiledCopyQ = typename CacheKV_traits::GmemTiledCopyQ;
|
||||
using GmemLayoutAtom = typename CacheKV_traits::GmemLayoutAtom;
|
||||
using GmemTiledCopyKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
|
||||
GmemKVLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
|
||||
using SmemCopyAtomTransposed = typename CacheKV_traits::SmemCopyAtomTransposed;
|
||||
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, cuteType>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, cuteType>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
static constexpr int kShareMemSize = (size(SmemLayoutQ{}) + size(SmemLayoutQK{}) + size(SmemLayoutKV{}) * 2) * sizeof(cuteType);
|
||||
};
|
||||
@@ -0,0 +1,189 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "../moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
template <typename T, int kBlockSize, int kHeadDim, int moba_block_size, int kMaxN>
|
||||
__global__ void moba_decoder_attn_write_c16(
|
||||
const T * qkv_out,
|
||||
const T * qkv_bias,
|
||||
T * q_input,
|
||||
const int * cu_seq_q,
|
||||
const int * cu_seq_k,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
T * cache_k,
|
||||
T * cache_v,
|
||||
const int * block_tables,
|
||||
const float * rope_sin_cos,
|
||||
T *k_block_means,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int max_blocks_per_seq,
|
||||
const int max_input_length) {
|
||||
|
||||
int bidh = blockIdx.x;
|
||||
const int bidb = blockIdx.y;
|
||||
const int tidx = threadIdx.x;
|
||||
const int seq_len = seq_len_decoder[bidb];
|
||||
|
||||
if (seq_len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int kPackSize = 4;
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
using rope_type = Vec<float, kPackSize / 2>;
|
||||
SrcType src, bias, k_prev;
|
||||
rope_type sin, cos;
|
||||
const int bias_idx = bidh * kHeadDim + tidx * kPackSize;
|
||||
const int ori_token_idx = cu_seq_q[bidb];
|
||||
src.load_from(qkv_out + ori_token_idx * (head_num + 2 * kv_head_num) * kHeadDim + bias_idx);
|
||||
if (qkv_bias != nullptr) {
|
||||
bias.load_from(qkv_bias + bias_idx);
|
||||
src.add(bias);
|
||||
}
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const int32_t physical_block_number = block_table_now[seq_len / kBlockSize];
|
||||
|
||||
|
||||
if (bidh < head_num) {
|
||||
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(q_input + cu_seq_q[bidb] * head_num * kHeadDim + bias_idx);
|
||||
} else if (bidh < head_num + kv_head_num) {
|
||||
bidh -= head_num;
|
||||
const int token_in_blocks = seq_len % kBlockSize;
|
||||
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
|
||||
|
||||
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
|
||||
src.store_to(cache);
|
||||
|
||||
const int seq_len_block = seq_len / moba_block_size;
|
||||
|
||||
const int store_mean_idx = (bidb * kMaxN + seq_len_block) * kv_head_num * kHeadDim + bidh * kHeadDim + tidx * kPackSize;
|
||||
|
||||
if (seq_len % moba_block_size != 0) {
|
||||
const int token_num_prev = seq_len % moba_block_size;
|
||||
const float inv_tokens_sum = fdividef(1.0f, token_num_prev + 1);
|
||||
k_prev.load_from(k_block_means + store_mean_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
src.data.elt[i] = T(inv_tokens_sum * (float(src.data.elt[i]) + float(k_prev.data.elt[i]) * token_num_prev));
|
||||
}
|
||||
}
|
||||
|
||||
src.store_to(k_block_means + store_mean_idx);
|
||||
|
||||
} else {
|
||||
bidh -= head_num + kv_head_num;
|
||||
const int token_in_blocks = seq_len % kBlockSize;
|
||||
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
|
||||
src.store_to(cache);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void MobaDecoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& rope_sin_cos,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
constexpr int kThreads = 32;
|
||||
constexpr int kHeadDim = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
assert(kHeadDim == head_dim);
|
||||
constexpr int kBlockSize = 64;
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
const int batch_size = block_tables.dims()[0];
|
||||
if (cache_quant_type_str == "none") {
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = head_num + kv_head_num * 2;
|
||||
grid_dims.y = batch_size;
|
||||
if (qkv_out.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
|
||||
qkv_out.data<T>(),
|
||||
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
|
||||
const_cast<T*>(q_input.data<T>()),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T *>(cache_k.data<T>()),
|
||||
const_cast<T *>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
rope_sin_cos.data<float>(),
|
||||
const_cast<T*>(k_block_means.data<T>()),
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_blocks_per_seq,
|
||||
max_input_length);
|
||||
} else if (qkv_out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
|
||||
qkv_out.data<T>(),
|
||||
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
|
||||
const_cast<T*>(q_input.data<T>()),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T *>(cache_k.data<T>()),
|
||||
const_cast<T *>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
rope_sin_cos.data<float>(),
|
||||
const_cast<T*>(k_block_means.data<T>()),
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_blocks_per_seq,
|
||||
max_input_length);
|
||||
}
|
||||
} else {
|
||||
PD_THROW("Only supported cache_quant_type_str in ['none'].");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template <typename T, int knthreads, int moba_block_size, int kBlockMaxN, int searchtimes>
|
||||
__global__ void qk_gate_sort_decoder_kernel(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *decoder_seq_lens,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int kGqaGroupSize,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
const int bidb = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidh_kv = bidh / kGqaGroupSize;
|
||||
|
||||
if (decoder_seq_lens[bidb] == 0 || decoder_seq_lens[bidb] < use_moba_seq_limit) {
|
||||
return;
|
||||
}
|
||||
const int seq_len = (decoder_seq_lens[bidb] + moba_block_size - 1) / moba_block_size;
|
||||
|
||||
constexpr int kPackSize = kBlockMaxN / knthreads;
|
||||
|
||||
static_assert(kBlockMaxN % knthreads == 0);
|
||||
|
||||
T token_mean[kPackSize];
|
||||
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
using SrcType_f = Vec<float, kPackSize>;
|
||||
using SrcType_i = Vec<int, kPackSize>;
|
||||
|
||||
SrcType src;
|
||||
SrcType_f src_f;
|
||||
SrcType_i select_idx;
|
||||
|
||||
select_idx.set_zero();
|
||||
|
||||
const int load_offset = bidb * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
|
||||
|
||||
src.load_from(qk_gate_weight + load_offset);
|
||||
|
||||
float max_global = -FLT_MAX;
|
||||
float min_global = FLT_MAX;
|
||||
|
||||
const int data_len = seq_len - tidx * kPackSize;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (i < data_len) {
|
||||
src_f.data.elt[i] = float(src.data.elt[i]);
|
||||
min_global = min(min_global, src_f.data.elt[i]);
|
||||
} else {
|
||||
src_f.data.elt[i] = -FLT_MAX;
|
||||
}
|
||||
max_global = max(max_global, src_f.data.elt[i]);
|
||||
}
|
||||
|
||||
|
||||
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
|
||||
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
|
||||
|
||||
|
||||
float right_limit = max_global;
|
||||
float left_limit = min_global;
|
||||
|
||||
float mid_limit;
|
||||
int count;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < searchtimes; i++) {
|
||||
mid_limit = (left_limit + right_limit) * 0.5f;
|
||||
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
|
||||
if (count < top_k_left) {
|
||||
right_limit = mid_limit;
|
||||
} else if (count > top_k_right) {
|
||||
left_limit = mid_limit;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int store_idx = bidb * kv_head_num * kBlockMaxN + bidh_kv * kBlockMaxN + tidx * kPackSize;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (src_f.data.elt[i] >= mid_limit) {
|
||||
qk_gate_topk_idx[store_idx + i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (tidx == 0) {
|
||||
qk_gate_topk_idx[store_idx] = 1;
|
||||
qk_gate_topk_idx[store_idx + seq_len - 1] = 1;
|
||||
qk_gate_topk_idx[store_idx + seq_len - 2] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <int kBlockMaxN, int moba_block_size, typename T>
|
||||
void qk_gate_sort_decoder(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *decoder_seq_lens,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int batch_size,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int gqa_group_size = head_num / kv_head_num;
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
const int knthreads = kBlockMaxN / kPackSize;
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = batch_size;
|
||||
grid_dims.y = head_num;
|
||||
const int searchtimes = 6;
|
||||
|
||||
constexpr auto kernel = qk_gate_sort_decoder_kernel<T, knthreads, moba_block_size, kBlockMaxN, searchtimes>;
|
||||
|
||||
kernel<<<grid_dims, knthreads, 0, 0>>>(
|
||||
qk_gate_weight,
|
||||
qk_gate_topk_idx,
|
||||
decoder_seq_lens,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
gqa_group_size,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchQkSortDecoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
|
||||
const int batch_size = seq_len_decoder.dims()[0];
|
||||
paddle::Tensor qk_gate_topk_idx = paddle::empty({batch_size, kv_head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
|
||||
|
||||
qk_gate_sort_decoder<kMaxN, kMobaBlockSize, T>(
|
||||
qk_gate_weight.data<T>(),
|
||||
qk_gate_topk_idx.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit,
|
||||
qk_gate_weight.stream()
|
||||
);
|
||||
|
||||
return {qk_gate_topk_idx};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> QkSortDecoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortDecoder<phi::dtype::float16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit)
|
||||
);
|
||||
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortDecoder<phi::dtype::bfloat16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_qk_sort_decoder)
|
||||
.Inputs({
|
||||
"qk_gate_weight",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder"})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"top_k_left: int",
|
||||
"top_k_right: int",
|
||||
"use_moba_seq_limit: int"})
|
||||
.Outputs({"qk_gate_topk_idx"})
|
||||
.SetKernelFn(PD_KERNEL(QkSortDecoder));
|
||||
143
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h
Normal file
143
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct moba_encoder_attn_params {
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void * __restrict__ o_ptr;
|
||||
int * __restrict__ cu_seq_q;
|
||||
int * __restrict__ cu_seq_k;
|
||||
int * __restrict__ qk_gate_topk_idx;
|
||||
int * __restrict__ seq_len_encoder;
|
||||
int * __restrict__ cu_seq_q_pack;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_seq_q;
|
||||
int max_seq_k;
|
||||
int batch_size;
|
||||
int gqa_group_size;
|
||||
float scale_softmax_log2;
|
||||
int use_moba_seq_limit;
|
||||
};
|
||||
|
||||
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
|
||||
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
|
||||
struct SharedStorageQKVO {
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
union {
|
||||
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
cutlass::arch::ClusterTransactionBarrier barrier_Q;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, int kMaxN_, bool UseMoba_, typename elem_type=cutlass::half_t>
|
||||
struct moba_encoder_attn_kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int32_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int UseMoba = UseMoba_;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static constexpr int kMaxN = kMaxN_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
|
||||
static constexpr int kStages = kStages_;
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
using TiledMma0 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
using TiledMma1 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutV =
|
||||
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
|
||||
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
|
||||
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
|
||||
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyOValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
TiledCopyOAtom{},
|
||||
TiledCopyOThrLayout{}, // Thr layout
|
||||
TiledCopyOValLayout{} // Val layout
|
||||
));
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
};
|
||||
473
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp
Normal file
473
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp
Normal file
@@ -0,0 +1,473 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
enum class AttnNamedBarriers {
|
||||
QueryEmpty = 0,
|
||||
ValueEmpty = 1,
|
||||
TileCountSmemEmpty = 2,
|
||||
TileCountSmemFull = 3,
|
||||
WarpSchedulerWG1 = 4,
|
||||
WarpSchedulerWG2 = 5,
|
||||
WarpSchedulerWG3 = 6,
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopAttn {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
|
||||
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
|
||||
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
|
||||
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutV = SmemLayoutK;
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVt =
|
||||
decltype(cute::composition(SmemLayoutV{},
|
||||
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
|
||||
using SmemLayoutO = typename Ktraits::SmemLayoutO;
|
||||
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
|
||||
|
||||
using TMA_Q = decltype(make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{})); // no mcast for Q
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
|
||||
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
Element const* ptr_Q;
|
||||
LayoutT layout_Q;
|
||||
Element const* ptr_K;
|
||||
LayoutT layout_K;
|
||||
Element const* ptr_V;
|
||||
LayoutT layout_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutT layout_Q;
|
||||
LayoutT layout_K;
|
||||
LayoutT layout_V;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
|
||||
TMA_Q tma_load_Q = make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
mQ,
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{}); // no mcast for Q
|
||||
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
|
||||
TMA_KV tma_load_K = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mK,
|
||||
SmemLayoutK{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
|
||||
TMA_KV tma_load_V = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mV,
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.layout_Q, args.layout_K, args.layout_V,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
|
||||
tma_load_Q, tma_load_K, tma_load_V,
|
||||
args.softmax_scale_log2};
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(
|
||||
const MTensor &m_tensor,
|
||||
const Shape &tile_shape,
|
||||
const int *cu_seq_len,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int actual_seq_len) const {
|
||||
auto g_offset = local_tile(
|
||||
m_tensor(_, _, bidh),
|
||||
cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(cu_seq_len[bidb], _0{}));
|
||||
auto g_sequence = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
|
||||
g_offset.stride()
|
||||
));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
|
||||
template <bool UseMoba, typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_write_k,
|
||||
PipelineState& smem_pipe_write_v,
|
||||
SharedStorage &shared_storage,
|
||||
const int *qk_gate_topk_idx,
|
||||
const int n_block_max,
|
||||
const int m_block,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
|
||||
Tensor gQ = get_local_tile_tensor(
|
||||
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
|
||||
Tensor gK = get_local_tile_tensor(
|
||||
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
Tensor gV = get_local_tile_tensor(
|
||||
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
|
||||
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
|
||||
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
|
||||
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
|
||||
|
||||
uint16_t mcast_mask_kv = 0;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
||||
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
|
||||
}
|
||||
|
||||
|
||||
if (lane_predicate) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
}
|
||||
|
||||
if (lane_predicate) {
|
||||
int idx = 0;
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; ) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
int pre_idx = 1;
|
||||
if constexpr (UseMoba) {
|
||||
pre_idx = qk_gate_topk_idx[idx];
|
||||
}
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - pre_idx), tKsK(_, smem_pipe_write_k.index()));
|
||||
|
||||
++smem_pipe_write_k;
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
n_block -= pre_idx;
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
if (lane_predicate) {
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
warp_scheduler_barrier_sync() {
|
||||
if constexpr (UseSchedulerBarrier) {
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
mma_init() {
|
||||
if constexpr (!UseSchedulerBarrier) { return; }
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if (cutlass::canonical_warp_group_idx() > 1) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
|
||||
}
|
||||
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
|
||||
if (cutlass::canonical_warp_group_idx() > 2) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
warp_scheduler_barrier_arrive() {
|
||||
if constexpr (!UseSchedulerBarrier) { return; }
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
||||
} else {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <bool UseMoba, typename SharedStorage, typename FrgTensorO, typename Softmax>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_read_k,
|
||||
PipelineState& smem_pipe_read_v,
|
||||
FrgTensorO& tOrO,
|
||||
Softmax& softmax,
|
||||
const int *qk_gate_topk_idx,
|
||||
const int n_block_max,
|
||||
const int thread_idx,
|
||||
const int m_block,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k,
|
||||
SharedStorage& shared_storage) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
|
||||
|
||||
typename Ktraits::TiledMma0 tiled_mma0;
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
|
||||
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMma0.partition_fragment_B(sK);
|
||||
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
|
||||
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<0>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k);
|
||||
++smem_pipe_read_k;
|
||||
|
||||
auto col_limit_causal = [&](int row, int n_block) {
|
||||
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
|
||||
};
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
||||
Tensor tScS = threadMma0.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if (int(get<1>(tScS(i))) >=
|
||||
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
|
||||
tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
|
||||
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
||||
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
||||
clear(scores_scale);
|
||||
|
||||
int idx = 0;
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; ) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
||||
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
if constexpr (UseMoba) {
|
||||
n_block -= qk_gate_topk_idx[idx];
|
||||
idx += 1;
|
||||
} else {
|
||||
n_block -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v);
|
||||
++smem_pipe_read_v;
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
}
|
||||
|
||||
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
FrgTensorO const& tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
const int o_head_stride,
|
||||
const int real_seq,
|
||||
T * out_ptr) {
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<Element>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::ValueEmpty) /*id*/);
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
|
||||
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(o_head_stride, _1{}));
|
||||
|
||||
GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
|
||||
|
||||
if (real_seq >= kBlockM) {
|
||||
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
# include "cutlass/util/cublas_wrappers.hpp"
|
||||
#endif
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_attn.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
template <int kHeadDim>
|
||||
auto get_gmem_layout(int token_num, int head_num) {
|
||||
return make_layout(
|
||||
make_shape(token_num, kHeadDim, head_num),
|
||||
make_stride(head_num * kHeadDim, _1{}, kHeadDim));
|
||||
}
|
||||
|
||||
template <typename Ktraits>
|
||||
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
|
||||
moba_encoder_attention_kernel(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT moba_encoder_attn_params const data_params) {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
using SoftType = ElementAccum;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
constexpr int kMaxN = Ktraits::kMaxN;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
|
||||
const int seq_len_q = data_params.seq_len_encoder[bidb];
|
||||
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
|
||||
|
||||
|
||||
if (seq_len_q == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
__align__(16) __shared__ int qk_gate_topk_idx[kMaxN];
|
||||
const int *qk_gate_idx_cur_offset = data_params.qk_gate_topk_idx + data_params.cu_seq_q_pack[bidb] / kBlockM * data_params.head_num * kMaxN + (m_block * data_params.head_num + bidh) * kMaxN;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < kMaxN / 4; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
|
||||
reinterpret_cast<int4*>(qk_gate_topk_idx)[i] = reinterpret_cast<const int4*>(qk_gate_idx_cur_offset)[i];
|
||||
}
|
||||
|
||||
|
||||
const int n_block_max = min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN), cute::ceil_div(seq_len_k, kBlockN));
|
||||
|
||||
if (m_block * kBlockM >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
shared_storage.barrier_Q.init(1);
|
||||
}
|
||||
|
||||
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
|
||||
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
|
||||
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
collective_mainloop.load<Ktraits::UseMoba>(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_write_k,
|
||||
smem_pipe_write_v,
|
||||
shared_storage,
|
||||
qk_gate_topk_idx,
|
||||
n_block_max,
|
||||
m_block,
|
||||
bidh,
|
||||
bidb,
|
||||
data_params.cu_seq_q,
|
||||
data_params.cu_seq_k,
|
||||
seq_len_q,
|
||||
seq_len_k);
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
|
||||
collective_mainloop.mma_init();
|
||||
|
||||
PipelineState smem_pipe_read_k, smem_pipe_read_v;
|
||||
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
|
||||
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
|
||||
|
||||
collective_mainloop.mma<Ktraits::UseMoba>(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_read_k,
|
||||
smem_pipe_read_v,
|
||||
tOrO,
|
||||
softmax,
|
||||
qk_gate_topk_idx,
|
||||
n_block_max,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
m_block,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
shared_storage);
|
||||
|
||||
const int o_head_stride = data_params.head_num * kHeadDim;
|
||||
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
|
||||
|
||||
const int real_seq = seq_len_q - m_block * kBlockM;
|
||||
|
||||
collective_mainloop.store<NumMmaThreads>(
|
||||
mainloop_params,
|
||||
tOrO,
|
||||
shared_storage,
|
||||
tiled_mma1,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
o_head_stride,
|
||||
real_seq,
|
||||
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_moba_decoder_attn(moba_encoder_attn_params ¶ms, cudaStream_t stream) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(params.q_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_q * params.batch_size, params.head_num),
|
||||
static_cast<Element const*>(params.k_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
|
||||
static_cast<Element const*>(params.v_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
|
||||
params.scale_softmax_log2
|
||||
});
|
||||
|
||||
int num_blocks_m = cutlass::ceil_div(params.max_seq_q, Kernel_traits::kBlockM);
|
||||
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)moba_encoder_attention_kernel<Kernel_traits>;
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = num_blocks_m;
|
||||
grid_dims.y = params.head_num;
|
||||
grid_dims.z = params.batch_size;
|
||||
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
|
||||
}
|
||||
|
||||
|
||||
template <int kBlockM, int kBlockN, int kMaxN, typename InputType>
|
||||
void run_moba_encoder_attn_hdim128(moba_encoder_attn_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
constexpr static int Headdim = 128;
|
||||
constexpr static int kNWarps = kBlockM / 16 + 4;
|
||||
constexpr static int kStages = 2;
|
||||
|
||||
using Ktraits = moba_encoder_attn_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, kMaxN, true, InputType>;
|
||||
run_moba_decoder_attn<Ktraits>(params, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DispatchMobaEncoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& out,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int batch_size,
|
||||
const int max_input_length) {
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
|
||||
moba_encoder_attn_params params;
|
||||
memset(¶ms, 0, sizeof(moba_encoder_attn_params));
|
||||
|
||||
params.q_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>()));
|
||||
params.k_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>()));
|
||||
params.v_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>()));
|
||||
params.o_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(out.data<T>()));
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_seq_q = max_seq_q;
|
||||
params.max_seq_k = max_seq_k;
|
||||
params.batch_size = batch_size;
|
||||
params.gqa_group_size = head_num / kv_head_num;
|
||||
constexpr float kLog2e = 1.4426950408889634074;
|
||||
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
|
||||
params.qk_gate_topk_idx = const_cast<int*>(qk_gate_topk_idx.data<int>());
|
||||
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.cu_seq_q_pack = const_cast<int*>(cu_seq_q_pack.data<int>());
|
||||
|
||||
run_moba_encoder_attn_hdim128<kBlockM, kBlockN, kMaxN, cute_type>(params, out.stream());
|
||||
}
|
||||
|
||||
void MobaEncoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& out,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length) {
|
||||
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
return
|
||||
DispatchMobaEncoderAttn<phi::dtype::float16>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
qk_gate_topk_idx,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
out,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
batch_size,
|
||||
max_input_length);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return
|
||||
DispatchMobaEncoderAttn<phi::dtype::bfloat16>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
qk_gate_topk_idx,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
out,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
batch_size,
|
||||
max_input_length);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moba_encoder_attn)
|
||||
.Inputs({
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"qk_gate_topk_idx",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"cu_seq_q_pack",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"out"})
|
||||
.Attrs({
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_input_length: int"})
|
||||
.Outputs({"attn_out"})
|
||||
.SetInplaceMap({{"out", "attn_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MobaEncoderAttn));
|
||||
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template <typename T, int kBlockSize, int kHeadDim>
|
||||
__global__ void write_encoder_cachekv_c16(
|
||||
const T * k_input,
|
||||
const T * v_input,
|
||||
const int * cu_seq_k,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
T * cache_k,
|
||||
T * cache_v,
|
||||
const int * block_tables,
|
||||
const int kv_head_num,
|
||||
const int max_blocks_per_seq) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
const int block_idx = blockIdx.x * kBlockSize;
|
||||
int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
const int row_idx = tidx / (kHeadDim / kPackSize);
|
||||
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
|
||||
const int seq_len = seq_len_encoder[bidb];
|
||||
|
||||
if (seq_len == 0) return;
|
||||
|
||||
const int ramian_tokens = seq_len - block_idx;
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bidh -= kv_head_num;
|
||||
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void MobaEncoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_q,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
constexpr int kThreads = 128;
|
||||
constexpr int kHeadDim = 128;
|
||||
assert(kHeadDim == head_dim);
|
||||
constexpr int kBlockSize = 64;
|
||||
const int batch_size = block_tables.dims()[0];
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
if (cache_quant_type_str == "none") {
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_q + kBlockSize - 1) / kBlockSize;
|
||||
grid_dims.y = kv_head_num * 2;
|
||||
grid_dims.z = batch_size;
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(v_input.data<T>()),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T*>(cache_k.data<T>()),
|
||||
const_cast<T*>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
kv_head_num,
|
||||
max_blocks_per_seq);
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(v_input.data<T>()),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T*>(cache_k.data<T>()),
|
||||
const_cast<T*>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
kv_head_num,
|
||||
max_blocks_per_seq);
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"Quantized cache not implemented for cache_quant_type = %s", cache_quant_type_str.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_encoder_attn_write_cache_kv)
|
||||
.Inputs({
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cache_k",
|
||||
"cache_v",
|
||||
"block_tables",
|
||||
paddle::Optional("cache_k_quant_scale"),
|
||||
paddle::Optional("cache_v_quant_scale"),
|
||||
paddle::Optional("cache_k_dequant_scale"),
|
||||
paddle::Optional("cache_v_dequant_scale"),
|
||||
paddle::Optional("cache_k_zero_points"),
|
||||
paddle::Optional("cache_v_zero_points")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_q: int",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({"cache_k_out", "cache_v_out"})
|
||||
.SetInplaceMap({{"cache_k", "cache_k_out"},
|
||||
{"cache_v", "cache_v_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MobaEncoderAttnWriteCacheKv));
|
||||
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
|
||||
template <typename T, int knthreads, int moba_block_size, int kBlockM, int kBlockMaxN, int searchtimes>
|
||||
__global__ void qk_gate_sort_encoder_kernel(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int* cu_seq_q,
|
||||
const int* cu_seq_k,
|
||||
const int* cu_seq_q_pack,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int kGqaGroupSize,
|
||||
const int top_k_left,
|
||||
const int top_k_right) {
|
||||
|
||||
const int bidt = blockIdx.x * kBlockM;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kPackSize = kBlockMaxN / knthreads;
|
||||
|
||||
static_assert(kBlockMaxN % knthreads == 0);
|
||||
|
||||
const int seq_len_q = seq_len_encoder[bidb];
|
||||
|
||||
if (seq_len_q == 0 || bidt >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq_len_k = (bidt + kBlockM + seq_len_decoder[bidb]);
|
||||
|
||||
const int seq_len_moba = seq_len_k / moba_block_size;
|
||||
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
using SrcType_f = Vec<float, kPackSize>;
|
||||
using SrcType_i = Vec<int, kPackSize>;
|
||||
|
||||
SrcType src;
|
||||
SrcType_f src_f;
|
||||
|
||||
SrcType_i select_idx;
|
||||
|
||||
select_idx.set_zero();
|
||||
|
||||
const int store_idx = cu_seq_q_pack[bidb] / kBlockM * head_num * kBlockMaxN + bidh * kBlockMaxN + blockIdx.x * head_num * kBlockMaxN + tidx * kPackSize;
|
||||
|
||||
if (seq_len_k < use_moba_seq_limit) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
select_idx.data.elt[i] = 1;
|
||||
}
|
||||
select_idx.store_to(qk_gate_topk_idx + store_idx);
|
||||
return;
|
||||
}
|
||||
|
||||
const int load_offset = (cu_seq_q[bidb] + bidt) * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
|
||||
const int data_len = seq_len_moba - tidx * kPackSize;
|
||||
|
||||
#pragma unroll
|
||||
for (int t = 0; t < kBlockM; t++) {
|
||||
if (bidt + t >= seq_len_q) {
|
||||
break;
|
||||
}
|
||||
src.load_from(qk_gate_weight + load_offset + t * head_num * kBlockMaxN);
|
||||
float min_global = FLT_MAX;
|
||||
float max_global = -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (i < data_len) {
|
||||
src_f.data.elt[i] = float(src.data.elt[i]);
|
||||
min_global = min(min_global, src_f.data.elt[i]);
|
||||
} else {
|
||||
src_f.data.elt[i] = -FLT_MAX;
|
||||
}
|
||||
max_global = max(max_global, src_f.data.elt[i]);
|
||||
}
|
||||
|
||||
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
|
||||
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
|
||||
|
||||
float right_limit = max_global;
|
||||
float left_limit = min_global;
|
||||
|
||||
float mid_limit;
|
||||
int count;
|
||||
|
||||
if (right_limit == left_limit) {
|
||||
mid_limit = (left_limit + right_limit) * 0.5f;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < searchtimes; i++) {
|
||||
mid_limit = (left_limit + right_limit) * 0.5f;
|
||||
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
|
||||
if (count < top_k_left) {
|
||||
right_limit = mid_limit;
|
||||
} else if (count > top_k_right) {
|
||||
left_limit = mid_limit;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (src_f.data.elt[i] >= mid_limit) {
|
||||
select_idx.data.elt[i] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tidx == 0) {
|
||||
select_idx.data.elt[0] = 1;
|
||||
}
|
||||
|
||||
__align__(16) __shared__ int qk_gate_mem[kBlockMaxN];
|
||||
__align__(16) __shared__ int qk_continue_idx_mem[kBlockMaxN];
|
||||
select_idx.store_to(qk_gate_mem + tidx * kPackSize);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx == 0) {
|
||||
int cur_idx = 0;
|
||||
int idx = -1;
|
||||
const int last_idx = seq_len_moba - 1;
|
||||
while (last_idx + idx >= 0 && qk_gate_mem[last_idx + idx] == 0) {
|
||||
idx--;
|
||||
}
|
||||
qk_continue_idx_mem[cur_idx] = -idx;
|
||||
cur_idx++;
|
||||
|
||||
for (int i = last_idx - 1; i >= 0; --i) {
|
||||
if (qk_gate_mem[i] == 1) {
|
||||
int idx = -1;
|
||||
while (i + idx >= 0 && qk_gate_mem[i + idx] == 0) {
|
||||
idx--;
|
||||
}
|
||||
qk_continue_idx_mem[cur_idx] = -idx;
|
||||
cur_idx++;
|
||||
}
|
||||
}
|
||||
qk_continue_idx_mem[cur_idx] = 10000000;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
*reinterpret_cast<SrcType_i *>(qk_gate_topk_idx + store_idx) = reinterpret_cast<SrcType_i *>(qk_continue_idx_mem)[tidx];
|
||||
}
|
||||
|
||||
template <int kBlockM, int kMaxN, int moba_block_size, typename T>
|
||||
void qk_gate_sort_encoder(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int* cu_seq_q,
|
||||
const int* cu_seq_k,
|
||||
const int* cu_seq_q_pack,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int batch_size,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
cudaStream_t stream) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
|
||||
const int gqa_group_size = head_num / kv_head_num;
|
||||
const int knthreads = kMaxN / kPackSize;
|
||||
const int searchtimes = 6;
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_q + kBlockM - 1) / kBlockM;
|
||||
grid_dims.y = head_num;
|
||||
grid_dims.z = batch_size;
|
||||
|
||||
constexpr auto kernel = qk_gate_sort_encoder_kernel<T, knthreads, moba_block_size, kBlockM, kMaxN, searchtimes>;
|
||||
|
||||
kernel<<<grid_dims, knthreads, 0, stream>>>(
|
||||
qk_gate_weight,
|
||||
qk_gate_topk_idx,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
gqa_group_size,
|
||||
top_k_left,
|
||||
top_k_right);
|
||||
}
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchQkSortEncoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
|
||||
paddle::Tensor qk_gate_topk_idx = paddle::empty({q_pack_tokens.data<int>()[0] / kBlockM, head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
|
||||
|
||||
qk_gate_sort_encoder<kBlockM, kMaxN, kMobaBlockSize, cute_type>(
|
||||
reinterpret_cast<const cute_type *>(qk_gate_weight.data<T>()),
|
||||
qk_gate_topk_idx.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
cu_seq_q_pack.data<int>(),
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
qk_gate_weight.stream());
|
||||
|
||||
return {qk_gate_topk_idx};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> QkSortEncoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortEncoder<phi::dtype::float16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
q_pack_tokens,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortEncoder<phi::dtype::bfloat16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
q_pack_tokens,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_qk_sort_encoder)
|
||||
.Inputs({
|
||||
"qk_gate_weight",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"cu_seq_q_pack",
|
||||
"q_pack_tokens"})
|
||||
.Attrs({
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"top_k_left: int",
|
||||
"top_k_right: int",
|
||||
"use_moba_seq_limit: int"})
|
||||
.Outputs({"qk_gate_topk_idx"})
|
||||
.SetKernelFn(PD_KERNEL(QkSortEncoder));
|
||||
194
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp
Normal file
194
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "../moba_attn_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
|
||||
}
|
||||
|
||||
__forceinline__ __device__ __half2 half_exp(__half2 x) {
|
||||
uint32_t tmp_out, tmp_in;
|
||||
tmp_in = reinterpret_cast<uint32_t&>(x);
|
||||
asm ("ex2.approx.f16x2 %0, %1;\n"
|
||||
: "=r"(tmp_out)
|
||||
: "r"(tmp_in));
|
||||
__half2 out = reinterpret_cast<__half2&>(tmp_out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const float max_scaled = max(mi) * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
CUTLASS_DEVICE Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<bool Is_first, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT scores_scale;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.0f / sum;
|
||||
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
|
||||
scores_scale(mi) = inv_sum;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
@@ -0,0 +1,288 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
template <typename T, int kBlockSize, int kHeadDim>
|
||||
__global__ void get_kv_from_cache_c16_kernel(
|
||||
T * k_input,
|
||||
T * v_input,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
const int * cu_seq_k,
|
||||
const T * cache_k,
|
||||
const T * cache_v,
|
||||
const int * block_tables,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int batch_size,
|
||||
const int max_input_length,
|
||||
const int max_blocks_per_seq) {
|
||||
|
||||
const int block_idx = blockIdx.x;
|
||||
int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int seq_len = seq_len_decoder[bidb] + seq_len_encoder[bidb];
|
||||
const int tidx = threadIdx.x;
|
||||
const int base_token_idx = block_idx * kBlockSize;
|
||||
|
||||
if (base_token_idx >= seq_len || seq_len_encoder[bidb] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
|
||||
const int row_idx = tidx / (kHeadDim / kPackSize);
|
||||
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
|
||||
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
|
||||
|
||||
|
||||
const int ramian_tokens = seq_len - base_token_idx;
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bidh -= kv_head_num;
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void get_kv_from_cache(
|
||||
T * k_input,
|
||||
T * v_input,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
const int * cu_seq_k,
|
||||
const void * cache_k,
|
||||
const void * cache_v,
|
||||
const int * block_tables,
|
||||
const T * cache_k_dequant_scale,
|
||||
const T * cache_v_dequant_scale,
|
||||
const T * cache_k_zero_points,
|
||||
const T * cache_v_zero_points,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_k,
|
||||
const int batch_size,
|
||||
const int max_input_length,
|
||||
const int max_blocks_per_seq,
|
||||
const std::string &cache_quant_type_str,
|
||||
cudaStream_t stream) {
|
||||
|
||||
constexpr int kThreads = 128;
|
||||
constexpr int kHeadDim = 128;
|
||||
assert(kHeadDim == head_dim);
|
||||
constexpr int kBlockSize = 64;
|
||||
if (cache_quant_type_str == "none") {
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_k + kBlockSize - 1) / kBlockSize;
|
||||
grid_dims.y = kv_head_num * 2;
|
||||
grid_dims.z = batch_size;
|
||||
get_kv_from_cache_c16_kernel<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, stream>>>(
|
||||
k_input,
|
||||
v_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_k,
|
||||
reinterpret_cast<const T*>(cache_k),
|
||||
reinterpret_cast<const T*>(cache_v),
|
||||
block_tables,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
batch_size,
|
||||
max_input_length,
|
||||
max_blocks_per_seq);
|
||||
} else {
|
||||
PD_THROW("Only supported cache_quant_type_str in ['none'].");
|
||||
}
|
||||
}
|
||||
|
||||
void GetKVFromCache(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
get_kv_from_cache<cute_type>(
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
cache_k.data(),
|
||||
cache_v.data(),
|
||||
block_tables.data<int>(),
|
||||
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
|
||||
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_k,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
block_tables.dims()[1],
|
||||
cache_quant_type_str,
|
||||
k_input.stream());
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
get_kv_from_cache<cute_type>(
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
cache_k.data(),
|
||||
cache_v.data(),
|
||||
block_tables.data<int>(),
|
||||
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
|
||||
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_k,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
block_tables.dims()[1],
|
||||
cache_quant_type_str,
|
||||
k_input.stream());
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void get_cur_cu_seq_len_k_kernel(
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ seq_lens_decoder,
|
||||
const int* __restrict__ seq_lens_this_time,
|
||||
int* __restrict__ cu_seqlens_k,
|
||||
int* __restrict__ cu_seq_q_pack,
|
||||
int* __restrict__ q_pack_tokens,
|
||||
const int pack_size,
|
||||
const int bsz) {
|
||||
|
||||
int total_tokens = 0;
|
||||
cu_seqlens_k[0] = 0;
|
||||
cu_seq_q_pack[0] = 0;
|
||||
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int cache_len = seq_lens_decoder[bid];
|
||||
const int q_len = seq_lens_encoder[bid];
|
||||
if (q_len <= 0) {
|
||||
cache_len = 0;
|
||||
}
|
||||
total_tokens += (cache_len + q_len);
|
||||
cu_seqlens_k[bid + 1] = total_tokens;
|
||||
cu_seq_q_pack[bid + 1] = cu_seq_q_pack[bid] + (q_len + pack_size -1) / pack_size * pack_size;
|
||||
}
|
||||
q_pack_tokens[0] = cu_seq_q_pack[bsz];
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetCurCuSeqLenk(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const int pack_size) {
|
||||
auto stream = seq_lens_decoder.stream();
|
||||
auto place = seq_lens_decoder.place();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
paddle::Tensor cu_seq_q_pack = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor cu_seqlens_k = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor q_pack_tokens = paddle::empty({1}, paddle::DataType::INT32, place);
|
||||
|
||||
get_cur_cu_seq_len_k_kernel<<<1, 1, 0, stream>>>(
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cu_seq_q_pack.data<int>(),
|
||||
q_pack_tokens.data<int>(),
|
||||
pack_size,
|
||||
bsz
|
||||
);
|
||||
|
||||
auto q_pack_tokens_cpu = q_pack_tokens.copy_to(paddle::CPUPlace(), true);
|
||||
return {cu_seq_q_pack, cu_seqlens_k, q_pack_tokens_cpu};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_kv_from_cache)
|
||||
.Inputs({
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cache_k",
|
||||
"cache_v",
|
||||
"block_tables",
|
||||
paddle::Optional("cache_k_dequant_scale"),
|
||||
paddle::Optional("cache_v_dequant_scale"),
|
||||
paddle::Optional("cache_k_zero_points"),
|
||||
paddle::Optional("cache_v_zero_points")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_input_length: int",
|
||||
"max_seq_k: int",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({"k_input_out", "v_input_out"})
|
||||
.SetInplaceMap({{"k_input", "k_input_out"},
|
||||
{"v_input", "v_input_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetKVFromCache));
|
||||
|
||||
PD_BUILD_OP(get_cur_cu_seq_len_k)
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time"})
|
||||
.Attrs({
|
||||
"pack_size: int"})
|
||||
.Outputs({"cu_seq_q_pack", "cu_seqlens_k", "q_pack_tokens"})
|
||||
.SetKernelFn(PD_KERNEL(GetCurCuSeqLenk));
|
||||
221
custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu
Normal file
221
custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu
Normal file
@@ -0,0 +1,221 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template <typename T, int moba_block_size, int kHeadDim, int kMaxN>
|
||||
__global__ void moba_mlp_einsum_kernel(
|
||||
const T * src_data,
|
||||
const T * weight_data,
|
||||
const int * seq_lens_encoder,
|
||||
const int * seq_lens_decoder,
|
||||
const int * cu_seq_k,
|
||||
T * dst_data,
|
||||
const int head_num) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
const int block_idx = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
const int lane_id = tidx % 32;
|
||||
const int warp_id = tidx / 32;
|
||||
|
||||
__align__(16) __shared__ T local_sum_mem[128 / 32 * kHeadDim];
|
||||
|
||||
const int seq_len_encoder = seq_lens_encoder[bidb];
|
||||
const int seq_len_decoder = seq_len_encoder + seq_lens_decoder[bidb];
|
||||
|
||||
const int seq_len_this_block = seq_len_decoder - block_idx * moba_block_size;
|
||||
|
||||
if (seq_len_encoder == 0 || seq_len_this_block <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
|
||||
constexpr int tidx_per_row = kHeadDim / kPackSize;
|
||||
|
||||
const int row_idx = tidx / tidx_per_row;
|
||||
const int col_idx = tidx % tidx_per_row * kPackSize;
|
||||
|
||||
const int src_base_idx = cu_seq_k[bidb] * head_num * kHeadDim + block_idx * moba_block_size * head_num * kHeadDim + bidh * kHeadDim + row_idx * head_num * kHeadDim + col_idx;
|
||||
const int weight_base_idx = bidh * kHeadDim * moba_block_size + row_idx * kHeadDim + col_idx;
|
||||
|
||||
constexpr int step = 128 / tidx_per_row;
|
||||
|
||||
SrcType sums, src, weight;
|
||||
|
||||
sums.set_zero();
|
||||
|
||||
for (int i = 0; i < moba_block_size; i += step) {
|
||||
if (i >= seq_len_this_block) {
|
||||
break;
|
||||
}
|
||||
src.load_from(src_data + src_base_idx + i * head_num * kHeadDim);
|
||||
weight.load_from(weight_data + weight_base_idx + i * kHeadDim);
|
||||
sums.fma(src, weight);
|
||||
}
|
||||
|
||||
SrcType neighbor;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i+=2) {
|
||||
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(sums.data.elt + i), 16);
|
||||
}
|
||||
|
||||
sums.add(neighbor);
|
||||
|
||||
if (lane_id < 16) {
|
||||
sums.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
using pack_half = std::conditional_t<std::is_same<T, phi::dtype::float16>::value, __half2, nv_bfloat162>;
|
||||
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
|
||||
|
||||
if (tidx < kHeadDim / 2) {
|
||||
pack_half local_sum_half = local_sum_mem_half[tidx];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 4; i++) {
|
||||
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
|
||||
}
|
||||
local_sum_mem_half[tidx] = local_sum_half;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int store_row_id = tidx / (kHeadDim / kPackSize);
|
||||
const int store_col_id = tidx % (kHeadDim / kPackSize) * kPackSize;
|
||||
|
||||
sums.load_from(local_sum_mem + store_col_id);
|
||||
|
||||
const int base_store_idx = bidb * kMaxN * head_num * kHeadDim + (block_idx * (moba_block_size / 128) + store_row_id) * head_num * kHeadDim + bidh * kHeadDim + store_col_id;
|
||||
|
||||
if (store_row_id < moba_block_size / 128) {
|
||||
sums.store_to(dst_data + base_store_idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int kHeadDim, int kMaxN>
|
||||
void moba_mlp_einsum(
|
||||
const T * src_data,
|
||||
const T * weight_data,
|
||||
const int * seq_lens_encoder,
|
||||
const int * seq_lens_decoder,
|
||||
const int * cu_seq_k,
|
||||
T * dst_data,
|
||||
const int moba_block_size,
|
||||
const int max_seq_len,
|
||||
const int head_num,
|
||||
const int batch_size,
|
||||
cudaStream_t stream) {
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_len + moba_block_size - 1) / moba_block_size;
|
||||
grid_dims.y = head_num;
|
||||
grid_dims.z = batch_size;
|
||||
|
||||
if (moba_block_size == 1024) {
|
||||
moba_mlp_einsum_kernel<T, 1024, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
|
||||
src_data,
|
||||
weight_data,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seq_k,
|
||||
dst_data,
|
||||
head_num);
|
||||
} else if (moba_block_size == 128) {
|
||||
moba_mlp_einsum_kernel<T, 128, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
|
||||
src_data,
|
||||
weight_data,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seq_k,
|
||||
dst_data,
|
||||
head_num);
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"MobaMlpEinsum not implemented for moba_block_size = %d", moba_block_size));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MobaMlpEinsum(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& attn_gate_weight,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_len,
|
||||
const int kv_head_num) {
|
||||
|
||||
const int kHeadDim = 128;
|
||||
const int kMaxN = 1024;
|
||||
const int moba_block_size = attn_gate_weight.dims()[1];
|
||||
const int batch_size = seq_lens_encoder.dims()[0];
|
||||
paddle::Tensor k_gate_weight = paddle::zeros({batch_size, kMaxN, kv_head_num, kHeadDim}, k_input.dtype(), k_input.place());
|
||||
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
moba_mlp_einsum<T, kHeadDim, kMaxN>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(attn_gate_weight.data<T>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int*>(cu_seq_k.data<int>()),
|
||||
k_gate_weight.data<T>(),
|
||||
moba_block_size,
|
||||
max_seq_len,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
k_input.stream()
|
||||
);
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
moba_mlp_einsum<T, kHeadDim, kMaxN>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(attn_gate_weight.data<T>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int*>(cu_seq_k.data<int>()),
|
||||
k_gate_weight.data<T>(),
|
||||
moba_block_size,
|
||||
max_seq_len,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
k_input.stream()
|
||||
);
|
||||
}
|
||||
return {k_gate_weight};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_mlp_einsum)
|
||||
.Inputs({
|
||||
"k_input",
|
||||
"attn_gate_weight",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"cu_seq_k"})
|
||||
.Attrs({
|
||||
"max_seq_len: int",
|
||||
"kv_head_num: int"})
|
||||
.Outputs({"k_gate"})
|
||||
.SetKernelFn(PD_KERNEL(MobaMlpEinsum));
|
||||
465
custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu
Normal file
465
custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu
Normal file
@@ -0,0 +1,465 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, int kHeadDim, bool is_split_kv>
|
||||
__global__ void qk_gemm_kernel(
|
||||
const input_type *q_input,
|
||||
const input_type *k_gate_mean,
|
||||
input_type *qk_gate_weight,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int kGQA_groupsize) {
|
||||
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, input_type,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, input_type, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomQK = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, input_type,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutQK = decltype(tile_to_shape(SmemLayoutAtomQK{}, select<0, 1>(TileShape_MNK{})));
|
||||
|
||||
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<input_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
|
||||
using ValLayoutMNK = std::conditional_t<
|
||||
is_split_kv,
|
||||
Layout<Shape<_1,_4,_1>>,
|
||||
Layout<Shape<_4,_1,_1>>
|
||||
>;
|
||||
|
||||
using PermutationMNK = std::conditional_t<
|
||||
is_split_kv,
|
||||
Tile<_16,_64,_16>,
|
||||
Tile<_64,_16,_16>
|
||||
>;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
MMA_Atom_Arch,
|
||||
ValLayoutMNK,
|
||||
PermutationMNK>;
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, input_type>;
|
||||
using SmemCopyAtomQK = Copy_Atom<cute::SM90_U32x4_STSM_N, input_type>;
|
||||
|
||||
constexpr int kNThreads = 128;
|
||||
constexpr int kThreadPerValue = 16 / sizeof(input_type);
|
||||
constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
|
||||
constexpr int kThreadsPerRowQK = kBlockN / kThreadPerValue;
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
|
||||
Stride<Int<kThreadsPerRow>, _1>>;
|
||||
|
||||
using GmemTiledCopy = decltype(
|
||||
make_tiled_copy(Copy_Atom<
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, input_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
using GmemLayoutAtomQK = Layout<
|
||||
Shape <Int<kNThreads / kThreadsPerRowQK>, Int<kThreadsPerRowQK>>,
|
||||
Stride<Int<kThreadsPerRowQK>, _1>>;
|
||||
|
||||
using GmemTiledCopyQK = decltype(
|
||||
make_tiled_copy(Copy_Atom<
|
||||
UniversalCopy<cutlass::uint128_t>, input_type>{},
|
||||
GmemLayoutAtomQK{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
int mn_block = blockIdx.x;
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidh = blockIdx.z;
|
||||
const int bidh_k = bidh / kGQA_groupsize;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int seq_len_q = seq_len_encoder[bidb];
|
||||
const int seq_len_k = seq_len_decoder[bidb];
|
||||
const int seq_len_qk = seq_len_q + seq_len_k;
|
||||
|
||||
int q_head_stride;
|
||||
const int k_head_stride = kv_head_num * kHeadDim;
|
||||
int qk_head_stride;
|
||||
int offset_q;
|
||||
int offset_k;
|
||||
int offset_qk;
|
||||
int remain_q_seq;
|
||||
|
||||
if constexpr (is_split_kv) {
|
||||
if (seq_len_k < use_moba_seq_limit || seq_len_k == 0) {
|
||||
return;
|
||||
}
|
||||
mn_block *= kBlockN;
|
||||
q_head_stride = kHeadDim;
|
||||
qk_head_stride = kMaxN;
|
||||
if (mn_block >= (seq_len_k + kMobaBlockSize - 1) / kMobaBlockSize) {
|
||||
return;
|
||||
}
|
||||
offset_q = cu_seq_q[bidb] * head_num * kHeadDim + bidh * kGQA_groupsize * kHeadDim;
|
||||
offset_k = (bidb * kMaxN + mn_block) * k_head_stride + bidh * kHeadDim;
|
||||
offset_qk = bidb * head_num * kMaxN + bidh * kGQA_groupsize * kMaxN + mn_block;
|
||||
remain_q_seq = kGQA_groupsize;
|
||||
} else {
|
||||
if (seq_len_q == 0 || seq_len_qk < use_moba_seq_limit) {
|
||||
return;
|
||||
}
|
||||
q_head_stride = head_num * kHeadDim;
|
||||
qk_head_stride = head_num * kMaxN;
|
||||
mn_block *= kBlockM;
|
||||
if (mn_block >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
offset_q = (cu_seq_q[bidb] + mn_block) * q_head_stride + bidh * kHeadDim;
|
||||
offset_k = bidb * kMaxN * k_head_stride + bidh_k * kHeadDim;
|
||||
offset_qk = (cu_seq_q[bidb] + mn_block) * qk_head_stride + bidh * kMaxN;
|
||||
remain_q_seq = seq_len_q - mn_block;
|
||||
}
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(q_input + offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(q_head_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(k_gate_mean + offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(k_head_stride, _1{}));
|
||||
Tensor gQK = make_tensor(make_gmem_ptr(qk_gate_weight + offset_qk),
|
||||
Shape<Int<kBlockM>, Int<kBlockN>>{},
|
||||
make_stride(qk_head_stride, _1{}));
|
||||
|
||||
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<input_type *>(smem_)), SmemLayoutK{});
|
||||
Tensor sQ = make_tensor(sK.data() + size(sK), SmemLayoutQ{});
|
||||
Tensor sQK = make_tensor(sK.data() + size(sK), SmemLayoutQK{});
|
||||
|
||||
auto gmem_tiled_copy = GmemTiledCopy{};
|
||||
auto gmem_tiled_copy_qk = GmemTiledCopyQK{};
|
||||
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx);
|
||||
auto gmem_thr_copy_qk = gmem_tiled_copy_qk.get_thread_slice(tidx);
|
||||
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy.partition_D(sQ);
|
||||
|
||||
Tensor tKgK = gmem_thr_copy.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy.partition_D(sK);
|
||||
|
||||
Tensor tQKgQK = gmem_thr_copy_qk.partition_S(gQK);
|
||||
Tensor tQKsQK = gmem_thr_copy_qk.partition_D(sQK);
|
||||
|
||||
|
||||
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
|
||||
Tensor tQcQ = gmem_thr_copy.partition_S(cQ);
|
||||
|
||||
Tensor cK = make_identity_tensor(make_shape(kBlockN, kHeadDim));
|
||||
Tensor tKcK = gmem_thr_copy.partition_S(cK);
|
||||
|
||||
Tensor cQK = make_identity_tensor(make_shape(kBlockM, kBlockN));
|
||||
Tensor tQKcQK = gmem_thr_copy.partition_S(cQK);
|
||||
|
||||
if (remain_q_seq >= kBlockM) {
|
||||
copy(gmem_tiled_copy, tQgQ, tQsQ, tQcQ);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy, tQgQ, tQsQ, tQcQ, remain_q_seq);
|
||||
}
|
||||
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
|
||||
|
||||
cute::cp_async_fence();
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK);
|
||||
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
|
||||
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
|
||||
auto smem_tiled_copy_QK = make_tiled_copy_C(SmemCopyAtomQK{}, tiled_mma);
|
||||
auto smem_thr_copy_QK = smem_tiled_copy_QK.get_thread_slice(tidx);
|
||||
Tensor tsQK = smem_thr_copy_QK.partition_D(sQK);
|
||||
|
||||
const int n_blocks = is_split_kv ? 1 : cute::ceil_div(cute::ceil_div(seq_len_qk, kMobaBlockSize), kBlockN);
|
||||
|
||||
#pragma unroll
|
||||
for (int n_block = 0; n_block < n_blocks; ++n_block) {
|
||||
clear(acc_s);
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block == 0) {
|
||||
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
} else {
|
||||
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
}
|
||||
if constexpr (!is_split_kv) {
|
||||
if (n_block < n_blocks - 1) {
|
||||
__syncthreads();
|
||||
tKgK.data() = tKgK.data() + kBlockN * k_head_stride;
|
||||
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor rS = convert_type<input_type>(acc_s);
|
||||
Tensor trQK = smem_thr_copy_QK.retile_S(rS);
|
||||
cute::copy(smem_tiled_copy_QK, trQK, tsQK);
|
||||
|
||||
__syncthreads();
|
||||
if (remain_q_seq >= kBlockM) {
|
||||
copy(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK, remain_q_seq);
|
||||
}
|
||||
if constexpr (!is_split_kv) {
|
||||
__syncthreads();
|
||||
tQKgQK.data() = tQKgQK.data() + kBlockN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, bool is_split_kv>
|
||||
void qk_gemm(
|
||||
const input_type *q_input,
|
||||
const input_type *k_gate_mean,
|
||||
input_type *qk_gate_weight,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int bsz,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int gqa_group_size = head_num / kv_head_num;
|
||||
|
||||
dim3 grid_dims;
|
||||
const int num_m_block = (max_seq_q + kBlockM - 1) / kBlockM;
|
||||
const int num_n_block = ((max_seq_k + kMobaBlockSize - 1) / kMobaBlockSize + kBlockN - 1) / kBlockN;
|
||||
|
||||
if (is_split_kv) {
|
||||
grid_dims.x = num_n_block;
|
||||
grid_dims.z = kv_head_num;
|
||||
} else {
|
||||
grid_dims.x = num_m_block;
|
||||
grid_dims.z = head_num;
|
||||
}
|
||||
grid_dims.y = bsz;
|
||||
|
||||
constexpr int kHeadDim = 128;
|
||||
constexpr int smemq = kBlockM * kHeadDim * sizeof(input_type);
|
||||
constexpr int smemk = kBlockN * kHeadDim * sizeof(input_type);
|
||||
constexpr int smemqk = kBlockM * kBlockN * sizeof(input_type);
|
||||
const int smem_size = smemk + max(smemq, smemqk);
|
||||
|
||||
auto kernel = &qk_gemm_kernel<input_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, kHeadDim, is_split_kv>;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
kernel<<<grid_dims, 128, smem_size, stream>>>(
|
||||
q_input,
|
||||
k_gate_mean,
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
gqa_group_size);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchMobaQKGemm(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const bool is_split_kv,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
if (is_split_kv) {
|
||||
paddle::Tensor qk_gate_weight = paddle::empty({batch_size, head_num, kMaxN}, q_input.dtype(), q_input.place());
|
||||
qk_gemm<cute_type, 16, kMobaBlockSize, kMobaBlockSize, kMaxN, true>(
|
||||
reinterpret_cast<const cute_type*>(q_input.data<T>()),
|
||||
reinterpret_cast<const cute_type*>(k_block_means.data<T>()),
|
||||
reinterpret_cast<cute_type*>(qk_gate_weight.data<T>()),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
q_input.stream()
|
||||
);
|
||||
return {qk_gate_weight};
|
||||
} else {
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
const int token_num = q_input.dims()[0];
|
||||
paddle::Tensor qk_gate_weight = paddle::empty({token_num, head_num, kMaxN}, q_input.dtype(), q_input.place());
|
||||
qk_gemm<cute_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, false>(
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>())),
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
|
||||
reinterpret_cast<cute_type *>(qk_gate_weight.data<T>()),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
q_input.stream());
|
||||
return {qk_gate_weight};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MobaQKGemm(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const bool is_split_kv,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
return std::move(
|
||||
DispatchMobaQKGemm<phi::dtype::float16>(
|
||||
q_input,
|
||||
k_block_means,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
is_split_kv,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return std::move(
|
||||
DispatchMobaQKGemm<phi::dtype::bfloat16>(
|
||||
q_input,
|
||||
k_block_means,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
is_split_kv,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_qk_gemm)
|
||||
.Inputs({
|
||||
"q_input",
|
||||
"k_block_means",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k"})
|
||||
.Attrs({
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"is_split_kv: bool",
|
||||
"use_moba_seq_limit: int"})
|
||||
.Outputs({"qk_gate_weight"})
|
||||
.SetKernelFn(PD_KERNEL(MobaQKGemm));
|
||||
370
custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu
Normal file
370
custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu
Normal file
@@ -0,0 +1,370 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN, int tokens_per_block, bool need_k_mean>
|
||||
__global__ void fused_block_mean_and_rope_kernel(
|
||||
const input_type *qkv_input,
|
||||
const input_type *qkv_bias,
|
||||
input_type *k_gate_mean,
|
||||
input_type *q_input,
|
||||
input_type *k_input,
|
||||
input_type *v_input,
|
||||
const float *rope_sin_cos,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int max_input_length) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(input_type);
|
||||
constexpr int kHeadDim = 128;
|
||||
|
||||
using src_type = Vec<input_type, kPackSize>;
|
||||
|
||||
using rope_type = Vec<float, kPackSize / 2>;
|
||||
using pack_half = std::conditional_t<std::is_same<input_type, cutlass::half_t>::value, __half2, nv_bfloat162>;
|
||||
|
||||
__align__(16) __shared__ input_type local_sum_mem[128 / 32 * kHeadDim];
|
||||
|
||||
const int bidb = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidt_q = blockIdx.z * tokens_per_block;
|
||||
const int bidt_v = blockIdx.z * tokens_per_block;
|
||||
const int bidt_k = need_k_mean ? blockIdx.z * moba_block_size : blockIdx.z * tokens_per_block;
|
||||
const int tidx = threadIdx.x;
|
||||
const int lane_id = tidx % 32;
|
||||
const int warp_id = tidx / 32;
|
||||
const int seq_len = seq_len_encoder[bidb];
|
||||
const int seq_len_start = seq_len_decoder[bidb];
|
||||
|
||||
if (seq_len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int all_head_num = head_num + 2 * kv_head_num;
|
||||
const int hidden = all_head_num * kHeadDim;
|
||||
|
||||
const int row_idx = tidx / (kHeadDim / kPackSize);
|
||||
const int col_idx = tidx % (kHeadDim / kPackSize);
|
||||
|
||||
const int bias_idx = bidh * kHeadDim + col_idx * kPackSize;
|
||||
|
||||
src_type src, src_bias;
|
||||
rope_type sin, cos;
|
||||
|
||||
const bool need_add_bias = qkv_bias != nullptr;
|
||||
|
||||
if (need_add_bias) {
|
||||
src_bias.load_from(qkv_bias + bias_idx);
|
||||
}
|
||||
|
||||
if (bidh < head_num) {
|
||||
const int cur_token = bidt_q + row_idx;
|
||||
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
|
||||
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(q_input + (cu_seq_q[bidb] + cur_token) * head_num * kHeadDim + bias_idx);
|
||||
}
|
||||
} else if (bidh < head_num + kv_head_num) {
|
||||
if constexpr (!need_k_mean) {
|
||||
const int cur_token = bidt_k + row_idx;
|
||||
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
|
||||
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * head_num * kHeadDim + bias_idx- head_num * kHeadDim);
|
||||
}
|
||||
} else {
|
||||
if (bidt_k >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
src_type local_sum;
|
||||
local_sum.set_zero();
|
||||
|
||||
const input_type* qkv = qkv_input + cu_seq_q[bidb] * hidden + bias_idx;
|
||||
|
||||
for (int i = 0; i < moba_block_size; i += tokens_per_block) {
|
||||
const int cur_token = bidt_k + i + row_idx;
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv + cur_token * hidden);
|
||||
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
|
||||
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - head_num * kHeadDim);
|
||||
|
||||
local_sum.add(src);
|
||||
}
|
||||
}
|
||||
|
||||
src_type neighbor;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i+=2) {
|
||||
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(local_sum.data.elt + i), 16);
|
||||
}
|
||||
|
||||
local_sum.add(neighbor);
|
||||
|
||||
if (lane_id < 16) {
|
||||
local_sum.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
|
||||
|
||||
pack_half local_sum_half = local_sum_mem_half[tidx];
|
||||
|
||||
|
||||
if (tidx < kHeadDim / 2) {
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 4; i++) {
|
||||
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
|
||||
}
|
||||
|
||||
float inv_tokens_sum = fdividef(1.0f, min(seq_len - bidt_k, moba_block_size));
|
||||
|
||||
local_sum_half *= float_2_half2<input_type>(inv_tokens_sum);
|
||||
|
||||
const int store_mean_idx = ((bidb * kMaxN + blockIdx.z + seq_len_start / moba_block_size) * kv_head_num * kHeadDim + (bidh - head_num) * kHeadDim) / 2 + tidx;
|
||||
|
||||
reinterpret_cast<pack_half*>(k_gate_mean)[store_mean_idx] = local_sum_half;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int cur_token = bidt_v + row_idx;
|
||||
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
|
||||
src.store_to(v_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - (head_num + kv_head_num) * kHeadDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN>
|
||||
void fused_block_mean_and_rope(
|
||||
const input_type *qkv_input,
|
||||
const input_type *qkv_bias,
|
||||
input_type *k_gate_mean,
|
||||
input_type *q_input,
|
||||
input_type *k_input,
|
||||
input_type *v_input,
|
||||
const float *rope_sin_cos,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int bsz,
|
||||
const int max_input_length,
|
||||
cudaStream_t stream) {
|
||||
|
||||
static_assert(moba_block_size >= 64, "moba_block_size must be at least 64");
|
||||
constexpr int kPackSize = 16 / sizeof(input_type);
|
||||
constexpr int kHeadDim = 128;
|
||||
constexpr int kThreads = 128;
|
||||
constexpr int tokens_per_block = kThreads / (kHeadDim / kPackSize);
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = bsz;
|
||||
grid_dims.y = head_num + 2 * kv_head_num;
|
||||
grid_dims.z = (max_seq_q + tokens_per_block - 1) / tokens_per_block;
|
||||
|
||||
if (k_gate_mean != nullptr) {
|
||||
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, true>
|
||||
<<<grid_dims, kThreads, 0, stream>>>(
|
||||
qkv_input,
|
||||
qkv_bias,
|
||||
k_gate_mean,
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
rope_sin_cos,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_input_length);
|
||||
} else {
|
||||
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, false>
|
||||
<<<grid_dims, kThreads, 0, stream>>>(
|
||||
qkv_input,
|
||||
qkv_bias,
|
||||
k_gate_mean,
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
rope_sin_cos,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_input_length);
|
||||
}
|
||||
}
|
||||
|
||||
void FusedBlockMeanAndRope(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
|
||||
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
rotary_embs.data<float>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
qkv_out.stream());
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
|
||||
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
rotary_embs.data<float>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
qkv_out.stream());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_OP(fused_block_mean_and_rope)
|
||||
.Inputs({
|
||||
"qkv_out",
|
||||
"k_block_means",
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"rotary_embs",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
paddle::Optional("qkv_bias")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_input_length: int",
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({"q_input_out", "k_input_out", "v_input_out", "k_block_means_out"})
|
||||
.SetInplaceMap({{"q_input", "q_input_out"},
|
||||
{"k_input", "k_input_out"},
|
||||
{"v_input", "v_input_out"},
|
||||
{"k_block_means", "k_block_means_out"}})
|
||||
.SetKernelFn(PD_KERNEL(FusedBlockMeanAndRope));
|
||||
Reference in New Issue
Block a user