[PD Disaggregation] Support PD deployment of DeepSeekv3. (#5251)

* Support deepseekv3 cache transfer for PD deploy

* clean some log info

---------

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-12-02 14:11:50 +08:00
committed by GitHub
parent 117980dd4e
commit 2e1680838f
17 changed files with 620 additions and 400 deletions

View File

@@ -13,22 +13,24 @@
// limitations under the License.
#pragma once
#include "helper.h"
#include "mla_cache_kernel.cuh"
#include "helper.h"
#include "remote_cache_kv_ipc.h"
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
pe_size,
block_size,
elem_nums);
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len) {
cudaStream_t stream = kv_pe.stream();
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
@@ -95,30 +128,34 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
@@ -126,18 +163,18 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const int blocksize = 128;
int grid_size = 1;
if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
"seq_lens_decoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
"block_tables",
paddle::Optional("kv_signal_data")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int"})
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(decode_mla_write_cache)