mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
338 lines
16 KiB
Plaintext
338 lines
16 KiB
Plaintext
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "helper.h"
|
|
#include "iluvatar_context.h"
|
|
|
|
#define CUINFER_CHECK(func) \
|
|
do { \
|
|
cuinferStatus_t status = (func); \
|
|
if (status != CUINFER_STATUS_SUCCESS) { \
|
|
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \
|
|
<< cuinferGetErrorString(status) << std::endl; \
|
|
throw std::runtime_error("CUINFER_CHECK ERROR"); \
|
|
} \
|
|
} while (0)
|
|
|
|
template <paddle::DataType T>
|
|
void PagedAttnKernel(const paddle::Tensor& q,
|
|
const paddle::Tensor& k_cache,
|
|
const paddle::Tensor& v_cache,
|
|
const paddle::Tensor& block_table,
|
|
const paddle::Tensor& seq_lens,
|
|
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
|
const paddle::optional<paddle::Tensor> &k,
|
|
const paddle::optional<paddle::Tensor> &v,
|
|
int num_kv_heads,
|
|
float scale,
|
|
int block_size,
|
|
int max_context_len,
|
|
bool causal,
|
|
int window_left,
|
|
int window_right,
|
|
float softcap,
|
|
bool enable_cuda_graph,
|
|
bool use_sqrt_alibi,
|
|
paddle::Tensor& out) {
|
|
if (alibi_slopes) {
|
|
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
|
|
paddle::DataType::FLOAT32,
|
|
common::errors::InvalidArgument(
|
|
"paged_attention expects alibi_slopes float tensor"));
|
|
PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(),
|
|
true,
|
|
common::errors::InvalidArgument(
|
|
"paged_attention expects alibi_slopes is contiguous"));
|
|
}
|
|
|
|
// check dtype and contiguous
|
|
const auto& dtype = q.dtype();
|
|
cudaDataType_t data_type;
|
|
if (dtype == paddle::DataType::FLOAT16) {
|
|
data_type = CUDA_R_16F;
|
|
} else if (dtype == paddle::DataType::BFLOAT16) {
|
|
data_type = CUDA_R_16BF;
|
|
} else {
|
|
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
|
|
}
|
|
|
|
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
|
dtype,
|
|
common::errors::InvalidArgument(
|
|
"k_cache dtype must be the same as query dtype"));
|
|
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
|
true,
|
|
common::errors::InvalidArgument(
|
|
"paged_attention expects k_cache is contiguous"));
|
|
PADDLE_ENFORCE_EQ(v_cache.dtype(),
|
|
dtype,
|
|
common::errors::InvalidArgument(
|
|
"v_cache dtype must be the same as query dtype"));
|
|
PADDLE_ENFORCE_EQ(v_cache.is_contiguous(),
|
|
true,
|
|
common::errors::InvalidArgument(
|
|
"paged_attention expects v_cache is contiguous"));
|
|
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
|
paddle::DataType::INT32,
|
|
common::errors::InvalidArgument(
|
|
"block_table dtype must be int32"));
|
|
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
|
true,
|
|
common::errors::InvalidArgument(
|
|
"paged_attention expects block_table is contiguous"));
|
|
PADDLE_ENFORCE_EQ(seq_lens.dtype(),
|
|
paddle::DataType::INT32,
|
|
common::errors::InvalidArgument(
|
|
"seq_lens dtype must be int32"));
|
|
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
|
|
true,
|
|
common::errors::InvalidArgument(
|
|
"paged_attention expects seq_lens is contiguous"));
|
|
|
|
// check dim and shape
|
|
// out: [num_seqs, num_heads, head_size]
|
|
// q: [num_seqs, num_heads, head_size]
|
|
// k_chache: [num_blocks, kv_num_heads, block_size, head_size]
|
|
// v_chache: [num_blocks, kv_num_heads, block_size, head_size]
|
|
// block_table: [num_seqs, max_num_blocks_per_seq]
|
|
// seq_lens: [num_seqs]
|
|
|
|
const auto& q_dims = q.dims();
|
|
PADDLE_ENFORCE_EQ(q_dims.size(),
|
|
3,
|
|
common::errors::InvalidArgument(
|
|
"paged_attn receive query dims is "
|
|
"[num_seqs, num_heads, head_size]"));
|
|
PADDLE_ENFORCE_EQ(out.dims().size(),
|
|
3,
|
|
common::errors::InvalidArgument(
|
|
"paged_attn receive out dims is "
|
|
"[num_seqs, num_heads, head_size]"));
|
|
PADDLE_ENFORCE_EQ(k_cache.dims(),
|
|
v_cache.dims(),
|
|
common::errors::InvalidArgument(
|
|
"paged_attn requires k_cache size is the "
|
|
"same as v_cache"));
|
|
|
|
const auto& kv_cache_dims = k_cache.dims();
|
|
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
|
4,
|
|
common::errors::InvalidArgument(
|
|
"paged_attn receive kv cache dims is "
|
|
"[num_blocks, kv_num_heads, block_size, head_size]"));
|
|
|
|
const auto& block_table_dims = block_table.dims();
|
|
PADDLE_ENFORCE_EQ(block_table_dims.size(),
|
|
2,
|
|
common::errors::InvalidArgument(
|
|
"paged_attn receive block_table dims is "
|
|
"[num_seqs, max_num_blocks_per_seq]"));
|
|
|
|
const auto& seq_lens_dims = seq_lens.dims();
|
|
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
|
|
1,
|
|
common::errors::InvalidArgument(
|
|
"paged_attn receive seq_lens dims is [num_seqs]"));
|
|
|
|
int num_seqs = q_dims[0];
|
|
int num_heads = q_dims[1];
|
|
int head_size = q_dims[2];
|
|
int max_num_blocks_per_seq = block_table_dims[1];
|
|
int q_stride = q.strides()[0];
|
|
int num_blocks = kv_cache_dims[0];
|
|
|
|
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
|
num_kv_heads,
|
|
common::errors::InvalidArgument(
|
|
"kv_cache_dims[1] must be equal to num_kv_head"));
|
|
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
|
block_size,
|
|
common::errors::InvalidArgument(
|
|
"kv_cache_dims[2] must be equal to block_size"));
|
|
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
|
head_size,
|
|
common::errors::InvalidArgument(
|
|
"kv_cache_dims[3] must be equal to head_size"));
|
|
PADDLE_ENFORCE_EQ(block_table_dims[0],
|
|
num_seqs,
|
|
common::errors::InvalidArgument(
|
|
"block_table_dims[0] must be equal to num_seqs"));
|
|
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
|
|
num_seqs,
|
|
common::errors::InvalidArgument(
|
|
"seq_lens_dims[0] must be equal to num_seqs"));
|
|
|
|
int kv_block_stride = k_cache.strides()[0];
|
|
int kv_head_stride = k_cache.strides()[1];
|
|
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
|
|
const void *key_ptr = k ? k.get().data() : nullptr;
|
|
const void *value_ptr = v ? v.get().data() : nullptr;
|
|
|
|
size_t workspace_size = 0;
|
|
void* workspace_ptr = nullptr;
|
|
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(
|
|
num_seqs, num_heads, num_kv_heads, head_size, block_size, max_context_len, &workspace_size));
|
|
|
|
CUDA_CHECK(cudaMalloc((void**)&workspace_ptr, workspace_size));
|
|
CUDA_CHECK(cudaMemset(workspace_ptr, 0xff, workspace_size));
|
|
|
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
|
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
|
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
|
|
|
PageAttentionWithKVCacheArguments args{
|
|
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
|
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, workspace_ptr};
|
|
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
|
out.data(),
|
|
data_type,
|
|
q.data(),
|
|
data_type,
|
|
num_seqs,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_size,
|
|
q_stride,
|
|
kv_block_stride,
|
|
kv_head_stride,
|
|
k_cache.data(),
|
|
data_type,
|
|
v_cache.data(),
|
|
data_type,
|
|
block_size,
|
|
max_num_blocks_per_seq,
|
|
max_context_len,
|
|
block_table.data<int32_t>(),
|
|
seq_lens.data<int32_t>(),
|
|
args));
|
|
|
|
CUDA_CHECK(cudaFree(workspace_ptr));
|
|
}
|
|
|
|
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|
const paddle::Tensor& k_cache,
|
|
const paddle::Tensor& v_cache,
|
|
const paddle::Tensor& block_table,
|
|
const paddle::Tensor& seq_lens,
|
|
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
|
const paddle::optional<paddle::Tensor> &k,
|
|
const paddle::optional<paddle::Tensor> &v,
|
|
int num_kv_heads,
|
|
float scale,
|
|
int block_size,
|
|
int max_context_len,
|
|
bool causal,
|
|
int window_left,
|
|
int window_right,
|
|
float softcap,
|
|
bool enable_cuda_graph,
|
|
bool use_sqrt_alibi) {
|
|
|
|
const auto dtype = q.dtype();
|
|
auto out = paddle::empty_like(q, dtype);
|
|
|
|
switch (dtype) {
|
|
case paddle::DataType::BFLOAT16:
|
|
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
|
|
k_cache,
|
|
v_cache,
|
|
block_table,
|
|
seq_lens,
|
|
alibi_slopes,
|
|
k,
|
|
v,
|
|
num_kv_heads,
|
|
scale,
|
|
block_size,
|
|
max_context_len,
|
|
causal,
|
|
window_left,
|
|
window_right,
|
|
softcap,
|
|
enable_cuda_graph,
|
|
use_sqrt_alibi,
|
|
out);
|
|
break;
|
|
case paddle::DataType::FLOAT16:
|
|
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
|
|
k_cache,
|
|
v_cache,
|
|
block_table,
|
|
seq_lens,
|
|
alibi_slopes,
|
|
k,
|
|
v,
|
|
num_kv_heads,
|
|
scale,
|
|
block_size,
|
|
max_context_len,
|
|
causal,
|
|
window_left,
|
|
window_right,
|
|
softcap,
|
|
enable_cuda_graph,
|
|
use_sqrt_alibi,
|
|
out);
|
|
break;
|
|
default:
|
|
PD_THROW("Unsupported data type for Paged attn");
|
|
}
|
|
return {out};
|
|
}
|
|
|
|
std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>& q_shape,
|
|
const std::vector<int64_t>& k_cache_shape,
|
|
const std::vector<int64_t>& v_cache_shape,
|
|
const std::vector<int64_t>& block_table_shape,
|
|
const std::vector<int64_t>& seq_lens_shape,
|
|
const std::vector<int64_t>& alibi_slopes_shape,
|
|
const std::vector<int64_t>& k_shape,
|
|
const std::vector<int64_t>& v_shape) {
|
|
return {q_shape};
|
|
}
|
|
|
|
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
|
|
const paddle::DataType& k_cache_dtype,
|
|
const paddle::DataType& v_cache_dtype,
|
|
const paddle::DataType& block_table_dtype,
|
|
const paddle::DataType& seq_lens_dtype,
|
|
const paddle::DataType& alibi_slopes_dtype,
|
|
const paddle::DataType& k_dtype,
|
|
const paddle::DataType& v_dtype) {
|
|
return {q_dtype};
|
|
}
|
|
|
|
|
|
PD_BUILD_STATIC_OP(paged_attn)
|
|
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", paddle::Optional("alibi_slopes"), paddle::Optional("k"), paddle::Optional("v")})
|
|
.Outputs({"out"})
|
|
.Attrs({"num_kv_heads:int",
|
|
"scale:float",
|
|
"block_size:int",
|
|
"max_context_len:int",
|
|
"causal:bool",
|
|
"window_left:int",
|
|
"window_right:int",
|
|
"softcap:float",
|
|
"enable_cuda_graph:bool",
|
|
"use_sqrt_alibi:bool"})
|
|
.SetKernelFn(PD_KERNEL(PagedAttn))
|
|
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
|
|
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));
|
|
|
|
|
|
PYBIND11_MODULE(fastdeploy_ops, m) {
|
|
m.def("paged_attn", &PagedAttn, "paged attn function");
|
|
}
|