mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
remove cum_offsets from get_block_shape_and_split_kv_block (#2913)
* remove padding_offsets from get_padding_offset.cu * remove padding_offsets from get_padding_offset.cu * remove padding_offsets from get_padding_offset.cu * remove cum_offsets from get_block_shape_and_split_kv_block * remove cum_offsets from get_block_shape_and_split_kv_block
This commit is contained in:
@@ -194,12 +194,12 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
|||||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||||
const int group_size, const int block_size,
|
const int group_size, const int block_size,
|
||||||
const int decoder_step_token_num) {
|
const int decoder_step_token_num) {
|
||||||
auto stream = seq_lens_encoder.stream();
|
auto stream = seq_lens_encoder.stream();
|
||||||
int bsz = cum_offsets.shape()[0];
|
int bsz = seq_lens_encoder.shape()[0];
|
||||||
auto max_len_tensor =
|
auto max_len_tensor =
|
||||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||||
@@ -335,8 +335,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||||
const paddle::DataType &seq_lens_encoder_dtype,
|
const paddle::DataType &seq_lens_encoder_dtype,
|
||||||
const paddle::DataType &seq_lens_decoder_dtype,
|
const paddle::DataType &seq_lens_decoder_dtype,
|
||||||
const paddle::DataType &seq_lens_this_time_dtype,
|
const paddle::DataType &seq_lens_this_time_dtype) {
|
||||||
const paddle::DataType &cum_offsets_dtype) {
|
|
||||||
return {
|
return {
|
||||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||||
@@ -347,8 +346,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
|||||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||||
const std::vector<int64_t> &seq_lens_this_time_shape,
|
const std::vector<int64_t> &seq_lens_this_time_shape) {
|
||||||
const std::vector<int64_t> &cum_offsets_shape) {
|
|
||||||
std::vector<int64_t> dynamic_shape = {-1};
|
std::vector<int64_t> dynamic_shape = {-1};
|
||||||
|
|
||||||
return {dynamic_shape,
|
return {dynamic_shape,
|
||||||
@@ -365,8 +363,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
|||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time",
|
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
||||||
"cum_offsets"})
|
|
||||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||||
paddle::Optional("encoder_num_blocks"),
|
paddle::Optional("encoder_num_blocks"),
|
||||||
|
@@ -234,7 +234,7 @@ paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
|||||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||||
const int group_size, const int block_size,
|
const int group_size, const int block_size,
|
||||||
const int decoder_step_token_num);
|
const int decoder_step_token_num);
|
||||||
|
@@ -145,7 +145,6 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cum_offsets,
|
|
||||||
metadata.encoder_block_shape_q,
|
metadata.encoder_block_shape_q,
|
||||||
metadata.decoder_block_shape_q,
|
metadata.decoder_block_shape_q,
|
||||||
self.num_heads // self.kv_num_heads,
|
self.num_heads // self.kv_num_heads,
|
||||||
|
@@ -151,7 +151,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cum_offsets,
|
|
||||||
metadata.encoder_block_shape_q,
|
metadata.encoder_block_shape_q,
|
||||||
metadata.decoder_block_shape_q,
|
metadata.decoder_block_shape_q,
|
||||||
self.num_heads // self.kv_num_heads,
|
self.num_heads // self.kv_num_heads,
|
||||||
|
@@ -173,7 +173,6 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cum_offsets,
|
|
||||||
metadata.encoder_block_shape_q,
|
metadata.encoder_block_shape_q,
|
||||||
metadata.decoder_block_shape_q,
|
metadata.decoder_block_shape_q,
|
||||||
self.num_heads // self.kv_num_heads,
|
self.num_heads // self.kv_num_heads,
|
||||||
|
@@ -28,7 +28,6 @@ def get_block_shape_and_split_kv_block(
|
|||||||
seq_lens_encoder: paddle.Tensor,
|
seq_lens_encoder: paddle.Tensor,
|
||||||
seq_lens_decoder: paddle.Tensor,
|
seq_lens_decoder: paddle.Tensor,
|
||||||
seq_lens_this_time: paddle.Tensor,
|
seq_lens_this_time: paddle.Tensor,
|
||||||
cum_offsets: paddle.Tensor,
|
|
||||||
encoder_block_shape_q: int,
|
encoder_block_shape_q: int,
|
||||||
decoder_block_shape_q: int,
|
decoder_block_shape_q: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -55,7 +54,6 @@ def get_block_shape_and_split_kv_block(
|
|||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
cum_offsets,
|
|
||||||
encoder_block_shape_q,
|
encoder_block_shape_q,
|
||||||
decoder_block_shape_q,
|
decoder_block_shape_q,
|
||||||
group_size,
|
group_size,
|
||||||
|
Reference in New Issue
Block a user