Support GPT-OSS-BF16 (#4240)

* [Feature] AppendAtten support sinks & HEAD_DIM=64

* fix bug

* fix bug

* fix bug

* fix bug

* [Feature] support gpt-oss

* fix bug

* add mask

* support-gpt-oss

* support-gpt-oss

* fix long seq

* support wint8

* support wint8

* support wint8

* update test

* change sliding windows init pos

---------

Co-authored-by: ming1753 <ideaminghp@163.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
This commit is contained in:
Haonan Luo
2025-10-20 14:44:58 +08:00
committed by GitHub
parent 80a16c4c87
commit 1b9f351d21
32 changed files with 1502 additions and 172 deletions

View File

@@ -72,10 +72,10 @@ void AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -90,7 +90,8 @@ void AppendAttentionKernel(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -146,6 +147,7 @@ void AppendAttentionKernel(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
sinks,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
@@ -169,7 +171,8 @@ void AppendAttentionKernel(
lambda_is_decoder,
lambda_enable_prefill,
lambda_stream,
&fmha_out);
&fmha_out,
sliding_window);
};
if (max_enc_len_this_time > 0) {
@@ -428,6 +431,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
@@ -443,7 +447,8 @@ std::vector<paddle::Tensor> AppendAttention(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
AppendAttnMetaData meta_data;
const auto& qkv_dims = qkv.dims();
@@ -550,10 +555,10 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
@@ -568,7 +573,8 @@ std::vector<paddle::Tensor> AppendAttention(
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
sliding_window);
};
@@ -630,6 +636,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
@@ -645,7 +652,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
AppendAttnMetaData meta_data;
const auto& qkv_dims = qkv.dims();
@@ -704,10 +712,10 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
@@ -722,7 +730,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
sliding_window);
};
phi::dtype::float16 fp16_dtype;
@@ -797,6 +806,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& sinks_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
@@ -812,7 +822,8 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
int head_dim = key_cache_shape[3];
@@ -860,6 +871,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const paddle::optional<paddle::DataType>& sinks_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
@@ -875,7 +887,8 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
if (compute_dtype == "bf16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
@@ -942,6 +955,7 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& sinks_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
@@ -957,7 +971,8 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
return {fmha_out_shape};
}
@@ -998,6 +1013,7 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const paddle::optional<paddle::DataType>& sinks_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
@@ -1013,7 +1029,8 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
return {fmha_out_dtype};
}
@@ -1054,7 +1071,8 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
.Outputs({"fmha_out"})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
@@ -1072,6 +1090,7 @@ PD_BUILD_STATIC_OP(append_attention)
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
@@ -1113,7 +1132,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
.Outputs({"fmha_out_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
.Attrs({"rms_norm_eps: float",
@@ -1132,6 +1152,7 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))

View File

@@ -36,6 +36,8 @@ void CascadeAppendAttentionC16Kernel(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -58,7 +60,8 @@ void CascadeAppendAttentionC16Kernel(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out) {
paddle::Tensor* out,
const int sliding_window) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
@@ -99,6 +102,7 @@ void CascadeAppendAttentionC16Kernel(
attn_mask,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
@@ -118,7 +122,8 @@ void CascadeAppendAttentionC16Kernel(
speculate_max_draft_token_num,
is_decoder,
stream,
out);
out,
sliding_window);
})})})})})})
}
@@ -142,6 +147,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -164,7 +171,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
@@ -186,6 +194,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -208,7 +218,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -230,6 +241,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -252,7 +265,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -274,6 +288,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -296,7 +312,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
@@ -318,6 +335,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -340,7 +359,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -362,6 +382,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -384,4 +406,5 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

View File

@@ -36,6 +36,8 @@ void CascadeAppendAttentionC4Kernel(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -58,7 +60,8 @@ void CascadeAppendAttentionC4Kernel(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out) {
paddle::Tensor* out,
const int sliding_window) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
@@ -103,6 +106,7 @@ void CascadeAppendAttentionC4Kernel(
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
@@ -122,7 +126,8 @@ void CascadeAppendAttentionC4Kernel(
speculate_max_draft_token_num,
is_decoder,
stream,
out);
out,
sliding_window);
})})})})})})
}
@@ -146,6 +151,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -168,7 +175,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
@@ -190,6 +198,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -212,7 +222,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -234,6 +245,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -256,7 +269,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -278,6 +292,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -300,7 +316,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
@@ -322,6 +339,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -344,7 +363,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -366,6 +386,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -388,4 +410,5 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

