mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Support GPT-OSS-BF16 (#4240)
* [Feature] AppendAtten support sinks & HEAD_DIM=64 * fix bug * fix bug * fix bug * fix bug * [Feature] support gpt-oss * fix bug * add mask * support-gpt-oss * support-gpt-oss * fix long seq * support wint8 * support wint8 * support wint8 * update test * change sliding windows init pos --------- Co-authored-by: ming1753 <ideaminghp@163.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
This commit is contained in:
@@ -72,10 +72,10 @@ void AppendAttentionKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const float rms_norm_eps,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -90,7 +90,8 @@ void AppendAttentionKernel(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -146,6 +147,7 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
sinks,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
@@ -169,7 +171,8 @@ void AppendAttentionKernel(
|
||||
lambda_is_decoder,
|
||||
lambda_enable_prefill,
|
||||
lambda_stream,
|
||||
&fmha_out);
|
||||
&fmha_out,
|
||||
sliding_window);
|
||||
};
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
@@ -428,6 +431,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
@@ -443,7 +447,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
@@ -550,10 +555,10 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
@@ -568,7 +573,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
speculate_decoder,
|
||||
sliding_window);
|
||||
};
|
||||
|
||||
|
||||
@@ -630,6 +636,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
@@ -645,7 +652,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
@@ -704,10 +712,10 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
@@ -722,7 +730,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
speculate_decoder,
|
||||
sliding_window);
|
||||
};
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
@@ -797,6 +806,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& sinks_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
@@ -812,7 +822,8 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
const int token_num = qkv_shape[0];
|
||||
const int kv_num_heads = key_cache_shape[1];
|
||||
int head_dim = key_cache_shape[3];
|
||||
@@ -860,6 +871,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& sinks_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
@@ -875,7 +887,8 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
if (compute_dtype == "bf16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
@@ -942,6 +955,7 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& sinks_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
@@ -957,7 +971,8 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
return {fmha_out_shape};
|
||||
}
|
||||
|
||||
@@ -998,6 +1013,7 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& sinks_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
@@ -1013,7 +1029,8 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window) {
|
||||
return {fmha_out_dtype};
|
||||
}
|
||||
|
||||
@@ -1054,7 +1071,8 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
paddle::Optional("k_norm_weight"),
|
||||
paddle::Optional("sinks")})
|
||||
.Outputs({"fmha_out"})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
@@ -1072,6 +1090,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
"sliding_window: int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||
@@ -1113,7 +1132,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
paddle::Optional("k_norm_weight"),
|
||||
paddle::Optional("sinks")})
|
||||
.Outputs({"fmha_out_out"})
|
||||
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
@@ -1132,6 +1152,7 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
"sliding_window: int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||
|
||||
@@ -36,6 +36,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -58,7 +60,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
@@ -99,6 +102,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
attn_mask,
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
@@ -118,7 +122,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
})})})})})})
|
||||
}
|
||||
|
||||
@@ -142,6 +147,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -164,7 +171,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -186,6 +194,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -208,7 +218,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -230,6 +241,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -252,7 +265,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -274,6 +288,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -296,7 +312,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -318,6 +335,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -340,7 +359,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -362,6 +382,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -384,4 +406,5 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
@@ -36,6 +36,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -58,7 +60,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
@@ -103,6 +106,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
cache_v_zp,
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
@@ -122,7 +126,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
})})})})})})
|
||||
}
|
||||
|
||||
@@ -146,6 +151,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -168,7 +175,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -190,6 +198,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -212,7 +222,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -234,6 +245,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -256,7 +269,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -278,6 +292,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -300,7 +316,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m3fn>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -322,6 +339,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -344,7 +363,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -366,6 +386,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -388,4 +410,5 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
@@ -36,6 +36,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -59,7 +61,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
@@ -106,7 +109,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
@@ -125,7 +129,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
})})})})})})})
|
||||
}
|
||||
|
||||
@@ -141,6 +146,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -164,7 +170,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, true>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -178,6 +185,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -201,7 +209,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, false>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -215,6 +224,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -238,7 +248,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, true>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -252,6 +263,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -275,7 +287,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -289,6 +302,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -312,7 +326,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -326,6 +341,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -349,7 +365,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -363,6 +380,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -386,7 +404,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -400,6 +419,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -423,7 +443,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, false>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -437,6 +458,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -460,7 +482,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, true>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -474,6 +497,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -497,7 +521,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -511,6 +536,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -534,7 +560,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -548,6 +575,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -571,4 +599,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window);
|
||||
|
||||
@@ -77,6 +77,14 @@ struct prefill_softmax_state_t {
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
const T d_t = static_cast<T>(d);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(float current_sink) {
|
||||
const T d_t = static_cast<T>(d + __expf(current_sink - m));
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
@@ -1028,7 +1036,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t chunk_end,
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
const int *mask_offset = nullptr,
|
||||
const int sliding_window = 0) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -1045,11 +1054,21 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
} else {
|
||||
}
|
||||
else if (sliding_window > 0)
|
||||
{
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
@@ -1064,7 +1083,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
||||
}
|
||||
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
|
||||
|
||||
} else {
|
||||
const uint32_t q_idx = qo_idx_base,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
@@ -1458,6 +1477,33 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t num_frags_x, uint32_t num_frags_y>
|
||||
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
|
||||
float (*d)[2],
|
||||
float (*m)[2],
|
||||
float (*current_sinks)[2]) {
|
||||
float d_rcp[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
d_rcp[fx][j] = 1.f / (d[fx][j] + __expf(current_sinks[fx][j] - m[fx][j]));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
#pragma unroll
|
||||
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
||||
o_frag[fx][fy][reg_id] =
|
||||
o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t num_frags_x,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t NUM_WARPS,
|
||||
@@ -2271,6 +2317,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -2354,7 +2401,12 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
if (sinks) {
|
||||
float current_sink = static_cast<float>(sinks[hid]);
|
||||
st.normalize(current_sink);
|
||||
} else {
|
||||
st.normalize();
|
||||
}
|
||||
|
||||
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
||||
AlignedVector<T, vec_size> shift_bias_vec;
|
||||
@@ -2394,6 +2446,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -2511,7 +2564,13 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
if (sinks) {
|
||||
float current_sink = static_cast<float>(sinks[hid]);
|
||||
st.normalize(current_sink);
|
||||
} else {
|
||||
st.normalize();
|
||||
}
|
||||
|
||||
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
||||
AlignedVector<T, vec_size> shift_bias_vec;
|
||||
|
||||
@@ -40,6 +40,8 @@ void CascadeAppendAttentionKernel(
|
||||
shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>&
|
||||
sinks, // [num_heads]
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -63,7 +65,8 @@ void CascadeAppendAttentionKernel(
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
paddle::Tensor* out,
|
||||
const int sliding_window) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
|
||||
qkv,
|
||||
@@ -76,6 +79,7 @@ void CascadeAppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
@@ -98,7 +102,8 @@ void CascadeAppendAttentionKernel(
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, false>(meta_data,
|
||||
qkv,
|
||||
@@ -111,6 +116,7 @@ void CascadeAppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
@@ -134,7 +140,8 @@ void CascadeAppendAttentionKernel(
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
@@ -147,6 +154,7 @@ void CascadeAppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
@@ -170,7 +178,8 @@ void CascadeAppendAttentionKernel(
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
|
||||
qkv,
|
||||
@@ -183,6 +192,7 @@ void CascadeAppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
sinks,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
@@ -205,7 +215,8 @@ void CascadeAppendAttentionKernel(
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
stream,
|
||||
out);
|
||||
out,
|
||||
sliding_window);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
|
||||
@@ -2179,7 +2179,9 @@ void gqa_rotary_qk_norm_variable(
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
|
||||
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
|
||||
assert(dim_head == 128 && "dim_head must be 128");
|
||||
if (dim_head != 128) {
|
||||
PADDLE_THROW("gqa rotary with qk norm only support head_dim=128, but got %d.", dim_head);
|
||||
}
|
||||
constexpr int HEAD_DIM = 128;
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
|
||||
@@ -36,6 +36,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
T *__restrict__ cache_v,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -57,7 +58,8 @@ __global__ void multi_query_append_attention_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const int sliding_window = 0) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
@@ -244,7 +246,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
compute_qk<num_frags_x, num_frags_y, num_frags_z, T>(
|
||||
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
|
||||
// mask according to kv_idx and q_idx
|
||||
if (iter >= mask_check_iteration) {
|
||||
if (iter >= mask_check_iteration || sliding_window > 0) {
|
||||
mask_s<T,
|
||||
partition_kv,
|
||||
CAUSAL,
|
||||
@@ -260,7 +262,8 @@ __global__ void multi_query_append_attention_kernel(
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
|
||||
}
|
||||
|
||||
@@ -318,8 +321,21 @@ __global__ void multi_query_append_attention_kernel(
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (!partition_kv) {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
if constexpr (!partition_kv ) {
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
}
|
||||
if constexpr (partition_kv) {
|
||||
write_o_reg_gmem_shift_smooth_quant<GROUP_SIZE,
|
||||
@@ -411,6 +427,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
T *__restrict__ cache_v,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -434,7 +451,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
const uint32_t attn_mask_len = -1,
|
||||
const int sliding_window = 0) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
||||
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
||||
@@ -622,7 +640,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
compute_qk<num_frags_x, num_frags_y, num_frags_z, T>(
|
||||
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
|
||||
// mask according to kv_idx and q_idx
|
||||
if (iter >= mask_check_iteration) {
|
||||
if (iter >= mask_check_iteration || sliding_window > 0) {
|
||||
mask_s<T,
|
||||
partition_kv,
|
||||
CAUSAL,
|
||||
@@ -638,7 +656,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -699,7 +718,20 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
|
||||
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
}
|
||||
|
||||
// write o
|
||||
@@ -792,6 +824,7 @@ void MultiQueryAppendAttention(
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::optional<paddle::Tensor> &sinks,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -811,7 +844,8 @@ void MultiQueryAppendAttention(
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
paddle::Tensor *out,
|
||||
const int sliding_window) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
|
||||
|
||||
@@ -898,6 +932,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -918,7 +955,8 @@ void MultiQueryAppendAttention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
sliding_window);
|
||||
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
@@ -957,6 +995,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -977,7 +1018,8 @@ void MultiQueryAppendAttention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
sliding_window);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1005,6 +1047,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1041,6 +1086,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1130,6 +1178,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1153,7 +1204,8 @@ void MultiQueryAppendAttention(
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
attn_mask_len,
|
||||
sliding_window);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1203,6 +1255,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1226,7 +1281,8 @@ void MultiQueryAppendAttention(
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
attn_mask_len,
|
||||
sliding_window);
|
||||
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
@@ -1255,6 +1311,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1291,6 +1350,9 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
|
||||
@@ -32,6 +32,7 @@ void MultiQueryAppendAttention(
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::optional<paddle::Tensor> &sinks,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -51,4 +52,5 @@ void MultiQueryAppendAttention(
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
paddle::Tensor *out,
|
||||
const int sliding_window);
|
||||
|
||||
@@ -41,6 +41,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -62,7 +63,8 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const int sliding_window = 0) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||
@@ -332,7 +334,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
cache_k_scale_frag,
|
||||
cache_k_zp_frag);
|
||||
|
||||
if (iter >= mask_check_iteration) {
|
||||
if (iter >= mask_check_iteration || sliding_window > 0) {
|
||||
mask_s<T,
|
||||
partition_kv,
|
||||
CAUSAL,
|
||||
@@ -348,7 +350,8 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
}
|
||||
|
||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||
@@ -412,7 +415,20 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (!partition_kv) {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (partition_kv) {
|
||||
@@ -509,6 +525,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -532,7 +549,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
const uint32_t attn_mask_len = -1,
|
||||
const int sliding_window = 0) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||
@@ -798,7 +816,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
s_frag,
|
||||
cache_k_scale_frag,
|
||||
cache_k_zp_frag);
|
||||
if (iter >= mask_check_iteration) {
|
||||
if (iter >= mask_check_iteration || sliding_window > 0) {
|
||||
mask_s<T,
|
||||
partition_kv,
|
||||
CAUSAL,
|
||||
@@ -814,7 +832,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
}
|
||||
|
||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||
@@ -882,7 +901,20 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
|
||||
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
}
|
||||
|
||||
// write o
|
||||
@@ -978,6 +1010,7 @@ void MultiQueryAppendC4Attention(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::optional<paddle::Tensor> &sinks,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -997,7 +1030,8 @@ void MultiQueryAppendC4Attention(
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
paddle::Tensor *out,
|
||||
const int sliding_window) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
|
||||
|
||||
@@ -1103,6 +1137,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1123,7 +1160,8 @@ void MultiQueryAppendC4Attention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
sliding_window);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (ENABLE_PREFILL) {
|
||||
@@ -1168,6 +1206,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1188,7 +1229,8 @@ void MultiQueryAppendC4Attention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
sliding_window);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1216,6 +1258,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1252,6 +1297,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1361,6 +1409,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1384,7 +1435,8 @@ void MultiQueryAppendC4Attention(
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
attn_mask_len,
|
||||
sliding_window);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1442,6 +1494,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1465,7 +1520,8 @@ void MultiQueryAppendC4Attention(
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
attn_mask_len,
|
||||
sliding_window);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1493,6 +1549,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1529,6 +1588,9 @@ void MultiQueryAppendC4Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
|
||||
@@ -36,6 +36,7 @@ void MultiQueryAppendC4Attention(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::optional<paddle::Tensor> &sinks,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -55,4 +56,5 @@ void MultiQueryAppendC4Attention(
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
paddle::Tensor *out,
|
||||
const int sliding_window);
|
||||
|
||||
@@ -42,6 +42,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -63,7 +64,8 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const int sliding_window = 0) {
|
||||
constexpr uint32_t num_vecs_per_head =
|
||||
HEAD_DIM / num_elems_per_128b<T>(); // 128 / 8 = 16
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
@@ -321,7 +323,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag);
|
||||
|
||||
// mask according to kv_idx and q_idx
|
||||
if (iter >= mask_check_iteration) {
|
||||
if (iter >= mask_check_iteration || sliding_window > 0) {
|
||||
mask_s<T,
|
||||
partition_kv,
|
||||
CAUSAL,
|
||||
@@ -337,7 +339,8 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -415,7 +418,20 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (!partition_kv) {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
}
|
||||
|
||||
// write o
|
||||
@@ -516,6 +532,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -539,7 +556,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
const uint32_t attn_mask_len = -1,
|
||||
const int sliding_window = 0) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
@@ -798,7 +816,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
cache_k_scale_reg,
|
||||
s_frag);
|
||||
// mask according to kv_idx and q_idx
|
||||
if (iter >= mask_check_iteration) {
|
||||
if (iter >= mask_check_iteration || sliding_window > 0) {
|
||||
mask_s<T,
|
||||
partition_kv,
|
||||
CAUSAL,
|
||||
@@ -814,7 +832,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
|
||||
}
|
||||
|
||||
@@ -895,7 +914,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
|
||||
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
}
|
||||
|
||||
// write o
|
||||
@@ -992,6 +1024,7 @@ void MultiQueryAppendC8Attention(
|
||||
const paddle::Tensor &cache_v_scale,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::optional<paddle::Tensor> &sinks,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -1011,7 +1044,8 @@ void MultiQueryAppendC8Attention(
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
paddle::Tensor *out,
|
||||
const int sliding_window) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
|
||||
|
||||
@@ -1155,6 +1189,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1175,7 +1212,8 @@ void MultiQueryAppendC8Attention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
sliding_window);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (ENABLE_PREFILL) {
|
||||
@@ -1214,6 +1252,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1234,7 +1275,8 @@ void MultiQueryAppendC8Attention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
sliding_window);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1262,6 +1304,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1298,6 +1343,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1439,6 +1487,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1462,7 +1513,8 @@ void MultiQueryAppendC8Attention(
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
attn_mask_len,
|
||||
sliding_window);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1514,6 +1566,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1537,7 +1592,8 @@ void MultiQueryAppendC8Attention(
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
attn_mask_len,
|
||||
sliding_window);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1560,6 +1616,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1596,6 +1655,9 @@ void MultiQueryAppendC8Attention(
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
|
||||
@@ -36,6 +36,7 @@ void MultiQueryAppendC8Attention(
|
||||
const paddle::Tensor &cache_v_scale,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::optional<paddle::Tensor> &sinks,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -55,4 +56,5 @@ void MultiQueryAppendC8Attention(
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
paddle::Tensor *out,
|
||||
const int sliding_window);
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
],
|
||||
"max_instances_per_file": 80,
|
||||
"file_prefix": "multiquery_attention_c8_",
|
||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n"
|
||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
|
||||
},
|
||||
"multiquery_attention_c4": {
|
||||
"name": "multiquery_attention_c4",
|
||||
@@ -71,7 +71,7 @@
|
||||
],
|
||||
"max_instances_per_file": 160,
|
||||
"file_prefix": "multiquery_attention_c4_",
|
||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n"
|
||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
|
||||
},
|
||||
"multiquery_attention_c16": {
|
||||
"name": "multiquery_attention_c16",
|
||||
@@ -90,7 +90,7 @@
|
||||
],
|
||||
"dispatch_params": {
|
||||
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
|
||||
"HEAD_DIM": [128],
|
||||
"HEAD_DIM": [64,128],
|
||||
"BLOCK_SIZE": [64],
|
||||
"CAUSAL": [0, 1],
|
||||
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
|
||||
@@ -106,7 +106,7 @@
|
||||
],
|
||||
"max_instances_per_file": 160,
|
||||
"file_prefix": "multiquery_attention_c16_",
|
||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out);\n\n"
|
||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
|
||||
},
|
||||
"multiquery_decoder_attention": {
|
||||
"name": "multiquery_decoder_attention",
|
||||
|
||||
@@ -301,6 +301,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 64: { \
|
||||
constexpr size_t HEAD_DIM = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
|
||||
@@ -81,6 +81,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
@@ -89,7 +90,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int max_partition_size, const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num, const bool causal,
|
||||
const bool speculate_decoder);
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window);
|
||||
|
||||
std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||
@@ -124,6 +126,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& sinks,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
@@ -132,7 +135,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int max_partition_size, const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num, const bool causal,
|
||||
const bool speculate_decoder);
|
||||
const bool speculate_decoder,
|
||||
const int sliding_window);
|
||||
|
||||
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||
@@ -248,15 +252,18 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
|
||||
paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size);
|
||||
const int hadamard_block_size,
|
||||
const std::string& activation);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "moe/fast_hardamard_kernel.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
#include "w4afp8_gemm/w4afp8_gemm.h"
|
||||
#include "swigluoai.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
@@ -36,7 +37,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
const int hadamard_block_size,
|
||||
const std::string& activation) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -233,8 +235,13 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
if (used_in_ep_low_latency) {
|
||||
act_out_tensor = GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum);
|
||||
} else {
|
||||
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
if (activation == "swigluoai") {
|
||||
act_out_tensor = SwigluOAI(fc1_out_tensor, 1.702, 7.0, "interleave");
|
||||
} else {
|
||||
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
if (quant_method == "weight_only_int8") {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||
@@ -405,8 +412,10 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums, const int hadamard_block_size,
|
||||
const std::string& activation) {
|
||||
|
||||
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
@@ -430,7 +439,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
hadamard_block_size,
|
||||
activation);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -446,7 +456,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
hadamard_block_size,
|
||||
activation);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
@@ -466,7 +477,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
const int hadamard_block_size,
|
||||
const std::string& activation) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
@@ -479,7 +491,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
quant_method,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size)};
|
||||
hadamard_block_size,
|
||||
activation)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
@@ -495,7 +508,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
const int hadamard_block_size,
|
||||
const std::string& activation) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
@@ -509,7 +523,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size,
|
||||
const std::string &activation) {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
@@ -583,7 +598,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int", "activation:std::string"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||
|
||||
189
custom_ops/gpu_ops/moe/swigluoai.cu
Normal file
189
custom_ops/gpu_ops/moe/swigluoai.cu
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "swigluoai.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
// dim3 grid(256)
|
||||
// dim3 block(512)
|
||||
template <typename T, int VecSize>
|
||||
__global__ void swigluoai_interleave_kernel(T* act_out,
|
||||
const T* input,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const int64_t seq_len,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = seq_len * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
|
||||
int64_t vec_num = hidden_dim / VecSize * seq_len;
|
||||
int64_t col_size = hidden_dim / VecSize;
|
||||
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
|
||||
|
||||
for(int i = 0; i < times; i++)
|
||||
{
|
||||
int64_t index = tid + i * gridDim.x * blockDim.x ;
|
||||
int64_t row = index / col_size;
|
||||
int64_t col = index % col_size;
|
||||
|
||||
if(row < seq_len && col < col_size)
|
||||
{
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2], &src_vec0);
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2 + VecSize], &src_vec1);
|
||||
|
||||
for (int j = 0; j < VecSize/2; ++j) {
|
||||
float a = static_cast<float>(src_vec0[2*j]);
|
||||
float b = static_cast<float>(src_vec0[2*j + 1]);
|
||||
a = fminf(a, limit);
|
||||
b = fminf(fmaxf(b,-limit), limit);
|
||||
float res = (b + 1) * a / (1.f + expf(-a * alpha));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
for (int j = 0; j < VecSize/2; ++j) {
|
||||
float a = static_cast<float>(src_vec1[2*j]);
|
||||
float b = static_cast<float>(src_vec1[2*j + 1]);
|
||||
a = fminf(a, limit);
|
||||
b = fminf(fmaxf(b,-limit), limit);
|
||||
float res = (b + 1) * a / (1.f + expf(-a * alpha));
|
||||
res_vec[j + VecSize/2] = static_cast<T>(res);
|
||||
}
|
||||
|
||||
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// dim3 grid(256)
|
||||
// dim3 block(512)
|
||||
template <typename T, int VecSize>
|
||||
__global__ void swigluoai_norm_kernel(T* act_out,
|
||||
const T* input,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const int64_t seq_len,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = seq_len * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
|
||||
int64_t vec_num = hidden_dim / VecSize * seq_len;
|
||||
int64_t col_size = hidden_dim / VecSize;
|
||||
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
|
||||
|
||||
for(int i = 0; i < times; i++)
|
||||
{
|
||||
int64_t index = tid + i * gridDim.x * blockDim.x ;
|
||||
int64_t row = index / col_size;
|
||||
int64_t col = index % col_size;
|
||||
|
||||
if(row < seq_len && col < col_size)
|
||||
{
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize], &src_vec0);
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + hidden_dim + col*VecSize], &src_vec1);
|
||||
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
float a = static_cast<float>(src_vec0[j]);
|
||||
float b = static_cast<float>(src_vec1[j]);
|
||||
float z = fminf(fmaxf(a * alpha, -limit), limit);
|
||||
float res = b * a / (1.f + expf(-z));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
|
||||
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
paddle::Tensor SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type)
|
||||
{
|
||||
// const int64_t group_size = fc1_out_tensor.shape()[1];
|
||||
const int64_t seq_len = fc1_out_tensor.shape()[0];
|
||||
const int64_t hidden_dim = fc1_out_tensor.shape()[1] / 2;
|
||||
auto act_out_tensor = GetEmptyTensor({seq_len, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
|
||||
|
||||
constexpr int VecSize = 8;
|
||||
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
|
||||
PD_CHECK(hidden_dim % VecSize == 0);
|
||||
|
||||
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int block_size = 512;
|
||||
const int grid_size = 256;
|
||||
|
||||
#define dispatch_norm() do {\
|
||||
swigluoai_norm_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
|
||||
alpha,\
|
||||
limit,\
|
||||
seq_len,\
|
||||
hidden_dim\
|
||||
);} while(0)
|
||||
|
||||
#define dispatch_interleave() do {\
|
||||
swigluoai_interleave_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
|
||||
alpha,\
|
||||
limit,\
|
||||
seq_len,\
|
||||
hidden_dim\
|
||||
);} while(0)
|
||||
|
||||
if(type == "interleave")
|
||||
{
|
||||
dispatch_interleave();
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatch_norm();
|
||||
}
|
||||
// if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
|
||||
// dispatch_by_index(int64_t);
|
||||
// } else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
|
||||
// dispatch_by_index(int32_t);
|
||||
// } else {
|
||||
// PD_THROW("Unsupported token_nums_per_expert's data dtype.");
|
||||
// }
|
||||
|
||||
return act_out_tensor;
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> SwigluOAIWrapper(
|
||||
const paddle::Tensor& fc1_out_tensor,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const std::string& type) {
|
||||
return {SwigluOAI(fc1_out_tensor, alpha, limit, type)};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(swigluoai)
|
||||
.Inputs({"fc1_out_tensor"})
|
||||
.Attrs({"alpha: float", "limit: float", "type: std::string"})
|
||||
.Outputs({"output_tensor"})
|
||||
.SetKernelFn(PD_KERNEL(SwigluOAIWrapper));
|
||||
19
custom_ops/gpu_ops/moe/swigluoai.h
Normal file
19
custom_ops/gpu_ops/moe/swigluoai.h
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
|
||||
paddle::Tensor
|
||||
SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type);
|
||||
@@ -170,7 +170,10 @@ std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& to
|
||||
run_align_kernel(128);
|
||||
} else if (num_experts == 160) {
|
||||
run_align_kernel(160);
|
||||
} else {
|
||||
} else if (num_experts == 32) {
|
||||
run_align_kernel(32);
|
||||
}
|
||||
else {
|
||||
PD_THROW("Not support num_experts: %d", num_experts);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user