[PD Disaggregation] Support PD deployment of DeepSeekv3. (#5251)

* Support deepseekv3 cache transfer for PD deploy

* clean some log info

---------

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-12-02 14:11:50 +08:00
committed by GitHub
parent 117980dd4e
commit 2e1680838f
17 changed files with 620 additions and 400 deletions

View File

@@ -13,22 +13,24 @@
// limitations under the License.
#pragma once
#include "helper.h"
#include "mla_cache_kernel.cuh"
#include "helper.h"
#include "remote_cache_kv_ipc.h"
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
pe_size,
block_size,
elem_nums);
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len) {
cudaStream_t stream = kv_pe.stream();
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
@@ -95,30 +128,34 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
@@ -126,18 +163,18 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const int blocksize = 128;
int grid_size = 1;
if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
"seq_lens_decoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
"block_tables",
paddle::Optional("kv_signal_data")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int"})
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(decode_mla_write_cache)

View File

@@ -527,6 +527,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len);

View File

@@ -13,8 +13,8 @@
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
* Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
@@ -39,8 +39,8 @@
#include "epilogue.cuh"
#include "helper.h"
#include "kernel_traits.cuh"
#include "mainloop_mma.cuh"
#include "mainloop_load.cuh"
#include "mainloop_mma.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
@@ -52,76 +52,91 @@ namespace mla_attn {
using namespace cute;
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
template <typename DTypeQ_,
typename DTypeKV_,
typename DTypeO_,
typename IdType_>
struct Params {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(
16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(
16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
int bsz;
int token_num;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_num;
int bsz;
int token_num;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_num;
float sm_scale;
float sm_scale;
};
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else if (group_size == 128) { \
constexpr size_t GROUP_SIZE = 128; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
template <typename CollectiveMainloop,
typename CollectiveEpilogue,
typename Ktraits,
bool CAUSAL,
int SM_COUNT = 132,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
__global__ void __launch_bounds__(
Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(
CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO;
@@ -147,7 +162,8 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
@@ -158,12 +174,14 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
int const warp_group_thread_idx =
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS;
@@ -173,17 +191,20 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.role = warp_group_idx == 0
? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
pipeline_params_q.consumer_arv_count =
cutlass::NumThreadsPerWarpGroup; // just one wg qk
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
MainloopPipeline pipeline_kv = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv,
pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
@@ -196,191 +217,217 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
if (warp_group_idx == 0) {
// producer
if constexpr(USE_REG_EALLOC) {
if constexpr (USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_dealloc<72>();
}
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
const uint32_t warp_idx_in_warpgroup =
__shfl_sync(0xffffffff, warp_idx % 4, 0);
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineStateQ smem_pipe_write_q =
cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv =
cutlass::make_producer_start_state<MainloopPipeline>();
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int seq_len_decoder_now =
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
cutlass::arch::NamedBarrier::sync(
Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
collective_mainloop.load_q(mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
collective_mainloop.load_kv(mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
collective_mainloop.load_kv_tma(mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id);
}
}
}
} else {
// consumer
if constexpr(USE_REG_EALLOC) {
if constexpr (USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_alloc<216>();
}
PipelineStateQ smem_pipe_read_q;
PipelineState smem_pipe_read_kv;
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
Tensor tOrO =
partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
auto attention_updater =
OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(
mainloop_params.sm_scale);
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int seq_len_decoder_now =
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
cutlass::arch::NamedBarrier::sync(
Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
mma_f16<Ktraits, CAUSAL>(mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
mma_f16_two_stages<Ktraits, CAUSAL>(mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
collective_epilogue.store(epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
template <typename KernelTraits,
bool CAUSAL,
typename Params,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(
Params &params, cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using NV_TYPE = typename KernelTraits::NV_TYPE;
using CollectiveMainloop =
CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveMainloop = CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.chunk_size_device,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_num,
params.max_draft_token_num
});
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments(
{make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim),
make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(
make_shape(
params.block_size, params.qk_head_dim, params.max_block_num),
make_stride(params.qk_head_dim,
_1{},
params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num,
params.bsz * params.max_draft_token_num *
params.q_num_head),
make_stride(params.bsz * params.max_draft_token_num *
params.q_num_head,
_1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.chunk_size_device,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_num,
params.max_draft_token_num});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
// Get the ptr to kernel function.
auto kernel =
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
auto kernel = MLAWithKVCacheKernel<CollectiveMainloop,
CollectiveEpilogue,
KernelTraits,
CAUSAL,
132>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
int device;
cudaGetDevice(&device);
int multiprocessor_count;
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
cudaDeviceGetAttribute(
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
int act_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
@@ -390,15 +437,15 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
dim3 grid_dims = {multiprocessor_count, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
mainloop_params, epilogue_params
);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(mainloop_params,
epilogue_params);
if (params.chunk_num > 1) {
constexpr int vec_size = 16 / sizeof(DTypeO);
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
dim3 grids_merge(multiprocessor_count,
params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE,
vec_size,
@@ -423,28 +470,35 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
template <uint32_t HEAD_DIM_QK,
uint32_t HEAD_DIM_VO,
typename NV_TYPE,
typename Params,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params &params,
cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
DISPATCH_GROUP_SIZE(params.q_num_head,
GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
} else {
return cudaErrorNotSupported;
}