View File

@@ -36,6 +36,8 @@ void CascadeAppendAttentionC8Kernel(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -59,7 +61,8 @@ void CascadeAppendAttentionC8Kernel(
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out) {
paddle::Tensor* out,
const int sliding_window) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
@@ -106,7 +109,8 @@ void CascadeAppendAttentionC8Kernel(
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
@@ -125,7 +129,8 @@ void CascadeAppendAttentionC8Kernel(
speculate_max_draft_token_num,
is_decoder,
stream,
out);
out,
sliding_window);
})})})})})})})
}
@@ -141,6 +146,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -164,7 +170,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, true>(
const AppendAttnMetaData& meta_data,
@@ -178,6 +185,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -201,7 +209,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, false>(
const AppendAttnMetaData& meta_data,
@@ -215,6 +224,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -238,7 +248,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, true>(
const AppendAttnMetaData& meta_data,
@@ -252,6 +263,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -275,7 +287,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const AppendAttnMetaData& meta_data,
@@ -289,6 +302,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -312,7 +326,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const AppendAttnMetaData& meta_data,
@@ -326,6 +341,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -349,7 +365,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const AppendAttnMetaData& meta_data,
@@ -363,6 +380,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -386,7 +404,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const AppendAttnMetaData& meta_data,
@@ -400,6 +419,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -423,7 +443,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, false>(
const AppendAttnMetaData& meta_data,
@@ -437,6 +458,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -460,7 +482,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, true>(
const AppendAttnMetaData& meta_data,
@@ -474,6 +497,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -497,7 +521,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const AppendAttnMetaData& meta_data,
@@ -511,6 +536,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -534,7 +560,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const AppendAttnMetaData& meta_data,
@@ -548,6 +575,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::optional<paddle::Tensor>& sinks,
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -571,4 +599,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

View File

@@ -77,6 +77,14 @@ struct prefill_softmax_state_t {
__device__ __forceinline__ void normalize() {
const T d_t = static_cast<T>(d);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
}
}
__device__ __forceinline__ void normalize(float current_sink) {
const T d_t = static_cast<T>(d + __expf(current_sink - m));
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
@@ -1028,7 +1036,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const int *mask_offset = nullptr,
const int sliding_window = 0) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -1045,11 +1054,21 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
} else {
}
else if (sliding_window > 0)
{
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
(causal
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
else
{
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
@@ -1064,7 +1083,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
@@ -1458,6 +1477,33 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
}
}
template <uint32_t num_frags_x, uint32_t num_frags_y>
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
float (*d)[2],
float (*m)[2],
float (*current_sinks)[2]) {
float d_rcp[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
d_rcp[fx][j] = 1.f / (d[fx][j] + __expf(current_sinks[fx][j] - m[fx][j]));
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_frag[fx][fy][reg_id] =
o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2];
}
}
}
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t NUM_WARPS,
@@ -2271,6 +2317,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
@@ -2354,7 +2401,12 @@ __global__ void merge_multi_chunks_decoder_kernel(
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
if (sinks) {
float current_sink = static_cast<float>(sinks[hid]);
st.normalize(current_sink);
} else {
st.normalize();
}
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
AlignedVector<T, vec_size> shift_bias_vec;
@@ -2394,6 +2446,7 @@ __global__ void merge_multi_chunks_v2_kernel(
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
@@ -2511,7 +2564,13 @@ __global__ void merge_multi_chunks_v2_kernel(
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
if (sinks) {
float current_sink = static_cast<float>(sinks[hid]);
st.normalize(current_sink);
} else {
st.normalize();
}
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
AlignedVector<T, vec_size> shift_bias_vec;

View File

@@ -40,6 +40,8 @@ void CascadeAppendAttentionKernel(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -63,7 +65,8 @@ void CascadeAppendAttentionKernel(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out) {
paddle::Tensor* out,
const int sliding_window) {
if (cache_quant_type_str == "none") {
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
qkv,
@@ -76,6 +79,7 @@ void CascadeAppendAttentionKernel(
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
@@ -98,7 +102,8 @@ void CascadeAppendAttentionKernel(
is_decoder,
enable_prefill,
stream,
out);
out,
sliding_window);
} else if (cache_quant_type_str == "cache_int8") {
CascadeAppendAttentionC8Kernel<T, OutT, false>(meta_data,
qkv,
@@ -111,6 +116,7 @@ void CascadeAppendAttentionKernel(
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
@@ -134,7 +140,8 @@ void CascadeAppendAttentionKernel(
enable_prefill,
cache_quant_type_str,
stream,
out);
out,
sliding_window);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
@@ -147,6 +154,7 @@ void CascadeAppendAttentionKernel(
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
@@ -170,7 +178,8 @@ void CascadeAppendAttentionKernel(
enable_prefill,
cache_quant_type_str,
stream,
out);
out,
sliding_window);
} else if (cache_quant_type_str == "cache_int4_zp") {
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
qkv,
@@ -183,6 +192,7 @@ void CascadeAppendAttentionKernel(
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
@@ -205,7 +215,8 @@ void CascadeAppendAttentionKernel(
is_decoder,
enable_prefill,
stream,
out);
out,
sliding_window);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "

View File

@@ -2179,7 +2179,9 @@ void gqa_rotary_qk_norm_variable(
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
assert(dim_head == 128 && "dim_head must be 128");
if (dim_head != 128) {
PADDLE_THROW("gqa rotary with qk norm only support head_dim=128, but got %d.", dim_head);
}
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;

View File

@@ -36,6 +36,7 @@ __global__ void multi_query_append_attention_kernel(
T *__restrict__ cache_v,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -57,7 +58,8 @@ __global__ void multi_query_append_attention_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t kv_num_heads = gridDim.z;
@@ -244,7 +246,7 @@ __global__ void multi_query_append_attention_kernel(
compute_qk<num_frags_x, num_frags_y, num_frags_z, T>(
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
partition_kv,
CAUSAL,
@@ -260,7 +262,8 @@ __global__ void multi_query_append_attention_kernel(
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
mask_offset_this_seq,
sliding_window);
}
@@ -318,8 +321,21 @@ __global__ void multi_query_append_attention_kernel(
wait_group<0>();
__syncthreads();
if constexpr (!partition_kv) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
if constexpr (!partition_kv ) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
}
if constexpr (partition_kv) {
write_o_reg_gmem_shift_smooth_quant<GROUP_SIZE,
@@ -411,6 +427,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
T *__restrict__ cache_v,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -434,7 +451,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const uint32_t attn_mask_len = -1,
const int sliding_window = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -622,7 +640,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
compute_qk<num_frags_x, num_frags_y, num_frags_z, T>(
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
partition_kv,
CAUSAL,
@@ -638,7 +656,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
mask_offset_this_seq,
sliding_window);
}
// update m,d
@@ -699,7 +718,20 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
if (num_chunks_this_seq <= 1) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
}
// write o
@@ -792,6 +824,7 @@ void MultiQueryAppendAttention(
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::optional<paddle::Tensor> &sinks,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
@@ -811,7 +844,8 @@ void MultiQueryAppendAttention(
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out) {
paddle::Tensor *out,
const int sliding_window) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
@@ -898,6 +932,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -918,7 +955,8 @@ void MultiQueryAppendAttention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
sliding_window);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
@@ -957,6 +995,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -977,7 +1018,8 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1005,6 +1047,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1041,6 +1086,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1130,6 +1178,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1153,7 +1204,8 @@ void MultiQueryAppendAttention(
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
attn_mask_len,
sliding_window);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1203,6 +1255,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1226,7 +1281,8 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
attn_mask_len,
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
@@ -1255,6 +1311,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1291,6 +1350,9 @@ void MultiQueryAppendAttention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,

View File

@@ -32,6 +32,7 @@ void MultiQueryAppendAttention(
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::optional<paddle::Tensor> &sinks,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
@@ -51,4 +52,5 @@ void MultiQueryAppendAttention(
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out);
paddle::Tensor *out,
const int sliding_window);

View File

@@ -41,6 +41,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -62,7 +63,8 @@ __global__ void multi_query_append_attention_c4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -332,7 +334,7 @@ __global__ void multi_query_append_attention_c4_kernel(
cache_k_scale_frag,
cache_k_zp_frag);
if (iter >= mask_check_iteration) {
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
partition_kv,
CAUSAL,
@@ -348,7 +350,8 @@ __global__ void multi_query_append_attention_c4_kernel(
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
mask_offset_this_seq,
sliding_window);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -412,7 +415,20 @@ __global__ void multi_query_append_attention_c4_kernel(
__syncthreads();
if constexpr (!partition_kv) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
}
if constexpr (partition_kv) {
@@ -509,6 +525,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -532,7 +549,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const uint32_t attn_mask_len = -1,
const int sliding_window = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -798,7 +816,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
s_frag,
cache_k_scale_frag,
cache_k_zp_frag);
if (iter >= mask_check_iteration) {
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
partition_kv,
CAUSAL,
@@ -814,7 +832,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
mask_offset_this_seq,
sliding_window);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -882,7 +901,20 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
if (num_chunks_this_seq <= 1) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
}
// write o
@@ -978,6 +1010,7 @@ void MultiQueryAppendC4Attention(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::optional<paddle::Tensor> &sinks,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
@@ -997,7 +1030,8 @@ void MultiQueryAppendC4Attention(
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out) {
paddle::Tensor *out,
const int sliding_window) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
@@ -1103,6 +1137,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1123,7 +1160,8 @@ void MultiQueryAppendC4Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
sliding_window);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) {
@@ -1168,6 +1206,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1188,7 +1229,8 @@ void MultiQueryAppendC4Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1216,6 +1258,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1252,6 +1297,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1361,6 +1409,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1384,7 +1435,8 @@ void MultiQueryAppendC4Attention(
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
attn_mask_len,
sliding_window);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1442,6 +1494,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1465,7 +1520,8 @@ void MultiQueryAppendC4Attention(
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
attn_mask_len,
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1493,6 +1549,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1529,6 +1588,9 @@ void MultiQueryAppendC4Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,

View File

@@ -36,6 +36,7 @@ void MultiQueryAppendC4Attention(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::optional<paddle::Tensor> &sinks,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
@@ -55,4 +56,5 @@ void MultiQueryAppendC4Attention(
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out);
paddle::Tensor *out,
const int sliding_window);

View File

@@ -42,6 +42,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -63,7 +64,8 @@ __global__ void multi_query_append_attention_c8_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0) {
constexpr uint32_t num_vecs_per_head =
HEAD_DIM / num_elems_per_128b<T>(); // 128 / 8 = 16
constexpr uint32_t num_vecs_per_head_k =
@@ -321,7 +323,7 @@ __global__ void multi_query_append_attention_c8_kernel(
s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
partition_kv,
CAUSAL,
@@ -337,7 +339,8 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
mask_offset_this_seq,
sliding_window);
}
// update m,d
@@ -415,7 +418,20 @@ __global__ void multi_query_append_attention_c8_kernel(
__syncthreads();
if constexpr (!partition_kv) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
}
// write o
@@ -516,6 +532,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -539,7 +556,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const uint32_t attn_mask_len = -1,
const int sliding_window = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -798,7 +816,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
cache_k_scale_reg,
s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
partition_kv,
CAUSAL,
@@ -814,7 +832,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
mask_offset_this_seq,
sliding_window);
}
@@ -895,7 +914,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
if (num_chunks_this_seq <= 1) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
}
// write o
@@ -992,6 +1024,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor &cache_v_scale,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::optional<paddle::Tensor> &sinks,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
@@ -1011,7 +1044,8 @@ void MultiQueryAppendC8Attention(
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out) {
paddle::Tensor *out,
const int sliding_window) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
@@ -1155,6 +1189,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1175,7 +1212,8 @@ void MultiQueryAppendC8Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
sliding_window);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) {
@@ -1214,6 +1252,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1234,7 +1275,8 @@ void MultiQueryAppendC8Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
speculate_max_draft_token_num,
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1262,6 +1304,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1298,6 +1343,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1439,6 +1487,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1462,7 +1513,8 @@ void MultiQueryAppendC8Attention(
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
attn_mask_len,
sliding_window);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1514,6 +1566,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1537,7 +1592,8 @@ void MultiQueryAppendC8Attention(
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
attn_mask_len,
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1560,6 +1616,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1596,6 +1655,9 @@ void MultiQueryAppendC8Attention(
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,

View File

@@ -36,6 +36,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor &cache_v_scale,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::optional<paddle::Tensor> &sinks,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
@@ -55,4 +56,5 @@ void MultiQueryAppendC8Attention(
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out);
paddle::Tensor *out,
const int sliding_window);

View File

@@ -36,7 +36,7 @@
],
"max_instances_per_file": 80,
"file_prefix": "multiquery_attention_c8_",
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n"
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
},
"multiquery_attention_c4": {
"name": "multiquery_attention_c4",
@@ -71,7 +71,7 @@
],
"max_instances_per_file": 160,
"file_prefix": "multiquery_attention_c4_",
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n"
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
},
"multiquery_attention_c16": {
"name": "multiquery_attention_c16",
@@ -90,7 +90,7 @@
],
"dispatch_params": {
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
"HEAD_DIM": [128],
"HEAD_DIM": [64,128],
"BLOCK_SIZE": [64],
"CAUSAL": [0, 1],
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
@@ -106,7 +106,7 @@
],
"max_instances_per_file": 160,
"file_prefix": "multiquery_attention_c16_",
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n"
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
},
"multiquery_decoder_attention": {
"name": "multiquery_decoder_attention",

View File

@@ -301,6 +301,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \

View File

@@ -81,6 +81,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
@@ -89,7 +90,8 @@ std::vector<paddle::Tensor> AppendAttention(
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int max_partition_size, const int encoder_max_partition_size,
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);
const bool speculate_decoder,
const int sliding_window);
std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
@@ -124,6 +126,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
@@ -132,7 +135,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int max_partition_size, const int encoder_max_partition_size,
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);
const bool speculate_decoder,
const int sliding_window);
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
@@ -248,15 +252,18 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size);
const int hadamard_block_size,
const std::string& activation);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,

