[PD Disaggregation] Support PD deployment of DeepSeekv3. (#5251)

* Support deepseekv3 cache transfer for PD deploy

* clean some log info

---------

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-12-02 14:11:50 +08:00
committed by GitHub
parent 117980dd4e
commit 2e1680838f
17 changed files with 620 additions and 400 deletions

View File

@@ -13,22 +13,24 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "helper.h"
#include "mla_cache_kernel.cuh" #include "mla_cache_kernel.cuh"
#include "helper.h"
#include "remote_cache_kv_ipc.h"
template <paddle::DataType T> template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache( std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data, const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope, const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe, const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const int max_seq_len, const paddle::optional<paddle::Tensor>& kv_signal_data,
cudaStream_t& stream, const int max_seq_len,
paddle::Tensor* kv_cache) { cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_; typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_; typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t; typedef typename traits_::data_t data_t;
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
prefill_absorb_cache_kernel<DataType_, PackSize> prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>( <<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())), reinterpret_cast<DataType_*>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())), const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()), reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(), block_tables.data<int>(),
batch_id_per_token.data<int>(), batch_id_per_token.data<int>(),
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
pe_size, pe_size,
block_size, block_size,
elem_nums); elem_nums);
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {}; return {};
} }
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const int max_seq_len) { const int max_seq_len) {
cudaStream_t stream = kv_pe.stream(); cudaStream_t stream = kv_pe.stream();
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims(); const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims(); const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1]; meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0]; meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size; meta_data.head_dims_v = nope_size;
@@ -95,30 +128,34 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_decoder.dims()[0]; meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) { switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data, return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
kv_nope, meta_data,
kv_pe, kv_nope,
seq_lens, kv_pe,
seq_lens_decoder, seq_lens,
batch_id_per_token, seq_lens_decoder,
cu_seqlens_q, batch_id_per_token,
block_tables, cu_seqlens_q,
max_seq_len, block_tables,
stream, kv_signal_data,
const_cast<paddle::Tensor*>(&kv_cache)); max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} }
case paddle::DataType::FLOAT16: { case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data, return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
kv_nope, meta_data,
kv_pe, kv_nope,
seq_lens, kv_pe,
seq_lens_decoder, seq_lens,
batch_id_per_token, seq_lens_decoder,
cu_seqlens_q, batch_id_per_token,
block_tables, cu_seqlens_q,
max_seq_len, block_tables,
stream, kv_signal_data,
const_cast<paddle::Tensor*>(&kv_cache)); max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} }
} }
return {}; return {};
@@ -126,18 +163,18 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
template <paddle::DataType T> template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache( std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data, const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope, const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe, const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const int max_seq_len, const int max_seq_len,
const bool speculate_decoder, const bool speculate_decoder,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* kv_cache) { paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_; typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_; typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t; typedef typename traits_::data_t data_t;
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const int blocksize = 128; const int blocksize = 128;
int grid_size = 1; int grid_size = 1;
if (speculate_decoder) { if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size; const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize; const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size); GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize> speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>( <<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())), reinterpret_cast<DataType_*>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())), const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()), reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(), block_tables.data<int>(),
batch_id_per_token.data<int>(), batch_id_per_token.data<int>(),
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
GetNumBlocks<128>(pack_num, &grid_size); GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize> decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>( <<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())), reinterpret_cast<DataType_*>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())), const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()), reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(), block_tables.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims(); const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims(); const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1]; meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0]; meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size; meta_data.head_dims_v = nope_size;
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_encoder.dims()[0]; meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) { switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data, return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
kv_nope, meta_data,
kv_pe, kv_nope,
seq_lens, kv_pe,
seq_lens_encoder, seq_lens,
batch_id_per_token, seq_lens_encoder,
cu_seqlens_q, batch_id_per_token,
block_tables, cu_seqlens_q,
max_seq_len, block_tables,
speculate_decoder, max_seq_len,
stream, speculate_decoder,
const_cast<paddle::Tensor*>(&kv_cache)); stream,
const_cast<paddle::Tensor*>(&kv_cache));
} }
case paddle::DataType::FLOAT16: { case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data, return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
kv_nope, meta_data,
kv_pe, kv_nope,
seq_lens, kv_pe,
seq_lens_encoder, seq_lens,
batch_id_per_token, seq_lens_encoder,
cu_seqlens_q, batch_id_per_token,
block_tables, cu_seqlens_q,
max_seq_len, block_tables,
speculate_decoder, max_seq_len,
stream, speculate_decoder,
const_cast<paddle::Tensor*>(&kv_cache)); stream,
const_cast<paddle::Tensor*>(&kv_cache));
} }
} }
return {}; return {};
} }
PD_BUILD_STATIC_OP(prefill_mla_write_cache) PD_BUILD_STATIC_OP(prefill_mla_write_cache)
.Inputs({"kv_nope", .Inputs({"kv_nope",
"kv_pe", "kv_pe",
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
"seq_lens_decoder", "seq_lens_decoder",
"batch_id_per_token", "batch_id_per_token",
"cu_seqlens_q", "cu_seqlens_q",
"block_tables"}) "block_tables",
paddle::Optional("kv_signal_data")})
.Outputs({"kv_cache_out"}) .Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string", .Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel)); .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(decode_mla_write_cache) PD_BUILD_STATIC_OP(decode_mla_write_cache)

View File

