Files
FastDeploy/custom_ops/iluvatar_ops/prefill_fused_attn.cu
2025-10-22 17:59:50 +08:00

401 lines
16 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "iluvatar_context.h"
template <paddle::DataType T>
void PrefillFusedPagedAttnKernel(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
paddle::Tensor& out) {
// check dtype and contiguous
const auto& dtype = qkv.dtype();
cuinferDataType_t data_type;
if (dtype == paddle::DataType::FLOAT16) {
data_type = CUINFER_DATA_HALF;
} else if (dtype == paddle::DataType::BFLOAT16) {
data_type = CUINFER_DATA_BFLOAT16;
} else {
common::errors::InvalidArgument(
"paged_attention support half and bfloat16 now");
}
PADDLE_ENFORCE_EQ(k_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"k_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(
block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument("block_table dtype must be int32"));
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects block_table is contiguous"));
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument("cu_seqlens_qkv dtype must be int32"));
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects cu_seqlens_qkv is contiguous"));
// check dim and shape
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
// block_table: [batch_size, max_num_blocks_per_seq]
// seq_lens: [batch_size]
// qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim]
// out: [num_tokens, hidden_size]
const auto& qkv_dims = qkv.dims();
PADDLE_ENFORCE_EQ(qkv_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive query dims is "
"[num_tokens, (num_heads+2*num_kv_heads)*head_dim]"));
PADDLE_ENFORCE_EQ(
out.dims().size(),
2,
common::errors::InvalidArgument("paged_attn receive out dims is "
"[num_tokens, hidden_size]"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
4,
common::errors::InvalidArgument(
"paged_attn receive kv cache dims is "
"[num_blocks, kv_num_heads, block_size, head_dim]"));
const auto& block_table_dims = block_table.dims();
PADDLE_ENFORCE_EQ(
block_table_dims.size(),
2,
common::errors::InvalidArgument("paged_attn receive block_table dims is "
"[batch_size, max_num_blocks_per_seq]"));
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv_dims.size(),
1,
common::errors::InvalidArgument(
"paged_attn receive cu_seqlens_qkv dims is [batch_size]"));
int batch_size = block_table_dims[0];
int num_tokens = qkv_dims[0];
int num_total_heads = num_heads + 2 * num_kv_heads;
int qkv_stride = qkv.strides()[0];
int num_blocks = kv_cache_dims[0];
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
num_kv_heads,
common::errors::InvalidArgument(
"kv_cache_dims[1] must be equal to num_kv_head"));
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
block_size,
common::errors::InvalidArgument(
"kv_cache_dims[2] must be equal to block_size"));
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
head_dim,
common::errors::InvalidArgument(
"kv_cache_dims[3] must be equal to head_dim"));
PADDLE_ENFORCE_EQ(
cu_seqlens_qkv_dims[0],
batch_size + 1,
common::errors::InvalidArgument(
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
int block_table_stride = block_table.strides()[0];
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
size_t workspace_size = 0;
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens,
num_heads,
num_kv_heads,
head_dim,
q_rope,
k_rope,
v_rope,
data_type,
data_type,
data_type,
&workspace_size));
auto* allocator = paddle::GetAllocator(qkv.place());
phi::Allocator::AllocationPtr tmp_workspace =
allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_desc,
data_type,
3,
std::vector<int>({num_tokens, num_total_heads, head_dim}).data(),
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t qkv_seqlens_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
CUINFER_CHECK(
cuinferSetTensorNdDescriptor(qkv_seqlens_desc,
CUINFER_DATA_INT32,
1,
std::vector<int>({batch_size + 1}).data(),
std::vector<int>({1}).data()));
cuinferTensorDescriptor_t block_table_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
block_table_desc,
CUINFER_DATA_INT32,
2,
std::vector<int>({batch_size, block_table_stride}).data(),
std::vector<int>({block_table_stride, 1}).data()));
cuinferTensorDescriptor_t o_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
o_desc,
data_type,
3,
std::vector<int>({num_tokens, num_heads, head_dim}).data(),
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t k_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
k_cache_desc,
data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t v_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
v_cache_desc,
data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cos_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
sin_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
qkv_seqlens_desc,
cu_seqlens_qkv.data<int32_t>(),
block_table_desc,
block_table.data<int32_t>(),
o_desc,
out.data(),
k_cache_desc,
k_cache.data(),
v_cache_desc,
v_cache.data(),
workspace_ptr,
workspace_size,
cos_desc,
rope_cos_ptr,
sin_desc,
rope_sin_ptr,
batch_size,
num_heads,
num_kv_heads,
head_dim,
causal,
scale,
q_rope,
k_rope,
v_rope));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
}
std::vector<paddle::Tensor> PrefillFusedPagedAttn(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
const auto dtype = qkv.dtype();
auto out =
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
break;
case paddle::DataType::FLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
}
return {out};
}
std::vector<std::vector<int64_t>> PrefillFusedPagedAttnInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& k_cache_shape,
const std::vector<int64_t>& v_cache_shape,
const std::vector<int64_t>& block_table_shape,
const std::vector<int64_t>& cu_seqlens_qkv_shape,
const std::vector<int64_t>& rope_sin_shape,
const std::vector<int64_t>& rope_cos_shape,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
return {{qkv_shape[0], num_heads * head_dim}};
}
std::vector<paddle::DataType> PrefillFusedPagedAttnInferDtype(
const paddle::DataType& qkv_dtype) {
return {qkv_dtype};
}
PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
.Inputs({"qkv",
"k_cache",
"v_cache",
"block_table",
"cu_seqlens_qkv",
paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
.Outputs({"out"})
.Attrs({"num_heads:int",
"head_dim:int",
"num_kv_heads:int",
"block_size:int",
"max_seq_len:int",
"scale:float",
"causal:bool",
"q_rope:bool",
"k_rope:bool",
"v_rope:bool"})
.SetKernelFn(PD_KERNEL(PrefillFusedPagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(PrefillFusedPagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillFusedPagedAttnInferDtype));