mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
This commit is contained in:
@@ -23,7 +23,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -74,7 +74,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -185,7 +185,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -206,7 +206,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
@@ -233,7 +233,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
@@ -247,7 +247,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
@@ -266,7 +266,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
@@ -281,7 +281,7 @@ PD_BUILD_OP(decode_mla_write_cache)
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
|
||||
Reference in New Issue
Block a user