@@ -527,6 +527,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const int max_seq_len); const int max_seq_len);

View File

@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
/* /*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
* Dao. Licensed under the BSD 3-Clause. * Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause.
* *
* Modified by the FlashInfer team. * Modified by the FlashInfer team.
*/ */
@@ -39,8 +39,8 @@
#include "epilogue.cuh" #include "epilogue.cuh"
#include "helper.h" #include "helper.h"
#include "kernel_traits.cuh" #include "kernel_traits.cuh"
#include "mainloop_mma.cuh"
#include "mainloop_load.cuh" #include "mainloop_load.cuh"
#include "mainloop_mma.cuh"
#include "utils.cuh" #include "utils.cuh"
#ifdef DEBUG_MLA #ifdef DEBUG_MLA
@@ -52,76 +52,91 @@ namespace mla_attn {
using namespace cute; using namespace cute;
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_> template <typename DTypeQ_,
typename DTypeKV_,
typename DTypeO_,
typename IdType_>
struct Params { struct Params {
using DTypeQ = DTypeQ_; using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_; using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_; using DTypeO = DTypeO_;
using IdType = IdType_; using IdType = IdType_;
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head] alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head] alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head] alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head] alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num] alignas(
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num] 16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(
16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables; alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time; alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_decoder; alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens; alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token; alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *batch_ids; alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch; alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x; alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device; alignas(16) IdType *chunk_size_device;
uint32_t q_stride_bsz; uint32_t q_stride_bsz;
uint32_t q_stride_head_num; uint32_t q_stride_head_num;
uint32_t kv_stride_block_num; uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size; uint32_t kv_stride_block_size;
uint32_t o_stride_bsz; uint32_t o_stride_bsz;
uint32_t o_stride_head_num; uint32_t o_stride_head_num;
int bsz; int bsz;
int token_num; int token_num;
int max_block_num; int max_block_num;
int max_block_num_per_seq; int max_block_num_per_seq;
int q_num_head; int q_num_head;
int qk_head_dim; int qk_head_dim;
int vo_head_dim; int vo_head_dim;
int block_size; int block_size;
int max_draft_token_num; int max_draft_token_num;
int chunk_num; int chunk_num;
float sm_scale; float sm_scale;
}; };
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ #define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \ if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \ constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \ __VA_ARGS__ \
} else if (group_size == 16) { \ } else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \ constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \ __VA_ARGS__ \
} else if (group_size == 64) { \ } else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \ constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \ __VA_ARGS__ \
} else { \ } else if (group_size == 128) { \
PD_THROW("not support the group_size: ", group_size); \ constexpr size_t GROUP_SIZE = 128; \
return cudaErrorNotSupported; \ __VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
} }
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true> template <typename CollectiveMainloop,
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1) typename CollectiveEpilogue,
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT typename Ktraits,
typename CollectiveMainloop::Params const mainloop_params, bool CAUSAL,
CUTE_GRID_CONSTANT int SM_COUNT = 132,
typename CollectiveEpilogue::Params const epilogue_params) { bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
__global__ void __launch_bounds__(
Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(
CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
using DTypeQ = typename Ktraits::DTypeQ; using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV; using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO; using DTypeO = typename Ktraits::DTypeO;
@@ -147,7 +162,8 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
using PipelineStateQ = typename MainloopPipelineQ::PipelineState; using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
extern __shared__ char shared_memory[]; extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory); auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
int const lane_predicate = cute::elect_one_sync(); int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync();
@@ -158,12 +174,14 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
} }
// Obtain warp index // Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; int const warp_group_thread_idx =
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params; PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx(); int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer pipeline_params.role = warp_group_idx == 0
: MainloopPipeline::ThreadCategory::Consumer; ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) { if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS; pipeline_params.num_consumers = NUM_MMA_THREADS;
@@ -173,17 +191,20 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
} }
PipelineParamsQ pipeline_params_q; PipelineParamsQ pipeline_params_q;
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer pipeline_params_q.role = warp_group_idx == 0
: MainloopPipelineQ::ThreadCategory::Consumer; ? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS; pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk pipeline_params_q.consumer_arv_count =
cutlass::NumThreadsPerWarpGroup; // just one wg qk
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q); MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
MainloopPipeline pipeline_kv = [&] { MainloopPipeline pipeline_kv = [&] {
if constexpr (use_tma_load_kv) { if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV; pipeline_params.transaction_bytes =
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params, CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv,
pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{}); /*cluster_shape=*/Shape<_1, _1, _1>{});
} else { } else {
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params); return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
@@ -196,191 +217,217 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
if (warp_group_idx == 0) { if (warp_group_idx == 0) {
// producer // producer
if constexpr(USE_REG_EALLOC) { if constexpr (USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_dealloc<72>(); cutlass::arch::warpgroup_reg_dealloc<72>();
} }
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0); const uint32_t warp_idx_in_warpgroup =
__shfl_sync(0xffffffff, warp_idx % 4, 0);
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>(); PipelineStateQ smem_pipe_write_q =
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>(); cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv =
cutlass::make_producer_start_state<MainloopPipeline>();
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i]; const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int seq_len_decoder_now =
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, cutlass::arch::NamedBarrier::sync(
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync)); Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q // load Q
collective_mainloop.load_q( collective_mainloop.load_q(mainloop_params,
mainloop_params, pipeline_q,
pipeline_q, smem_pipe_write_q,
smem_pipe_write_q, shared_storage,
shared_storage, threadIdx.x,
threadIdx.x, bid);
bid);
if constexpr (!use_tma_load_kv) { if constexpr (!use_tma_load_kv) {
// load kv // load kv
collective_mainloop.load_kv( collective_mainloop.load_kv(mainloop_params,
mainloop_params, pipeline_kv,
pipeline_kv, smem_pipe_write_kv,
smem_pipe_write_kv, shared_storage,
shared_storage, bid,
bid, seq_len_decoder_now,
seq_len_decoder_now, tile_id);
tile_id
);
} else { } else {
if (warp_idx_in_warpgroup == 0) { if (warp_idx_in_warpgroup == 0) {
// load kv tma // load kv tma
collective_mainloop.load_kv_tma( collective_mainloop.load_kv_tma(mainloop_params,
mainloop_params, pipeline_kv,
pipeline_kv, smem_pipe_write_kv,
smem_pipe_write_kv, shared_storage,
shared_storage, bid,
bid, seq_len_decoder_now,
seq_len_decoder_now, tile_id);
tile_id
);
} }
} }
} }
} else { } else {
// consumer // consumer
if constexpr(USE_REG_EALLOC) { if constexpr (USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_alloc<216>(); cutlass::arch::warpgroup_reg_alloc<216>();
} }
PipelineStateQ smem_pipe_read_q; PipelineStateQ smem_pipe_read_q;
PipelineState smem_pipe_read_kv; PipelineState smem_pipe_read_kv;
typename Ktraits::TiledMmaPVSS tiled_mma_pv; typename Ktraits::TiledMmaPVSS tiled_mma_pv;
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); Tensor tOrO =
partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale); auto attention_updater =
OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(
mainloop_params.sm_scale);
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO); clear(tOrO);
clear(attention_updater.scores_scale); clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i]; const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int seq_len_decoder_now =
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, cutlass::arch::NamedBarrier::sync(
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync)); Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) { if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>( mma_f16<Ktraits, CAUSAL>(mainloop_params,
mainloop_params, pipeline_q,
pipeline_q, smem_pipe_read_q,
smem_pipe_read_q, pipeline_kv,
pipeline_kv, smem_pipe_read_kv,
smem_pipe_read_kv, tOrO,
tOrO, attention_updater,
attention_updater, threadIdx.x - NUM_COPY_THREADS,
threadIdx.x - NUM_COPY_THREADS, bid,
bid, seq_len_decoder_now,
seq_len_decoder_now, seq_len_now,
seq_len_now, tile_id,
tile_id, shared_storage);
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) { } else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>( mma_f16_two_stages<Ktraits, CAUSAL>(mainloop_params,
mainloop_params, pipeline_q,
pipeline_q, smem_pipe_read_q,
smem_pipe_read_q, pipeline_kv,
pipeline_kv, smem_pipe_read_kv,
smem_pipe_read_kv, tOrO,
tOrO, attention_updater,
attention_updater, threadIdx.x - NUM_COPY_THREADS,
threadIdx.x - NUM_COPY_THREADS, bid,
bid, seq_len_decoder_now,
seq_len_decoder_now, seq_len_now,
seq_len_now, tile_id,
tile_id, shared_storage);
shared_storage);
} }
collective_epilogue.store( collective_epilogue.store(epilogue_params,
epilogue_params, tOrO,
tOrO, attention_updater.get_lse(),
attention_updater.get_lse(), shared_storage,
shared_storage, tiled_mma_pv,
tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS,
threadIdx.x - NUM_COPY_THREADS, bid,
bid, mainloop_params.bsz,
mainloop_params.bsz, seq_len_now,
seq_len_now, start_token_idx,
start_token_idx, tile_id,
tile_id, seq_len_decoder_now,
seq_len_decoder_now, chunk_size,
chunk_size, mainloop_params.max_draft_token_num,
mainloop_params.max_draft_token_num, mainloop_params.o_stride_bsz);
mainloop_params.o_stride_bsz); }
}
} }
} }
template <typename KernelTraits,
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true> bool CAUSAL,
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, typename Params,
cudaStream_t stream) { bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(
Params &params, cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ; using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV; using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO; using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType; using IdType = typename KernelTraits::IdType;
using NV_TYPE = typename KernelTraits::NV_TYPE; using NV_TYPE = typename KernelTraits::NV_TYPE;
using CollectiveMainloop = using CollectiveMainloop = CollectiveMainloop<KernelTraits, CAUSAL>;
CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>; using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ typename CollectiveMainloop::Params mainloop_params =
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q CollectiveMainloop::to_underlying_arguments(
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)), {make_layout(
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})), make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim),
params.Q, make_stride(params.qk_head_dim, _1{})), // layout q
params.KV, make_layout(
params.m, make_shape(
params.d, params.block_size, params.qk_head_dim, params.max_block_num),
params.block_tables, make_stride(params.qk_head_dim,
params.seq_lens_this_time, _1{},
params.seq_lens_decoder, params.block_size * params.qk_head_dim)),
params.cumsum_q_seqlens, make_layout(make_shape(params.chunk_num,
params.batch_ids, params.bsz * params.max_draft_token_num *
params.tile_ids_per_batch, params.q_num_head),
params.num_blocks_x, make_stride(params.bsz * params.max_draft_token_num *
params.chunk_size_device, params.q_num_head,
params.sm_scale, _1{})),
params.bsz, params.Q,
params.max_block_num, params.KV,
params.max_block_num_per_seq, params.m,
params.q_stride_bsz, params.d,
params.q_stride_head_num, params.block_tables,
params.kv_stride_block_num, params.seq_lens_this_time,
params.kv_stride_block_size, params.seq_lens_decoder,
params.o_stride_bsz, params.cumsum_q_seqlens,
params.o_stride_head_num, params.batch_ids,
params.chunk_num, params.tile_ids_per_batch,
params.max_draft_token_num params.num_blocks_x,
}); params.chunk_size_device,
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({ params.sm_scale,
params.O, params.bsz,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O params.max_block_num,
params.O_tmp, params.max_block_num_per_seq,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp params.q_stride_bsz,
}); params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_num,
params.max_draft_token_num});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
// Get the ptr to kernel function. // Get the ptr to kernel function.
auto kernel = auto kernel = MLAWithKVCacheKernel<CollectiveMainloop,
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>; CollectiveEpilogue,
KernelTraits,
CAUSAL,
132>;
int smem_size = sizeof(typename KernelTraits::SharedStorage); int smem_size = sizeof(typename KernelTraits::SharedStorage);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int multiprocessor_count; int multiprocessor_count;
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device); cudaDeviceGetAttribute(
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
int act_blocks_per_sm; int act_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor( cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size); &act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
@@ -390,15 +437,15 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
dim3 grid_dims = {multiprocessor_count, 1, 1}; dim3 grid_dims = {multiprocessor_count, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1); dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>( kernel<<<grid_dims, block_dims, smem_size, stream>>>(mainloop_params,
mainloop_params, epilogue_params epilogue_params);
);
if (params.chunk_num > 1) { if (params.chunk_num > 1) {
constexpr int vec_size = 16 / sizeof(DTypeO); constexpr int vec_size = 16 / sizeof(DTypeO);
constexpr int merge_block_size = 256; constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size; constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx; constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large dim3 grids_merge(multiprocessor_count,
params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky); dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE, merge_multi_chunks_kernel<NV_TYPE,
vec_size, vec_size,
@@ -423,28 +470,35 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
return cudaSuccess; return cudaSuccess;
} }
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true> template <uint32_t HEAD_DIM_QK,
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { uint32_t HEAD_DIM_VO,
typename NV_TYPE,
typename Params,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params &params,
cudaStream_t stream) {
constexpr bool CAUSAL = true; constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) { if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE, DISPATCH_GROUP_SIZE(params.q_num_head,
BatchMLAWithPagedKVCacheKernelTraitsDispatched< GROUP_SIZE,
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true, BatchMLAWithPagedKVCacheKernelTraitsDispatched<
HEAD_DIM_QK, AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
HEAD_DIM_VO, HEAD_DIM_QK,
GROUP_SIZE, HEAD_DIM_VO,
/*BLOCK_SHAPE_Q_=*/64, GROUP_SIZE,
/*BLOCK_SHAPE_KV_=*/64, /*BLOCK_SHAPE_Q_=*/64,
/*NUM_STAGES_=*/2, /*BLOCK_SHAPE_KV_=*/64,
typename Params::DTypeQ, /*NUM_STAGES_=*/2,
typename Params::DTypeKV, typename Params::DTypeQ,
typename Params::DTypeO, typename Params::DTypeKV,
typename Params::IdType, typename Params::DTypeO,
NV_TYPE>, typename Params::IdType,
CAUSAL, NV_TYPE>,
Params, CAUSAL,
USE_REG_EALLOC, Params,
USE_FIXED_BLOCK>(params, stream);) USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
} else { } else {
return cudaErrorNotSupported; return cudaErrorNotSupported;
} }