View File

@@ -21,6 +21,7 @@
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
#include "swigluoai.h"
template <paddle::DataType T>
void MoeFFNKernel(const paddle::Tensor& permute_input,
@@ -36,7 +37,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size) {
const int hadamard_block_size,
const std::string& activation) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
@@ -233,8 +235,13 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
if (used_in_ep_low_latency) {
act_out_tensor = GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum);
} else {
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
if (activation == "swigluoai") {
act_out_tensor = SwigluOAI(fc1_out_tensor, 1.702, 7.0, "interleave");
} else {
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
}
}
auto act_out = act_out_tensor.data<data_t>();
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
@@ -405,8 +412,10 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size) {
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size,
const std::string& activation) {
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
@@ -430,7 +439,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size);
hadamard_block_size,
activation);
break;
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
@@ -446,7 +456,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size);
hadamard_block_size,
activation);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
@@ -466,7 +477,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size) {
const int hadamard_block_size,
const std::string& activation) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
@@ -479,7 +491,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
quant_method,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size)};
hadamard_block_size,
activation)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
@@ -495,7 +508,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size) {
const int hadamard_block_size,
const std::string& activation) {
return {permute_input_shape};
}
@@ -509,7 +523,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size) {
const int estimate_total_token_nums, const int hadamard_block_size,
const std::string &activation) {
if (quant_method == "w4a8" || quant_method == "w4afp8") {
return {up_gate_proj_scale_dtype.get()};
} else {
@@ -583,7 +598,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int", "activation:std::string"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -0,0 +1,189 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "swigluoai.h"
#pragma once
// dim3 grid(256)
// dim3 block(512)
template <typename T, int VecSize>
__global__ void swigluoai_interleave_kernel(T* act_out,
const T* input,
const float alpha,
const float limit,
const int64_t seq_len,
const int64_t hidden_dim) {
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = seq_len * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
int64_t vec_num = hidden_dim / VecSize * seq_len;
int64_t col_size = hidden_dim / VecSize;
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
for(int i = 0; i < times; i++)
{
int64_t index = tid + i * gridDim.x * blockDim.x ;
int64_t row = index / col_size;
int64_t col = index % col_size;
if(row < seq_len && col < col_size)
{
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2], &src_vec0);
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2 + VecSize], &src_vec1);
for (int j = 0; j < VecSize/2; ++j) {
float a = static_cast<float>(src_vec0[2*j]);
float b = static_cast<float>(src_vec0[2*j + 1]);
a = fminf(a, limit);
b = fminf(fmaxf(b,-limit), limit);
float res = (b + 1) * a / (1.f + expf(-a * alpha));
res_vec[j] = static_cast<T>(res);
}
for (int j = 0; j < VecSize/2; ++j) {
float a = static_cast<float>(src_vec1[2*j]);
float b = static_cast<float>(src_vec1[2*j + 1]);
a = fminf(a, limit);
b = fminf(fmaxf(b,-limit), limit);
float res = (b + 1) * a / (1.f + expf(-a * alpha));
res_vec[j + VecSize/2] = static_cast<T>(res);
}
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
}
}
}
// dim3 grid(256)
// dim3 block(512)
template <typename T, int VecSize>
__global__ void swigluoai_norm_kernel(T* act_out,
const T* input,
const float alpha,
const float limit,
const int64_t seq_len,
const int64_t hidden_dim) {
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = seq_len * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
int64_t vec_num = hidden_dim / VecSize * seq_len;
int64_t col_size = hidden_dim / VecSize;
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
for(int i = 0; i < times; i++)
{
int64_t index = tid + i * gridDim.x * blockDim.x ;
int64_t row = index / col_size;
int64_t col = index % col_size;
if(row < seq_len && col < col_size)
{
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize], &src_vec0);
Load<T, VecSize>(&input[row*hidden_dim*2 + hidden_dim + col*VecSize], &src_vec1);
for (int j = 0; j < VecSize; ++j) {
float a = static_cast<float>(src_vec0[j]);
float b = static_cast<float>(src_vec1[j]);
float z = fminf(fmaxf(a * alpha, -limit), limit);
float res = b * a / (1.f + expf(-z));
res_vec[j] = static_cast<T>(res);
}
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
}
}
}
paddle::Tensor SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type)
{
// const int64_t group_size = fc1_out_tensor.shape()[1];
const int64_t seq_len = fc1_out_tensor.shape()[0];
const int64_t hidden_dim = fc1_out_tensor.shape()[1] / 2;
auto act_out_tensor = GetEmptyTensor({seq_len, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
constexpr int VecSize = 8;
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
PD_CHECK(hidden_dim % VecSize == 0);
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int block_size = 512;
const int grid_size = 256;
#define dispatch_norm() do {\
swigluoai_norm_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
alpha,\
limit,\
seq_len,\
hidden_dim\
);} while(0)
#define dispatch_interleave() do {\
swigluoai_interleave_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
alpha,\
limit,\
seq_len,\
hidden_dim\
);} while(0)
if(type == "interleave")
{
dispatch_interleave();
}
else
{
dispatch_norm();
}
// if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
// dispatch_by_index(int64_t);
// } else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
// dispatch_by_index(int32_t);
// } else {
// PD_THROW("Unsupported token_nums_per_expert's data dtype.");
// }
return act_out_tensor;
}
std::vector<paddle::Tensor> SwigluOAIWrapper(
const paddle::Tensor& fc1_out_tensor,
const float alpha,
const float limit,
const std::string& type) {
return {SwigluOAI(fc1_out_tensor, alpha, limit, type)};
}
PD_BUILD_STATIC_OP(swigluoai)
.Inputs({"fc1_out_tensor"})
.Attrs({"alpha: float", "limit: float", "type: std::string"})
.Outputs({"output_tensor"})
.SetKernelFn(PD_KERNEL(SwigluOAIWrapper));

View File

@@ -0,0 +1,19 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
paddle::Tensor
SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type);

View File

@@ -170,7 +170,10 @@ std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& to
run_align_kernel(128);
} else if (num_experts == 160) {
run_align_kernel(160);
} else {
} else if (num_experts == 32) {
run_align_kernel(32);
}
else {
PD_THROW("Not support num_experts: %d", num_experts);
}