[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
This commit is contained in:
周周周
2025-07-16 20:10:57 +08:00
committed by GitHub
parent 42b80182e0
commit aa76085d1f
47 changed files with 237 additions and 260 deletions

View File

@@ -23,7 +23,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
@@ -74,7 +74,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len) {
@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = cu_seqlens_q.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
@@ -185,7 +185,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
@@ -206,7 +206,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = cu_seqlens_q.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -233,7 +233,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
@@ -247,7 +247,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
@@ -266,7 +266,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
@@ -281,7 +281,7 @@ PD_BUILD_OP(decode_mla_write_cache)
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"cum_offsets",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})