View File

@@ -1,6 +1,7 @@
pkill -9 -f python pkill -9 -f python
pkill -9 -f fastdeploy pkill -9 -f fastdeploy
pkill -9 -f gunicorn pkill -9 -f gunicorn
pkill -9 -f redis-server # Kill redis-server if you need.
#pkill -9 -f redis-server
sleep 1 sleep 1

View File

@@ -159,16 +159,22 @@ class CacheMessager:
cache_v = [] cache_v = []
self.messager = {} self.messager = {}
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
# value cache
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
if val_cache_key in self.gpu_cache_kvs:
val_cache = self.gpu_cache_kvs[val_cache_key]
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu():
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_v_ptr_list.append(val_cache.data_ptr())
# key cache
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache) cache_k.append(key_cache)
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr())) cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else: else:
cache_k_ptr_list.append(key_cache.data_ptr()) cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list) cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list) cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -198,7 +204,6 @@ class CacheMessager:
elif protocol == "rdma": elif protocol == "rdma":
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}") logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
self.messager[protocol] = RDMACommManager( self.messager[protocol] = RDMACommManager(
splitwise_role, splitwise_role,
rank, rank,
@@ -460,16 +465,22 @@ class CacheMessagerV1:
cache_v = [] cache_v = []
self.messager = {} self.messager = {}
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
# value cache
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
if val_cache_key in self.gpu_cache_kvs:
val_cache = self.gpu_cache_kvs[val_cache_key]
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu():
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_v_ptr_list.append(val_cache.data_ptr())
# key cache
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache) cache_k.append(key_cache)
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr())) cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else: else:
cache_k_ptr_list.append(key_cache.data_ptr()) cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list) cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list) cache_v_ptr_list = np.array(cache_v_ptr_list)

