[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)

* remove padding_offsets from atten
This commit is contained in:
周周周
2025-07-17 18:41:31 +08:00
committed by GitHub
parent d49f8fb30a
commit ddb10ac509
50 changed files with 311 additions and 288 deletions

View File

@@ -22,7 +22,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
@@ -53,7 +53,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
@@ -73,7 +73,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
@@ -99,7 +99,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -112,7 +112,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -130,7 +130,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
@@ -164,7 +164,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -205,7 +205,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
@@ -232,7 +232,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -246,7 +246,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -265,7 +265,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
"kv_cache",
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
@@ -280,7 +280,7 @@ PD_BUILD_OP(decode_mla_write_cache)
"kv_cache",
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})