View File

@@ -245,6 +245,15 @@ class PrefixCacheManager:
log_dir = envs.FD_LOG_DIR log_dir = envs.FD_LOG_DIR
cache_manager_processes = [] cache_manager_processes = []
visible_devices = get_all_visible_devices() visible_devices = get_all_visible_devices()
val_cache_arg_str = ""
if val_cache_shape:
if isinstance(val_cache_shape, list):
val_shape_str = ",".join(map(str, val_cache_shape))
else:
val_shape_str = str(val_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
launch_cmd = ( launch_cmd = (
"FLAGS_allocator_strategy=auto_growth " "FLAGS_allocator_strategy=auto_growth "
@@ -259,7 +268,7 @@ class PrefixCacheManager:
+ f" --mp_num {tensor_parallel_size}" + f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}" + f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}" + f" --key_cache_shape {key_cache_shape}"
+ f" --value_cache_shape {val_cache_shape}" + val_cache_arg_str
+ f" --cache_queue_port {cache_config.cache_queue_port}" + f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}" + f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}" + f" --pod_ip {pod_ip}"
@@ -332,6 +341,15 @@ class PrefixCacheManager:
log_dir = envs.FD_LOG_DIR log_dir = envs.FD_LOG_DIR
cache_messager_processes = [] cache_messager_processes = []
visible_devices = get_all_visible_devices() visible_devices = get_all_visible_devices()
val_cache_arg_str = ""
if value_cache_shape:
if isinstance(value_cache_shape, list):
val_shape_str = ",".join(map(str, value_cache_shape))
else:
val_shape_str = str(value_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
launch_cmd = ( launch_cmd = (
"FLAGS_allocator_strategy=auto_growth " "FLAGS_allocator_strategy=auto_growth "
@@ -345,7 +363,7 @@ class PrefixCacheManager:
+ f" --mp_num {tensor_parallel_size}" + f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}" + f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}" + f" --key_cache_shape {key_cache_shape}"
+ f" --value_cache_shape {value_cache_shape}" + val_cache_arg_str
+ f" --pod_ip {pod_ip}" + f" --pod_ip {pod_ip}"
+ f" --cache_queue_port {cache_config.cache_queue_port}" + f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}" + f" --engine_worker_queue_port {engine_worker_queue_port}"

View File

@@ -198,8 +198,8 @@ int get_port_info(struct ibv_context* Context,
int parse_port_ib_info(); int parse_port_ib_info();
// Memory region exchange // Memory region exchange
bool client_exchange_mr(struct RdmaContext* ctx); bool client_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
bool server_exchange_mr(struct RdmaContext* ctx); bool server_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
bool server_send_memory_region(struct RdmaContext* ctx, bool server_send_memory_region(struct RdmaContext* ctx,
void* local_mr, void* local_mr,
int byte_num); int byte_num);

View File

@@ -149,6 +149,7 @@ class RDMACommunicator {
struct ibv_pd* g_pd = NULL; // fd struct ibv_pd* g_pd = NULL; // fd
int RDMACommunicator_status; // Communicator status flag int RDMACommunicator_status; // Communicator status flag
bool start_client_listener = false; // Client listener flag bool start_client_listener = false; // Client listener flag
bool has_value_cache_; // MLA does not have value cache.
}; };
#endif // KVCACHE_RDMA_H #endif // KVCACHE_RDMA_H

View File

@@ -712,8 +712,8 @@ bool exchange_mr_vector(struct RdmaContext *ctx,
* @param ctx The RDMA context * @param ctx The RDMA context
* @return true on success, false on failure * @return true on success, false on failure
*/ */
bool client_exchange_mr(struct RdmaContext *ctx) { bool client_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
LOGD("verb client exchange mr: start"); LOGD("verb client exchange mr: start. has_value_cache=%d", has_value_cache);
if (ctx->conn.layer_number <= 0) { if (ctx->conn.layer_number <= 0) {
ERR("Invalid layer number: %d", ctx->conn.layer_number); ERR("Invalid layer number: %d", ctx->conn.layer_number);
@@ -723,19 +723,27 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
auto layer_num = ctx->conn.layer_number; auto layer_num = ctx->conn.layer_number;
std::vector<void *> key_ptrs(layer_num); std::vector<void *> key_ptrs(layer_num);
std::vector<uint32_t> key_rkeys(layer_num); std::vector<uint32_t> key_rkeys(layer_num);
std::vector<void *> val_ptrs(layer_num); std::vector<void *> val_ptrs;
std::vector<uint32_t> val_rkeys(layer_num); std::vector<uint32_t> val_rkeys;
if (has_value_cache) {
val_ptrs.resize(layer_num);
val_rkeys.resize(layer_num);
}
if (!exchange_mr_vector(ctx, key_ptrs, true)) return false; if (!exchange_mr_vector(ctx, key_ptrs, true)) return false;
if (!exchange_mr_vector(ctx, key_rkeys, true)) return false; if (!exchange_mr_vector(ctx, key_rkeys, true)) return false;
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false; if (has_value_cache) {
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false; if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
}
for (int i = 0; i < layer_num; ++i) { for (int i = 0; i < layer_num; ++i) {
ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]); ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]);
ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]); ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]);
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]); if (has_value_cache) {
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]); ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
}
} }
return true; return true;
} }
@@ -746,8 +754,8 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
* @param ctx The RDMA context * @param ctx The RDMA context
* @return true on success, false on failure * @return true on success, false on failure
*/ */
bool server_exchange_mr(struct RdmaContext *ctx) { bool server_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
LOGD("verbs server exchange mr: start"); LOGD("verbs server exchange mr: start. has_value_cache=%d", has_value_cache);
if (ctx->conn.layer_number <= 0) { if (ctx->conn.layer_number <= 0) {
ERR("Invalid layer number: %d", ctx->conn.layer_number); ERR("Invalid layer number: %d", ctx->conn.layer_number);
@@ -759,8 +767,16 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
auto &val_mrs = ctx->conn.write_cache_value_server_mr_list; auto &val_mrs = ctx->conn.write_cache_value_server_mr_list;
// Verify that server memory regions are properly initialized // Verify that server memory regions are properly initialized
if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) { if (key_mrs.size() != layer_num) {
ERR("server write cache memory region size error"); ERR("server write cache KEY memory region size error: %zu vs %d",
key_mrs.size(),
layer_num);
return false;
}
if (has_value_cache && val_mrs.size() != layer_num) {
ERR("server write cache VALUE memory region size error: %zu vs %d",
val_mrs.size(),
layer_num);
return false; return false;
} }
@@ -772,22 +788,27 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
send_key_ptrs.reserve(layer_num); send_key_ptrs.reserve(layer_num);
send_key_rkeys.reserve(layer_num); send_key_rkeys.reserve(layer_num);
send_val_ptrs.reserve(layer_num); if (has_value_cache) {
send_val_rkeys.reserve(layer_num); send_val_ptrs.reserve(layer_num);
send_val_rkeys.reserve(layer_num);
}
// Collect memory region information from local MRs // Collect memory region information from local MRs
for (int i = 0; i < layer_num; ++i) { for (int i = 0; i < layer_num; ++i) {
send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr)); send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr));
send_key_rkeys.push_back(key_mrs[i]->rkey); send_key_rkeys.push_back(key_mrs[i]->rkey);
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr)); if (has_value_cache) {
send_val_rkeys.push_back(val_mrs[i]->rkey); send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
send_val_rkeys.push_back(val_mrs[i]->rkey);
}
} }
// Send all vectors to client
if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false; if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false;
if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false; if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false;
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false; if (has_value_cache) {
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false; if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
}
return true; return true;
} }

View File

@@ -78,6 +78,18 @@ RDMACommunicator::RDMACommunicator(std::string& role,
throw std::runtime_error("Invalid layer number"); throw std::runtime_error("Invalid layer number");
} }
if (local_cache_value_ptr_layer_head_.empty()) {
has_value_cache_ = false;
WARN(
"Value Cache is empty (Maybe MLA Model). RDMA will run in Key-Only "
"mode.");
} else {
has_value_cache_ = true;
if (local_cache_value_ptr_layer_head_.size() != layer_number) {
throw std::runtime_error("Key and Value cache layer number mismatch!");
}
}
// Step 2: Setup cache vectors and pointers // Step 2: Setup cache vectors and pointers
resize_vectors(); resize_vectors();
assign_pointers(); assign_pointers();
@@ -100,7 +112,6 @@ RDMACommunicator::RDMACommunicator(std::string& role,
}); });
server_thread.detach(); server_thread.detach();
} }
RDMACommunicator_status = 1; RDMACommunicator_status = 1;
INFO("RDMA communicator initialized successfully"); INFO("RDMA communicator initialized successfully");
} catch (const std::exception& e) { } catch (const std::exception& e) {
@@ -119,7 +130,9 @@ void RDMACommunicator::resize_vectors() {
} }
local_cache_key_ptr_per_layer.resize(layer_number); local_cache_key_ptr_per_layer.resize(layer_number);
local_cache_value_ptr_per_layer.resize(layer_number); if (has_value_cache_) {
local_cache_value_ptr_per_layer.resize(layer_number);
}
} }
void RDMACommunicator::assign_pointers() { void RDMACommunicator::assign_pointers() {
@@ -131,15 +144,19 @@ void RDMACommunicator::assign_pointers() {
// Assign pointers for each layer and block // Assign pointers for each layer and block
for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) { for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) {
// Validate layer head pointers // Validate layer head pointers
if (local_cache_key_ptr_layer_head_[layer_idx] == 0 || if (local_cache_key_ptr_layer_head_[layer_idx] == 0) {
local_cache_value_ptr_layer_head_[layer_idx] == 0) {
throw std::runtime_error("Invalid cache pointer for layer " + throw std::runtime_error("Invalid cache pointer for layer " +
std::to_string(layer_idx)); std::to_string(layer_idx));
} }
// Resize block vectors for current layer
local_cache_key_ptr_per_layer[layer_idx].resize(block_number); local_cache_key_ptr_per_layer[layer_idx].resize(block_number);
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
if (has_value_cache_) {
if (local_cache_value_ptr_layer_head_[layer_idx] == 0) {
throw std::runtime_error("Invalid VALUE cache pointer for layer " +
std::to_string(layer_idx));
}
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
}
// Calculate and assign block pointers // Calculate and assign block pointers
for (int block_idx = 0; block_idx < block_number; ++block_idx) { for (int block_idx = 0; block_idx < block_number; ++block_idx) {
@@ -147,9 +164,12 @@ void RDMACommunicator::assign_pointers() {
reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[layer_idx] + reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[layer_idx] +
block_idx * block_size_byte); block_idx * block_size_byte);
local_cache_value_ptr_per_layer[layer_idx][block_idx] = if (has_value_cache_) {
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[layer_idx] + local_cache_value_ptr_per_layer[layer_idx][block_idx] =
block_idx * block_size_byte); reinterpret_cast<void*>(
local_cache_value_ptr_layer_head_[layer_idx] +
block_idx * block_size_byte);
}
} }
} }
} }
@@ -347,7 +367,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
continue; continue;
} }
server_exchange_mr(ctx); server_exchange_mr(ctx, has_value_cache_);
} else { } else {
auto ctx_iter = connectionContexts.find(event_fd); auto ctx_iter = connectionContexts.find(event_fd);
if (ctx_iter == connectionContexts.end()) { if (ctx_iter == connectionContexts.end()) {
@@ -435,18 +455,33 @@ bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) {
return false; return false;
} }
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) { if (!write_mr_key_list.empty()) {
if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) { for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) { if (write_mr_key_list[layer_idx]) {
ERR("Failed to deregister memory region: write_mr_key_list, layer %d", if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
layer_idx); ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
} layer_idx);
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) { }
ERR("Failed to deregister memory region: write_mr_value_list, layer %d", write_mr_key_list[layer_idx] = nullptr;
layer_idx);
} }
} }
write_mr_key_list.clear();
} }
if (!write_mr_value_list.empty()) {
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
if (write_mr_value_list[layer_idx]) {
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
ERR("Failed to deregister memory region: write_mr_value_list, layer "
"%d",
layer_idx);
}
write_mr_value_list[layer_idx] = nullptr;
}
}
write_mr_value_list.clear();
}
return true; return true;
} }
@@ -548,7 +583,7 @@ int RDMACommunicator::connect(const std::string& dst_ip,
ERR("Couldn't getexchange port infodestinations"); ERR("Couldn't getexchange port infodestinations");
return static_cast<int>(ConnStatus::kError); return static_cast<int>(ConnStatus::kError);
} else { } else {
client_exchange_mr(ctx); client_exchange_mr(ctx, has_value_cache_);
} }
// Allocate RDMA read and register read buffers // Allocate RDMA read and register read buffers
@@ -735,15 +770,17 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
} }
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!write_mr_key_list.empty()) {
if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) {
WARN("Memory regions already registered"); WARN("Memory regions already registered");
return true; return true;
} }
const size_t list_size = layer_number; const size_t list_size = layer_number;
write_mr_key_list.resize(list_size, nullptr); write_mr_key_list.resize(list_size, nullptr);
write_mr_value_list.resize(list_size, nullptr);
if (has_value_cache_) {
write_mr_value_list.resize(list_size, nullptr);
}
const uint32_t access_flags = const uint32_t access_flags =
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_LOCAL_WRITE |
@@ -753,8 +790,6 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
for (int i = 0; i < static_cast<int>(list_size); ++i) { for (int i = 0; i < static_cast<int>(list_size); ++i) {
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]); void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
size_t size = static_cast<size_t>(block_size_byte) * block_number; size_t size = static_cast<size_t>(block_size_byte) * block_number;
write_mr_key_list[i] = write_mr_key_list[i] =
@@ -765,13 +800,18 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
access_flags); access_flags);
if (!write_mr_key_list[i]) goto fail; if (!write_mr_key_list[i]) goto fail;
write_mr_value_list[i] = if (has_value_cache_) {
register_memory_region(ctx->pd, void* val_ptr =
val_ptr, reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
size,
"client_value_" + std::to_string(i), write_mr_value_list[i] =
access_flags); register_memory_region(ctx->pd,
if (!write_mr_value_list[i]) goto fail; val_ptr,
size,
"client_value_" + std::to_string(i),
access_flags);
if (!write_mr_value_list[i]) goto fail;
}
} }
return true; return true;
@@ -812,8 +852,6 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
for (int i = 0; i < layer_number; ++i) { for (int i = 0; i < layer_number; ++i) {
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]); void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
size_t size = static_cast<size_t>(block_size_byte) * block_number; size_t size = static_cast<size_t>(block_size_byte) * block_number;
struct ibv_mr* key_mr = register_memory_region( struct ibv_mr* key_mr = register_memory_region(
@@ -822,21 +860,25 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
ERR("Failed to register key MR at layer %d", i); ERR("Failed to register key MR at layer %d", i);
goto fail; goto fail;
} }
struct ibv_mr* value_mr = register_memory_region(
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
if (!value_mr) {
ERR("Failed to register value MR at layer %d", i);
ibv_dereg_mr(key_mr);
goto fail;
}
write_cache_key_server_mr_list.push_back(key_mr); write_cache_key_server_mr_list.push_back(key_mr);
write_cache_value_server_mr_list.push_back(value_mr);
if (has_value_cache_) {
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
struct ibv_mr* value_mr = register_memory_region(
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
if (!value_mr) {
ERR("Failed to register value MR at layer %d", i);
ibv_dereg_mr(key_mr);
goto fail;
}
write_cache_value_server_mr_list.push_back(value_mr);
}
} }
ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list; ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list;
ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list; ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list;
return true; return true;
fail: fail:
@@ -899,8 +941,12 @@ int RDMACommunicator::write_cache(const std::string& ip,
uint32_t cache_key_rkey = uint32_t cache_key_rkey =
ctx->conn.write_cache_key_remote_rkey_list[layer_idx]; ctx->conn.write_cache_key_remote_rkey_list[layer_idx];
uint32_t cache_value_rkey =
ctx->conn.write_cache_value_remote_rkey_list[layer_idx]; uint32_t cache_value_rkey = 0;
if (has_value_cache_) {
cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
}
uint32_t crc_cache_key_rkey, crc_cache_value_rkey; uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size; bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
uint64_t offset_in_block = uint64_t offset_in_block =
@@ -914,15 +960,19 @@ int RDMACommunicator::write_cache(const std::string& ip,
cache_key_remote_addr[block_index] = (uint64_t( cache_key_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte + char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block)); offset_in_block));
char_ptr = static_cast<char*>(
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]); if (has_value_cache_) {
cache_value_remote_addr[block_index] = (uint64_t( char_ptr = static_cast<char*>(
char_ptr + remote_block_ids[block_index] * total_block_size_byte + ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
offset_in_block)); cache_value_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
}
} }
ctx->conn.wc_target_count = 0; ctx->conn.wc_target_count = 0;
for (int i = 0; i < 2; ++i) { int loop_count = has_value_cache_ ? 2 : 1;
for (int i = 0; i < loop_count; ++i) {
bool is_key = (i == 0); bool is_key = (i == 0);
uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey); uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey);
std::vector<uint64_t>& remote_addr = std::vector<uint64_t>& remote_addr =
@@ -1038,6 +1088,10 @@ void RDMACommunicator::prepare_write_requests(
bool is_key, bool is_key,
std::vector<uint64_t>& remote_addr, std::vector<uint64_t>& remote_addr,
uint32_t rkey) { uint32_t rkey) {
if (!is_key) {
assert(!write_mr_value_list.empty() &&
"Trying to process Value Cache but it is empty!");
}
auto block_num = local_block_ids.size(); auto block_num = local_block_ids.size();
for (size_t i = 0; i < block_num; ++i) { for (size_t i = 0; i < block_num; ++i) {

View File

@@ -40,11 +40,10 @@ class RDMACommManager:
try: try:
import rdma_comm import rdma_comm
except: except:
logger.error( raise RuntimeError(
"The installation of the RDMA library failed." "The installation of the RDMA library failed."
"Confirm whether your network card supports RDMA transmission." "Confirm whether your network card supports RDMA transmission."
) )
return
self.messager = rdma_comm.RDMACommunicator( self.messager = rdma_comm.RDMACommunicator(
splitwise_role, splitwise_role,
gpu_id, gpu_id,

View File

@@ -755,8 +755,9 @@ class LLMEngine:
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
) )
) )
ctx = multiprocessing.get_context("spawn")
self.dp_processed.append( self.dp_processed.append(
multiprocessing.Process( ctx.Process(
target=start_data_parallel_service, target=start_data_parallel_service,
args=( args=(
self.cfg, self.cfg,

View File

@@ -205,7 +205,6 @@ class MLAAttentionBackend(AttentionBackend):
self.group_size, self.group_size,
self.block_size, self.block_size,
) )
# MLA # MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
@@ -279,6 +278,7 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token, forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none", "none",
getattr(forward_meta, "max_input_length", -1), getattr(forward_meta, "max_input_length", -1),
) )
@@ -422,10 +422,10 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token, forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none", "none",
self.max_seq_len, self.max_seq_len,
) )
# FA # FA
fmha_out = self.flash_attn_func( fmha_out = self.flash_attn_func(
q, q,

View File

@@ -307,6 +307,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token, forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none", "none",
getattr(forward_meta, "max_input_length", -1), getattr(forward_meta, "max_input_length", -1),
) )

View File

@@ -258,10 +258,10 @@ class FusedMoE(nn.Layer):
else: else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
if not param._is_initialized():
param.initialize()
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts): if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
return return
if not param._is_initialized():
param.initialize()
weight_need_transpose = getattr(param, "weight_need_transpose", False) weight_need_transpose = getattr(param, "weight_need_transpose", False)
if shard_id is None: if shard_id is None:
# 1.gate up fused in disk # 1.gate up fused in disk

View File

@@ -341,6 +341,7 @@ class DeepseekV3MLAAttention(nn.Layer):
# NOTE: (changwenbin) qkv_a_proj horizontal fusion # NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split( query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
) )
@@ -399,6 +400,7 @@ class DeepseekV3MLAAttention(nn.Layer):
self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim), self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim),
] ]
) )
fmha_out_decode = self.mla_attn( fmha_out_decode = self.mla_attn(
q=q_input, q=q_input,
k=None, k=None,
@@ -418,6 +420,7 @@ class DeepseekV3MLAAttention(nn.Layer):
.transpose([1, 0, 2]) .transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) .reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
) )
if fmha_out is None: if fmha_out is None:
fmha_out = fmha_out_decode fmha_out = fmha_out_decode
else: else:
@@ -515,6 +518,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
@@ -674,7 +678,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
for loaded_weight_name, loaded_weight in weights_iterator: for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model") loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name: if weight_name not in loaded_weight_name:
continue continue
@@ -741,6 +744,20 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
) )
return position_ids, mask_encoder_batch return position_ids, mask_encoder_batch
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[1, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
def forward( def forward(
self, self,
ids_remove_padding: paddle.Tensor, ids_remove_padding: paddle.Tensor,

View File

@@ -2328,7 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group, group=self.parallel_config.tp_group,
) )
# 5. Post Process # 5. Post Process
model_output_data = ModelOutputData( model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"], next_tokens=self.share_inputs["next_tokens"],