mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation] Support PD deployment of DeepSeekv3. (#5251)
* Support deepseekv3 cache transfer for PD deploy * clean some log info --------- Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -13,22 +13,24 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "helper.h"
|
|
||||||
#include "mla_cache_kernel.cuh"
|
#include "mla_cache_kernel.cuh"
|
||||||
|
#include "helper.h"
|
||||||
|
#include "remote_cache_kv_ipc.h"
|
||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
const paddle::Tensor& kv_nope,
|
const paddle::Tensor& kv_nope,
|
||||||
const paddle::Tensor& kv_pe,
|
const paddle::Tensor& kv_pe,
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& batch_id_per_token,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const int max_seq_len,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
cudaStream_t& stream,
|
const int max_seq_len,
|
||||||
paddle::Tensor* kv_cache) {
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* kv_cache) {
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
|||||||
|
|
||||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||||
<<<grid_size, blocksize, 0, stream>>>(
|
<<<grid_size, blocksize, 0, stream>>>(
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
reinterpret_cast<DataType_*>(
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||||
|
reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
batch_id_per_token.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
|||||||
pe_size,
|
pe_size,
|
||||||
block_size,
|
block_size,
|
||||||
elem_nums);
|
elem_nums);
|
||||||
|
|
||||||
|
const char* fmt_write_cache_completed_signal_str =
|
||||||
|
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||||
|
const char* FLAGS_use_pd_disaggregation_per_chunk =
|
||||||
|
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||||
|
|
||||||
|
if (fmt_write_cache_completed_signal_str &&
|
||||||
|
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||||
|
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||||
|
if (FLAGS_use_pd_disaggregation_per_chunk &&
|
||||||
|
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
|
||||||
|
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
|
||||||
|
cudaLaunchHostFunc(
|
||||||
|
stream,
|
||||||
|
&(RemoteCacheKvIpc::
|
||||||
|
save_cache_kv_complete_signal_layerwise_per_query),
|
||||||
|
(void*)nullptr);
|
||||||
|
} else {
|
||||||
|
if (kv_signal_data) {
|
||||||
|
cudaLaunchHostFunc(
|
||||||
|
stream,
|
||||||
|
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
|
||||||
|
(void*)(const_cast<int64_t*>(
|
||||||
|
kv_signal_data.get().data<int64_t>())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
|||||||
const paddle::Tensor& batch_id_per_token,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const int max_seq_len) {
|
const int max_seq_len) {
|
||||||
cudaStream_t stream = kv_pe.stream();
|
cudaStream_t stream = kv_pe.stream();
|
||||||
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
|||||||
const auto& kv_pe_dims = kv_pe.dims();
|
const auto& kv_pe_dims = kv_pe.dims();
|
||||||
const auto& kv_cache_dims = kv_cache.dims();
|
const auto& kv_cache_dims = kv_cache.dims();
|
||||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
const auto nope_size =
|
||||||
|
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||||
meta_data.token_nums = kv_nope_dims[0];
|
meta_data.token_nums = kv_nope_dims[0];
|
||||||
meta_data.head_dims = kv_cache_dims[3];
|
meta_data.head_dims = kv_cache_dims[3];
|
||||||
meta_data.head_dims_v = nope_size;
|
meta_data.head_dims_v = nope_size;
|
||||||
@@ -95,30 +128,34 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
|||||||
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||||
switch (kv_pe.dtype()) {
|
switch (kv_pe.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
|
||||||
kv_nope,
|
meta_data,
|
||||||
kv_pe,
|
kv_nope,
|
||||||
seq_lens,
|
kv_pe,
|
||||||
seq_lens_decoder,
|
seq_lens,
|
||||||
batch_id_per_token,
|
seq_lens_decoder,
|
||||||
cu_seqlens_q,
|
batch_id_per_token,
|
||||||
block_tables,
|
cu_seqlens_q,
|
||||||
max_seq_len,
|
block_tables,
|
||||||
stream,
|
kv_signal_data,
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
max_seq_len,
|
||||||
|
stream,
|
||||||
|
const_cast<paddle::Tensor*>(&kv_cache));
|
||||||
}
|
}
|
||||||
case paddle::DataType::FLOAT16: {
|
case paddle::DataType::FLOAT16: {
|
||||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
|
||||||
kv_nope,
|
meta_data,
|
||||||
kv_pe,
|
kv_nope,
|
||||||
seq_lens,
|
kv_pe,
|
||||||
seq_lens_decoder,
|
seq_lens,
|
||||||
batch_id_per_token,
|
seq_lens_decoder,
|
||||||
cu_seqlens_q,
|
batch_id_per_token,
|
||||||
block_tables,
|
cu_seqlens_q,
|
||||||
max_seq_len,
|
block_tables,
|
||||||
stream,
|
kv_signal_data,
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
max_seq_len,
|
||||||
|
stream,
|
||||||
|
const_cast<paddle::Tensor*>(&kv_cache));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
@@ -126,18 +163,18 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
|||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
const paddle::Tensor& kv_nope,
|
const paddle::Tensor& kv_nope,
|
||||||
const paddle::Tensor& kv_pe,
|
const paddle::Tensor& kv_pe,
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& batch_id_per_token,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* kv_cache) {
|
paddle::Tensor* kv_cache) {
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
|||||||
const int blocksize = 128;
|
const int blocksize = 128;
|
||||||
int grid_size = 1;
|
int grid_size = 1;
|
||||||
|
|
||||||
|
|
||||||
if (speculate_decoder) {
|
if (speculate_decoder) {
|
||||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||||
const int pack_num = elem_nums / PackSize;
|
const int pack_num = elem_nums / PackSize;
|
||||||
GetNumBlocks<128>(pack_num, &grid_size);
|
GetNumBlocks<128>(pack_num, &grid_size);
|
||||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||||
<<<grid_size, blocksize, 0, stream>>>(
|
<<<grid_size, blocksize, 0, stream>>>(
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
reinterpret_cast<DataType_*>(
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||||
|
reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
batch_id_per_token.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
|||||||
GetNumBlocks<128>(pack_num, &grid_size);
|
GetNumBlocks<128>(pack_num, &grid_size);
|
||||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||||
<<<grid_size, blocksize, 0, stream>>>(
|
<<<grid_size, blocksize, 0, stream>>>(
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
reinterpret_cast<DataType_*>(
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||||
|
reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
|||||||
const auto& kv_pe_dims = kv_pe.dims();
|
const auto& kv_pe_dims = kv_pe.dims();
|
||||||
const auto& kv_cache_dims = kv_cache.dims();
|
const auto& kv_cache_dims = kv_cache.dims();
|
||||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
const auto nope_size =
|
||||||
|
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||||
meta_data.token_nums = kv_nope_dims[0];
|
meta_data.token_nums = kv_nope_dims[0];
|
||||||
meta_data.head_dims = kv_cache_dims[3];
|
meta_data.head_dims = kv_cache_dims[3];
|
||||||
meta_data.head_dims_v = nope_size;
|
meta_data.head_dims_v = nope_size;
|
||||||
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
|||||||
meta_data.batch_size = seq_lens_encoder.dims()[0];
|
meta_data.batch_size = seq_lens_encoder.dims()[0];
|
||||||
switch (kv_pe.dtype()) {
|
switch (kv_pe.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
|
||||||
kv_nope,
|
meta_data,
|
||||||
kv_pe,
|
kv_nope,
|
||||||
seq_lens,
|
kv_pe,
|
||||||
seq_lens_encoder,
|
seq_lens,
|
||||||
batch_id_per_token,
|
seq_lens_encoder,
|
||||||
cu_seqlens_q,
|
batch_id_per_token,
|
||||||
block_tables,
|
cu_seqlens_q,
|
||||||
max_seq_len,
|
block_tables,
|
||||||
speculate_decoder,
|
max_seq_len,
|
||||||
stream,
|
speculate_decoder,
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
stream,
|
||||||
|
const_cast<paddle::Tensor*>(&kv_cache));
|
||||||
}
|
}
|
||||||
case paddle::DataType::FLOAT16: {
|
case paddle::DataType::FLOAT16: {
|
||||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
|
||||||
kv_nope,
|
meta_data,
|
||||||
kv_pe,
|
kv_nope,
|
||||||
seq_lens,
|
kv_pe,
|
||||||
seq_lens_encoder,
|
seq_lens,
|
||||||
batch_id_per_token,
|
seq_lens_encoder,
|
||||||
cu_seqlens_q,
|
batch_id_per_token,
|
||||||
block_tables,
|
cu_seqlens_q,
|
||||||
max_seq_len,
|
block_tables,
|
||||||
speculate_decoder,
|
max_seq_len,
|
||||||
stream,
|
speculate_decoder,
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
stream,
|
||||||
|
const_cast<paddle::Tensor*>(&kv_cache));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
||||||
.Inputs({"kv_nope",
|
.Inputs({"kv_nope",
|
||||||
"kv_pe",
|
"kv_pe",
|
||||||
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
|||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"batch_id_per_token",
|
"batch_id_per_token",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"block_tables"})
|
"block_tables",
|
||||||
|
paddle::Optional("kv_signal_data")})
|
||||||
.Outputs({"kv_cache_out"})
|
.Outputs({"kv_cache_out"})
|
||||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||||
.Attrs({"cache_quant_type_str: std::string",
|
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
|
||||||
"max_seq_len: int"})
|
|
||||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(decode_mla_write_cache)
|
PD_BUILD_STATIC_OP(decode_mla_write_cache)
|
||||||
|
|||||||
@@ -527,6 +527,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
|||||||
const paddle::Tensor& batch_id_per_token,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const int max_seq_len);
|
const int max_seq_len);
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,8 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
|
||||||
* Dao. Licensed under the BSD 3-Clause.
|
* Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause.
|
||||||
*
|
*
|
||||||
* Modified by the FlashInfer team.
|
* Modified by the FlashInfer team.
|
||||||
*/
|
*/
|
||||||
@@ -39,8 +39,8 @@
|
|||||||
#include "epilogue.cuh"
|
#include "epilogue.cuh"
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "kernel_traits.cuh"
|
#include "kernel_traits.cuh"
|
||||||
#include "mainloop_mma.cuh"
|
|
||||||
#include "mainloop_load.cuh"
|
#include "mainloop_load.cuh"
|
||||||
|
#include "mainloop_mma.cuh"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
|
|
||||||
#ifdef DEBUG_MLA
|
#ifdef DEBUG_MLA
|
||||||
@@ -52,76 +52,91 @@ namespace mla_attn {
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
template <typename DTypeQ_,
|
||||||
|
typename DTypeKV_,
|
||||||
|
typename DTypeO_,
|
||||||
|
typename IdType_>
|
||||||
struct Params {
|
struct Params {
|
||||||
using DTypeQ = DTypeQ_;
|
using DTypeQ = DTypeQ_;
|
||||||
using DTypeKV = DTypeKV_;
|
using DTypeKV = DTypeKV_;
|
||||||
using DTypeO = DTypeO_;
|
using DTypeO = DTypeO_;
|
||||||
using IdType = IdType_;
|
using IdType = IdType_;
|
||||||
|
|
||||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
alignas(
|
||||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||||
|
alignas(
|
||||||
|
16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||||
|
|
||||||
alignas(16) IdType *block_tables;
|
alignas(16) IdType *block_tables;
|
||||||
alignas(16) IdType *seq_lens_this_time;
|
alignas(16) IdType *seq_lens_this_time;
|
||||||
alignas(16) IdType *seq_lens_decoder;
|
alignas(16) IdType *seq_lens_decoder;
|
||||||
alignas(16) IdType *cumsum_q_seqlens;
|
alignas(16) IdType *cumsum_q_seqlens;
|
||||||
alignas(16) IdType *batch_id_per_token;
|
alignas(16) IdType *batch_id_per_token;
|
||||||
|
|
||||||
alignas(16) IdType *batch_ids;
|
alignas(16) IdType *batch_ids;
|
||||||
alignas(16) IdType *tile_ids_per_batch;
|
alignas(16) IdType *tile_ids_per_batch;
|
||||||
alignas(16) IdType *num_blocks_x;
|
alignas(16) IdType *num_blocks_x;
|
||||||
alignas(16) IdType *chunk_size_device;
|
alignas(16) IdType *chunk_size_device;
|
||||||
|
|
||||||
uint32_t q_stride_bsz;
|
uint32_t q_stride_bsz;
|
||||||
uint32_t q_stride_head_num;
|
uint32_t q_stride_head_num;
|
||||||
|
|
||||||
uint32_t kv_stride_block_num;
|
uint32_t kv_stride_block_num;
|
||||||
uint32_t kv_stride_block_size;
|
uint32_t kv_stride_block_size;
|
||||||
|
|
||||||
uint32_t o_stride_bsz;
|
uint32_t o_stride_bsz;
|
||||||
uint32_t o_stride_head_num;
|
uint32_t o_stride_head_num;
|
||||||
|
|
||||||
int bsz;
|
int bsz;
|
||||||
int token_num;
|
int token_num;
|
||||||
int max_block_num;
|
int max_block_num;
|
||||||
int max_block_num_per_seq;
|
int max_block_num_per_seq;
|
||||||
int q_num_head;
|
int q_num_head;
|
||||||
int qk_head_dim;
|
int qk_head_dim;
|
||||||
int vo_head_dim;
|
int vo_head_dim;
|
||||||
int block_size;
|
int block_size;
|
||||||
int max_draft_token_num;
|
int max_draft_token_num;
|
||||||
int chunk_num;
|
int chunk_num;
|
||||||
|
|
||||||
float sm_scale;
|
float sm_scale;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||||
if (group_size == 8) { \
|
if (group_size == 8) { \
|
||||||
constexpr size_t GROUP_SIZE = 8; \
|
constexpr size_t GROUP_SIZE = 8; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else if (group_size == 16) { \
|
} else if (group_size == 16) { \
|
||||||
constexpr size_t GROUP_SIZE = 16; \
|
constexpr size_t GROUP_SIZE = 16; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else if (group_size == 64) { \
|
} else if (group_size == 64) { \
|
||||||
constexpr size_t GROUP_SIZE = 64; \
|
constexpr size_t GROUP_SIZE = 64; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else { \
|
} else if (group_size == 128) { \
|
||||||
PD_THROW("not support the group_size: ", group_size); \
|
constexpr size_t GROUP_SIZE = 128; \
|
||||||
return cudaErrorNotSupported; \
|
__VA_ARGS__ \
|
||||||
|
} else { \
|
||||||
|
PD_THROW("not support the group_size: ", group_size); \
|
||||||
|
return cudaErrorNotSupported; \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
template <typename CollectiveMainloop,
|
||||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
typename CollectiveEpilogue,
|
||||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
typename Ktraits,
|
||||||
typename CollectiveMainloop::Params const mainloop_params,
|
bool CAUSAL,
|
||||||
CUTE_GRID_CONSTANT
|
int SM_COUNT = 132,
|
||||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
bool USE_REG_EALLOC = false,
|
||||||
|
bool USE_FIXED_BLOCK = true>
|
||||||
|
__global__ void __launch_bounds__(
|
||||||
|
Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1)
|
||||||
|
MLAWithKVCacheKernel(
|
||||||
|
CUTE_GRID_CONSTANT
|
||||||
|
typename CollectiveMainloop::Params const mainloop_params,
|
||||||
|
CUTE_GRID_CONSTANT
|
||||||
|
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||||
using DTypeQ = typename Ktraits::DTypeQ;
|
using DTypeQ = typename Ktraits::DTypeQ;
|
||||||
using DTypeKV = typename Ktraits::DTypeKV;
|
using DTypeKV = typename Ktraits::DTypeKV;
|
||||||
using DTypeO = typename Ktraits::DTypeO;
|
using DTypeO = typename Ktraits::DTypeO;
|
||||||
@@ -147,7 +162,8 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
|||||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||||
|
|
||||||
extern __shared__ char shared_memory[];
|
extern __shared__ char shared_memory[];
|
||||||
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
auto &shared_storage =
|
||||||
|
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
|
||||||
|
|
||||||
int const lane_predicate = cute::elect_one_sync();
|
int const lane_predicate = cute::elect_one_sync();
|
||||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||||
@@ -158,12 +174,14 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Obtain warp index
|
// Obtain warp index
|
||||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
int const warp_group_thread_idx =
|
||||||
|
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
PipelineParams pipeline_params;
|
PipelineParams pipeline_params;
|
||||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||||
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
pipeline_params.role = warp_group_idx == 0
|
||||||
: MainloopPipeline::ThreadCategory::Consumer;
|
? MainloopPipeline::ThreadCategory::Producer
|
||||||
|
: MainloopPipeline::ThreadCategory::Consumer;
|
||||||
if constexpr (use_tma_load_kv) {
|
if constexpr (use_tma_load_kv) {
|
||||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||||
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
||||||
@@ -173,17 +191,20 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
|||||||
}
|
}
|
||||||
|
|
||||||
PipelineParamsQ pipeline_params_q;
|
PipelineParamsQ pipeline_params_q;
|
||||||
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
pipeline_params_q.role = warp_group_idx == 0
|
||||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
? MainloopPipelineQ::ThreadCategory::Producer
|
||||||
|
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||||
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
||||||
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
pipeline_params_q.consumer_arv_count =
|
||||||
|
cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||||
|
|
||||||
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
||||||
MainloopPipeline pipeline_kv = [&] {
|
MainloopPipeline pipeline_kv = [&] {
|
||||||
if constexpr (use_tma_load_kv) {
|
if constexpr (use_tma_load_kv) {
|
||||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
pipeline_params.transaction_bytes =
|
||||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
CollectiveMainloop::TmaTransactionBytesKV;
|
||||||
|
return MainloopPipeline(shared_storage.pipeline_kv,
|
||||||
|
pipeline_params,
|
||||||
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
||||||
} else {
|
} else {
|
||||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
||||||
@@ -196,191 +217,217 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
|||||||
|
|
||||||
if (warp_group_idx == 0) {
|
if (warp_group_idx == 0) {
|
||||||
// producer
|
// producer
|
||||||
if constexpr(USE_REG_EALLOC) {
|
if constexpr (USE_REG_EALLOC) {
|
||||||
cutlass::arch::warpgroup_reg_dealloc<72>();
|
cutlass::arch::warpgroup_reg_dealloc<72>();
|
||||||
}
|
}
|
||||||
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
const uint32_t warp_idx_in_warpgroup =
|
||||||
|
__shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||||
|
|
||||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
PipelineStateQ smem_pipe_write_q =
|
||||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||||
|
PipelineState smem_pipe_write_kv =
|
||||||
|
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||||
const int bid = mainloop_params.batch_ids[i];
|
const int bid = mainloop_params.batch_ids[i];
|
||||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
const int seq_len_decoder_now =
|
||||||
|
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
cutlass::arch::NamedBarrier::sync(
|
||||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
Ktraits::NUM_THREADS,
|
||||||
|
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||||
|
|
||||||
// load Q
|
// load Q
|
||||||
collective_mainloop.load_q(
|
collective_mainloop.load_q(mainloop_params,
|
||||||
mainloop_params,
|
pipeline_q,
|
||||||
pipeline_q,
|
smem_pipe_write_q,
|
||||||
smem_pipe_write_q,
|
shared_storage,
|
||||||
shared_storage,
|
threadIdx.x,
|
||||||
threadIdx.x,
|
bid);
|
||||||
bid);
|
|
||||||
|
|
||||||
if constexpr (!use_tma_load_kv) {
|
if constexpr (!use_tma_load_kv) {
|
||||||
// load kv
|
// load kv
|
||||||
collective_mainloop.load_kv(
|
collective_mainloop.load_kv(mainloop_params,
|
||||||
mainloop_params,
|
pipeline_kv,
|
||||||
pipeline_kv,
|
smem_pipe_write_kv,
|
||||||
smem_pipe_write_kv,
|
shared_storage,
|
||||||
shared_storage,
|
bid,
|
||||||
bid,
|
seq_len_decoder_now,
|
||||||
seq_len_decoder_now,
|
tile_id);
|
||||||
tile_id
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
if (warp_idx_in_warpgroup == 0) {
|
if (warp_idx_in_warpgroup == 0) {
|
||||||
// load kv tma
|
// load kv tma
|
||||||
collective_mainloop.load_kv_tma(
|
collective_mainloop.load_kv_tma(mainloop_params,
|
||||||
mainloop_params,
|
pipeline_kv,
|
||||||
pipeline_kv,
|
smem_pipe_write_kv,
|
||||||
smem_pipe_write_kv,
|
shared_storage,
|
||||||
shared_storage,
|
bid,
|
||||||
bid,
|
seq_len_decoder_now,
|
||||||
seq_len_decoder_now,
|
tile_id);
|
||||||
tile_id
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// consumer
|
// consumer
|
||||||
if constexpr(USE_REG_EALLOC) {
|
if constexpr (USE_REG_EALLOC) {
|
||||||
cutlass::arch::warpgroup_reg_alloc<216>();
|
cutlass::arch::warpgroup_reg_alloc<216>();
|
||||||
}
|
}
|
||||||
PipelineStateQ smem_pipe_read_q;
|
PipelineStateQ smem_pipe_read_q;
|
||||||
PipelineState smem_pipe_read_kv;
|
PipelineState smem_pipe_read_kv;
|
||||||
|
|
||||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
||||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
Tensor tOrO =
|
||||||
|
partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||||
|
|
||||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
auto attention_updater =
|
||||||
|
OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(
|
||||||
|
mainloop_params.sm_scale);
|
||||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||||
clear(tOrO);
|
clear(tOrO);
|
||||||
clear(attention_updater.scores_scale);
|
clear(attention_updater.scores_scale);
|
||||||
const int bid = mainloop_params.batch_ids[i];
|
const int bid = mainloop_params.batch_ids[i];
|
||||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
const int seq_len_decoder_now =
|
||||||
|
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
cutlass::arch::NamedBarrier::sync(
|
||||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
Ktraits::NUM_THREADS,
|
||||||
|
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||||
|
|
||||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||||
mma_f16<Ktraits, CAUSAL>(
|
mma_f16<Ktraits, CAUSAL>(mainloop_params,
|
||||||
mainloop_params,
|
pipeline_q,
|
||||||
pipeline_q,
|
smem_pipe_read_q,
|
||||||
smem_pipe_read_q,
|
pipeline_kv,
|
||||||
pipeline_kv,
|
smem_pipe_read_kv,
|
||||||
smem_pipe_read_kv,
|
tOrO,
|
||||||
tOrO,
|
attention_updater,
|
||||||
attention_updater,
|
threadIdx.x - NUM_COPY_THREADS,
|
||||||
threadIdx.x - NUM_COPY_THREADS,
|
bid,
|
||||||
bid,
|
seq_len_decoder_now,
|
||||||
seq_len_decoder_now,
|
seq_len_now,
|
||||||
seq_len_now,
|
tile_id,
|
||||||
tile_id,
|
shared_storage);
|
||||||
shared_storage);
|
|
||||||
} else if (BLOCK_SHAPE_KV == 32) {
|
} else if (BLOCK_SHAPE_KV == 32) {
|
||||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
mma_f16_two_stages<Ktraits, CAUSAL>(mainloop_params,
|
||||||
mainloop_params,
|
pipeline_q,
|
||||||
pipeline_q,
|
smem_pipe_read_q,
|
||||||
smem_pipe_read_q,
|
pipeline_kv,
|
||||||
pipeline_kv,
|
smem_pipe_read_kv,
|
||||||
smem_pipe_read_kv,
|
tOrO,
|
||||||
tOrO,
|
attention_updater,
|
||||||
attention_updater,
|
threadIdx.x - NUM_COPY_THREADS,
|
||||||
threadIdx.x - NUM_COPY_THREADS,
|
bid,
|
||||||
bid,
|
seq_len_decoder_now,
|
||||||
seq_len_decoder_now,
|
seq_len_now,
|
||||||
seq_len_now,
|
tile_id,
|
||||||
tile_id,
|
shared_storage);
|
||||||
shared_storage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
collective_epilogue.store(
|
collective_epilogue.store(epilogue_params,
|
||||||
epilogue_params,
|
tOrO,
|
||||||
tOrO,
|
attention_updater.get_lse(),
|
||||||
attention_updater.get_lse(),
|
shared_storage,
|
||||||
shared_storage,
|
tiled_mma_pv,
|
||||||
tiled_mma_pv,
|
threadIdx.x - NUM_COPY_THREADS,
|
||||||
threadIdx.x - NUM_COPY_THREADS,
|
bid,
|
||||||
bid,
|
mainloop_params.bsz,
|
||||||
mainloop_params.bsz,
|
seq_len_now,
|
||||||
seq_len_now,
|
start_token_idx,
|
||||||
start_token_idx,
|
tile_id,
|
||||||
tile_id,
|
seq_len_decoder_now,
|
||||||
seq_len_decoder_now,
|
chunk_size,
|
||||||
chunk_size,
|
mainloop_params.max_draft_token_num,
|
||||||
mainloop_params.max_draft_token_num,
|
mainloop_params.o_stride_bsz);
|
||||||
mainloop_params.o_stride_bsz);
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename KernelTraits,
|
||||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
bool CAUSAL,
|
||||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
typename Params,
|
||||||
cudaStream_t stream) {
|
bool USE_REG_EALLOC = false,
|
||||||
|
bool USE_FIXED_BLOCK = true>
|
||||||
|
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(
|
||||||
|
Params ¶ms, cudaStream_t stream) {
|
||||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||||
using DTypeKV = typename KernelTraits::DTypeKV;
|
using DTypeKV = typename KernelTraits::DTypeKV;
|
||||||
using DTypeO = typename KernelTraits::DTypeO;
|
using DTypeO = typename KernelTraits::DTypeO;
|
||||||
using IdType = typename KernelTraits::IdType;
|
using IdType = typename KernelTraits::IdType;
|
||||||
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
||||||
|
|
||||||
using CollectiveMainloop =
|
using CollectiveMainloop = CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||||
CollectiveMainloop<KernelTraits, CAUSAL>;
|
|
||||||
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
||||||
|
|
||||||
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
typename CollectiveMainloop::Params mainloop_params =
|
||||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
CollectiveMainloop::to_underlying_arguments(
|
||||||
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
{make_layout(
|
||||||
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim),
|
||||||
params.Q,
|
make_stride(params.qk_head_dim, _1{})), // layout q
|
||||||
params.KV,
|
make_layout(
|
||||||
params.m,
|
make_shape(
|
||||||
params.d,
|
params.block_size, params.qk_head_dim, params.max_block_num),
|
||||||
params.block_tables,
|
make_stride(params.qk_head_dim,
|
||||||
params.seq_lens_this_time,
|
_1{},
|
||||||
params.seq_lens_decoder,
|
params.block_size * params.qk_head_dim)),
|
||||||
params.cumsum_q_seqlens,
|
make_layout(make_shape(params.chunk_num,
|
||||||
params.batch_ids,
|
params.bsz * params.max_draft_token_num *
|
||||||
params.tile_ids_per_batch,
|
params.q_num_head),
|
||||||
params.num_blocks_x,
|
make_stride(params.bsz * params.max_draft_token_num *
|
||||||
params.chunk_size_device,
|
params.q_num_head,
|
||||||
params.sm_scale,
|
_1{})),
|
||||||
params.bsz,
|
params.Q,
|
||||||
params.max_block_num,
|
params.KV,
|
||||||
params.max_block_num_per_seq,
|
params.m,
|
||||||
params.q_stride_bsz,
|
params.d,
|
||||||
params.q_stride_head_num,
|
params.block_tables,
|
||||||
params.kv_stride_block_num,
|
params.seq_lens_this_time,
|
||||||
params.kv_stride_block_size,
|
params.seq_lens_decoder,
|
||||||
params.o_stride_bsz,
|
params.cumsum_q_seqlens,
|
||||||
params.o_stride_head_num,
|
params.batch_ids,
|
||||||
params.chunk_num,
|
params.tile_ids_per_batch,
|
||||||
params.max_draft_token_num
|
params.num_blocks_x,
|
||||||
});
|
params.chunk_size_device,
|
||||||
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
params.sm_scale,
|
||||||
params.O,
|
params.bsz,
|
||||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
params.max_block_num,
|
||||||
params.O_tmp,
|
params.max_block_num_per_seq,
|
||||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
params.q_stride_bsz,
|
||||||
});
|
params.q_stride_head_num,
|
||||||
|
params.kv_stride_block_num,
|
||||||
|
params.kv_stride_block_size,
|
||||||
|
params.o_stride_bsz,
|
||||||
|
params.o_stride_head_num,
|
||||||
|
params.chunk_num,
|
||||||
|
params.max_draft_token_num});
|
||||||
|
typename CollectiveEpilogue::Params epilogue_params =
|
||||||
|
CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||||
|
params.O,
|
||||||
|
make_layout(
|
||||||
|
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
|
||||||
|
make_stride(params.vo_head_dim, _1{})), // layout O
|
||||||
|
params.O_tmp,
|
||||||
|
make_layout(
|
||||||
|
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
|
||||||
|
make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||||
|
});
|
||||||
|
|
||||||
// Get the ptr to kernel function.
|
// Get the ptr to kernel function.
|
||||||
auto kernel =
|
auto kernel = MLAWithKVCacheKernel<CollectiveMainloop,
|
||||||
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
CollectiveEpilogue,
|
||||||
|
KernelTraits,
|
||||||
|
CAUSAL,
|
||||||
|
132>;
|
||||||
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
||||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
int device;
|
int device;
|
||||||
cudaGetDevice(&device);
|
cudaGetDevice(&device);
|
||||||
int multiprocessor_count;
|
int multiprocessor_count;
|
||||||
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
cudaDeviceGetAttribute(
|
||||||
|
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||||
int act_blocks_per_sm;
|
int act_blocks_per_sm;
|
||||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||||
@@ -390,15 +437,15 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
|||||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||||
dim3 block_dims(ctaSize, 1, 1);
|
dim3 block_dims(ctaSize, 1, 1);
|
||||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
kernel<<<grid_dims, block_dims, smem_size, stream>>>(mainloop_params,
|
||||||
mainloop_params, epilogue_params
|
epilogue_params);
|
||||||
);
|
|
||||||
if (params.chunk_num > 1) {
|
if (params.chunk_num > 1) {
|
||||||
constexpr int vec_size = 16 / sizeof(DTypeO);
|
constexpr int vec_size = 16 / sizeof(DTypeO);
|
||||||
constexpr int merge_block_size = 256;
|
constexpr int merge_block_size = 256;
|
||||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
dim3 grids_merge(multiprocessor_count,
|
||||||
|
params.q_num_head); // 128k is too large
|
||||||
dim3 blocks_merge(blockx, blocky);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
merge_multi_chunks_kernel<NV_TYPE,
|
merge_multi_chunks_kernel<NV_TYPE,
|
||||||
vec_size,
|
vec_size,
|
||||||
@@ -423,28 +470,35 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
|||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
template <uint32_t HEAD_DIM_QK,
|
||||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
uint32_t HEAD_DIM_VO,
|
||||||
|
typename NV_TYPE,
|
||||||
|
typename Params,
|
||||||
|
bool USE_REG_EALLOC = false,
|
||||||
|
bool USE_FIXED_BLOCK = true>
|
||||||
|
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params ¶ms,
|
||||||
|
cudaStream_t stream) {
|
||||||
constexpr bool CAUSAL = true;
|
constexpr bool CAUSAL = true;
|
||||||
if constexpr (HEAD_DIM_QK == 576) {
|
if constexpr (HEAD_DIM_QK == 576) {
|
||||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
DISPATCH_GROUP_SIZE(params.q_num_head,
|
||||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
GROUP_SIZE,
|
||||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||||
HEAD_DIM_QK,
|
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||||
HEAD_DIM_VO,
|
HEAD_DIM_QK,
|
||||||
GROUP_SIZE,
|
HEAD_DIM_VO,
|
||||||
/*BLOCK_SHAPE_Q_=*/64,
|
GROUP_SIZE,
|
||||||
/*BLOCK_SHAPE_KV_=*/64,
|
/*BLOCK_SHAPE_Q_=*/64,
|
||||||
/*NUM_STAGES_=*/2,
|
/*BLOCK_SHAPE_KV_=*/64,
|
||||||
typename Params::DTypeQ,
|
/*NUM_STAGES_=*/2,
|
||||||
typename Params::DTypeKV,
|
typename Params::DTypeQ,
|
||||||
typename Params::DTypeO,
|
typename Params::DTypeKV,
|
||||||
typename Params::IdType,
|
typename Params::DTypeO,
|
||||||
NV_TYPE>,
|
typename Params::IdType,
|
||||||
CAUSAL,
|
NV_TYPE>,
|
||||||
Params,
|
CAUSAL,
|
||||||
USE_REG_EALLOC,
|
Params,
|
||||||
USE_FIXED_BLOCK>(params, stream);)
|
USE_REG_EALLOC,
|
||||||
|
USE_FIXED_BLOCK>(params, stream);)
|
||||||
} else {
|
} else {
|
||||||
return cudaErrorNotSupported;
|
return cudaErrorNotSupported;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
pkill -9 -f python
|
pkill -9 -f python
|
||||||
pkill -9 -f fastdeploy
|
pkill -9 -f fastdeploy
|
||||||
pkill -9 -f gunicorn
|
pkill -9 -f gunicorn
|
||||||
pkill -9 -f redis-server
|
# Kill redis-server if you need.
|
||||||
|
#pkill -9 -f redis-server
|
||||||
|
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|||||||
@@ -159,16 +159,22 @@ class CacheMessager:
|
|||||||
cache_v = []
|
cache_v = []
|
||||||
self.messager = {}
|
self.messager = {}
|
||||||
for layer_idx in range(self.num_layers):
|
for layer_idx in range(self.num_layers):
|
||||||
|
# value cache
|
||||||
|
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||||
|
if val_cache_key in self.gpu_cache_kvs:
|
||||||
|
val_cache = self.gpu_cache_kvs[val_cache_key]
|
||||||
|
cache_v.append(val_cache)
|
||||||
|
if paddle.is_compiled_with_xpu():
|
||||||
|
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||||
|
else:
|
||||||
|
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||||
|
# key cache
|
||||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
|
||||||
cache_k.append(key_cache)
|
cache_k.append(key_cache)
|
||||||
cache_v.append(val_cache)
|
|
||||||
if paddle.is_compiled_with_xpu():
|
if paddle.is_compiled_with_xpu():
|
||||||
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
||||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
|
||||||
else:
|
else:
|
||||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
|
||||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||||
|
|
||||||
@@ -198,7 +204,6 @@ class CacheMessager:
|
|||||||
|
|
||||||
elif protocol == "rdma":
|
elif protocol == "rdma":
|
||||||
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
||||||
|
|
||||||
self.messager[protocol] = RDMACommManager(
|
self.messager[protocol] = RDMACommManager(
|
||||||
splitwise_role,
|
splitwise_role,
|
||||||
rank,
|
rank,
|
||||||
@@ -460,16 +465,22 @@ class CacheMessagerV1:
|
|||||||
cache_v = []
|
cache_v = []
|
||||||
self.messager = {}
|
self.messager = {}
|
||||||
for layer_idx in range(self.num_layers):
|
for layer_idx in range(self.num_layers):
|
||||||
|
# value cache
|
||||||
|
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||||
|
if val_cache_key in self.gpu_cache_kvs:
|
||||||
|
val_cache = self.gpu_cache_kvs[val_cache_key]
|
||||||
|
cache_v.append(val_cache)
|
||||||
|
if paddle.is_compiled_with_xpu():
|
||||||
|
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||||
|
else:
|
||||||
|
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||||
|
# key cache
|
||||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
|
||||||
cache_k.append(key_cache)
|
cache_k.append(key_cache)
|
||||||
cache_v.append(val_cache)
|
|
||||||
if paddle.is_compiled_with_xpu():
|
if paddle.is_compiled_with_xpu():
|
||||||
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
||||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
|
||||||
else:
|
else:
|
||||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
|
||||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||||
|
|
||||||
|
|||||||
@@ -245,6 +245,15 @@ class PrefixCacheManager:
|
|||||||
log_dir = envs.FD_LOG_DIR
|
log_dir = envs.FD_LOG_DIR
|
||||||
cache_manager_processes = []
|
cache_manager_processes = []
|
||||||
visible_devices = get_all_visible_devices()
|
visible_devices = get_all_visible_devices()
|
||||||
|
|
||||||
|
val_cache_arg_str = ""
|
||||||
|
if val_cache_shape:
|
||||||
|
if isinstance(val_cache_shape, list):
|
||||||
|
val_shape_str = ",".join(map(str, val_cache_shape))
|
||||||
|
else:
|
||||||
|
val_shape_str = str(val_cache_shape)
|
||||||
|
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||||
|
|
||||||
for i in range(tensor_parallel_size):
|
for i in range(tensor_parallel_size):
|
||||||
launch_cmd = (
|
launch_cmd = (
|
||||||
"FLAGS_allocator_strategy=auto_growth "
|
"FLAGS_allocator_strategy=auto_growth "
|
||||||
@@ -259,7 +268,7 @@ class PrefixCacheManager:
|
|||||||
+ f" --mp_num {tensor_parallel_size}"
|
+ f" --mp_num {tensor_parallel_size}"
|
||||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||||
+ f" --key_cache_shape {key_cache_shape}"
|
+ f" --key_cache_shape {key_cache_shape}"
|
||||||
+ f" --value_cache_shape {val_cache_shape}"
|
+ val_cache_arg_str
|
||||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||||
+ f" --pod_ip {pod_ip}"
|
+ f" --pod_ip {pod_ip}"
|
||||||
@@ -332,6 +341,15 @@ class PrefixCacheManager:
|
|||||||
log_dir = envs.FD_LOG_DIR
|
log_dir = envs.FD_LOG_DIR
|
||||||
cache_messager_processes = []
|
cache_messager_processes = []
|
||||||
visible_devices = get_all_visible_devices()
|
visible_devices = get_all_visible_devices()
|
||||||
|
|
||||||
|
val_cache_arg_str = ""
|
||||||
|
if value_cache_shape:
|
||||||
|
if isinstance(value_cache_shape, list):
|
||||||
|
val_shape_str = ",".join(map(str, value_cache_shape))
|
||||||
|
else:
|
||||||
|
val_shape_str = str(value_cache_shape)
|
||||||
|
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||||
|
|
||||||
for i in range(tensor_parallel_size):
|
for i in range(tensor_parallel_size):
|
||||||
launch_cmd = (
|
launch_cmd = (
|
||||||
"FLAGS_allocator_strategy=auto_growth "
|
"FLAGS_allocator_strategy=auto_growth "
|
||||||
@@ -345,7 +363,7 @@ class PrefixCacheManager:
|
|||||||
+ f" --mp_num {tensor_parallel_size}"
|
+ f" --mp_num {tensor_parallel_size}"
|
||||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||||
+ f" --key_cache_shape {key_cache_shape}"
|
+ f" --key_cache_shape {key_cache_shape}"
|
||||||
+ f" --value_cache_shape {value_cache_shape}"
|
+ val_cache_arg_str
|
||||||
+ f" --pod_ip {pod_ip}"
|
+ f" --pod_ip {pod_ip}"
|
||||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||||
|
|||||||
@@ -198,8 +198,8 @@ int get_port_info(struct ibv_context* Context,
|
|||||||
int parse_port_ib_info();
|
int parse_port_ib_info();
|
||||||
|
|
||||||
// Memory region exchange
|
// Memory region exchange
|
||||||
bool client_exchange_mr(struct RdmaContext* ctx);
|
bool client_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
|
||||||
bool server_exchange_mr(struct RdmaContext* ctx);
|
bool server_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
|
||||||
bool server_send_memory_region(struct RdmaContext* ctx,
|
bool server_send_memory_region(struct RdmaContext* ctx,
|
||||||
void* local_mr,
|
void* local_mr,
|
||||||
int byte_num);
|
int byte_num);
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ class RDMACommunicator {
|
|||||||
struct ibv_pd* g_pd = NULL; // fd
|
struct ibv_pd* g_pd = NULL; // fd
|
||||||
int RDMACommunicator_status; // Communicator status flag
|
int RDMACommunicator_status; // Communicator status flag
|
||||||
bool start_client_listener = false; // Client listener flag
|
bool start_client_listener = false; // Client listener flag
|
||||||
|
bool has_value_cache_; // MLA does not have value cache.
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // KVCACHE_RDMA_H
|
#endif // KVCACHE_RDMA_H
|
||||||
|
|||||||
@@ -712,8 +712,8 @@ bool exchange_mr_vector(struct RdmaContext *ctx,
|
|||||||
* @param ctx The RDMA context
|
* @param ctx The RDMA context
|
||||||
* @return true on success, false on failure
|
* @return true on success, false on failure
|
||||||
*/
|
*/
|
||||||
bool client_exchange_mr(struct RdmaContext *ctx) {
|
bool client_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
|
||||||
LOGD("verb client exchange mr: start");
|
LOGD("verb client exchange mr: start. has_value_cache=%d", has_value_cache);
|
||||||
|
|
||||||
if (ctx->conn.layer_number <= 0) {
|
if (ctx->conn.layer_number <= 0) {
|
||||||
ERR("Invalid layer number: %d", ctx->conn.layer_number);
|
ERR("Invalid layer number: %d", ctx->conn.layer_number);
|
||||||
@@ -723,19 +723,27 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
|
|||||||
auto layer_num = ctx->conn.layer_number;
|
auto layer_num = ctx->conn.layer_number;
|
||||||
std::vector<void *> key_ptrs(layer_num);
|
std::vector<void *> key_ptrs(layer_num);
|
||||||
std::vector<uint32_t> key_rkeys(layer_num);
|
std::vector<uint32_t> key_rkeys(layer_num);
|
||||||
std::vector<void *> val_ptrs(layer_num);
|
std::vector<void *> val_ptrs;
|
||||||
std::vector<uint32_t> val_rkeys(layer_num);
|
std::vector<uint32_t> val_rkeys;
|
||||||
|
if (has_value_cache) {
|
||||||
|
val_ptrs.resize(layer_num);
|
||||||
|
val_rkeys.resize(layer_num);
|
||||||
|
}
|
||||||
|
|
||||||
if (!exchange_mr_vector(ctx, key_ptrs, true)) return false;
|
if (!exchange_mr_vector(ctx, key_ptrs, true)) return false;
|
||||||
if (!exchange_mr_vector(ctx, key_rkeys, true)) return false;
|
if (!exchange_mr_vector(ctx, key_rkeys, true)) return false;
|
||||||
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
|
if (has_value_cache) {
|
||||||
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
|
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
|
||||||
|
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < layer_num; ++i) {
|
for (int i = 0; i < layer_num; ++i) {
|
||||||
ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]);
|
ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]);
|
||||||
ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]);
|
ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]);
|
||||||
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
|
if (has_value_cache) {
|
||||||
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
|
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
|
||||||
|
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -746,8 +754,8 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
|
|||||||
* @param ctx The RDMA context
|
* @param ctx The RDMA context
|
||||||
* @return true on success, false on failure
|
* @return true on success, false on failure
|
||||||
*/
|
*/
|
||||||
bool server_exchange_mr(struct RdmaContext *ctx) {
|
bool server_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
|
||||||
LOGD("verbs server exchange mr: start");
|
LOGD("verbs server exchange mr: start. has_value_cache=%d", has_value_cache);
|
||||||
|
|
||||||
if (ctx->conn.layer_number <= 0) {
|
if (ctx->conn.layer_number <= 0) {
|
||||||
ERR("Invalid layer number: %d", ctx->conn.layer_number);
|
ERR("Invalid layer number: %d", ctx->conn.layer_number);
|
||||||
@@ -759,8 +767,16 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
|
|||||||
auto &val_mrs = ctx->conn.write_cache_value_server_mr_list;
|
auto &val_mrs = ctx->conn.write_cache_value_server_mr_list;
|
||||||
|
|
||||||
// Verify that server memory regions are properly initialized
|
// Verify that server memory regions are properly initialized
|
||||||
if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) {
|
if (key_mrs.size() != layer_num) {
|
||||||
ERR("server write cache memory region size error");
|
ERR("server write cache KEY memory region size error: %zu vs %d",
|
||||||
|
key_mrs.size(),
|
||||||
|
layer_num);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (has_value_cache && val_mrs.size() != layer_num) {
|
||||||
|
ERR("server write cache VALUE memory region size error: %zu vs %d",
|
||||||
|
val_mrs.size(),
|
||||||
|
layer_num);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -772,22 +788,27 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
|
|||||||
|
|
||||||
send_key_ptrs.reserve(layer_num);
|
send_key_ptrs.reserve(layer_num);
|
||||||
send_key_rkeys.reserve(layer_num);
|
send_key_rkeys.reserve(layer_num);
|
||||||
send_val_ptrs.reserve(layer_num);
|
if (has_value_cache) {
|
||||||
send_val_rkeys.reserve(layer_num);
|
send_val_ptrs.reserve(layer_num);
|
||||||
|
send_val_rkeys.reserve(layer_num);
|
||||||
|
}
|
||||||
|
|
||||||
// Collect memory region information from local MRs
|
// Collect memory region information from local MRs
|
||||||
for (int i = 0; i < layer_num; ++i) {
|
for (int i = 0; i < layer_num; ++i) {
|
||||||
send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr));
|
send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr));
|
||||||
send_key_rkeys.push_back(key_mrs[i]->rkey);
|
send_key_rkeys.push_back(key_mrs[i]->rkey);
|
||||||
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
|
if (has_value_cache) {
|
||||||
send_val_rkeys.push_back(val_mrs[i]->rkey);
|
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
|
||||||
|
send_val_rkeys.push_back(val_mrs[i]->rkey);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send all vectors to client
|
|
||||||
if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false;
|
if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false;
|
||||||
if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false;
|
if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false;
|
||||||
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
|
if (has_value_cache) {
|
||||||
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
|
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
|
||||||
|
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,6 +78,18 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
|||||||
throw std::runtime_error("Invalid layer number");
|
throw std::runtime_error("Invalid layer number");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (local_cache_value_ptr_layer_head_.empty()) {
|
||||||
|
has_value_cache_ = false;
|
||||||
|
WARN(
|
||||||
|
"Value Cache is empty (Maybe MLA Model). RDMA will run in Key-Only "
|
||||||
|
"mode.");
|
||||||
|
} else {
|
||||||
|
has_value_cache_ = true;
|
||||||
|
if (local_cache_value_ptr_layer_head_.size() != layer_number) {
|
||||||
|
throw std::runtime_error("Key and Value cache layer number mismatch!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Step 2: Setup cache vectors and pointers
|
// Step 2: Setup cache vectors and pointers
|
||||||
resize_vectors();
|
resize_vectors();
|
||||||
assign_pointers();
|
assign_pointers();
|
||||||
@@ -100,7 +112,6 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
|||||||
});
|
});
|
||||||
server_thread.detach();
|
server_thread.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
RDMACommunicator_status = 1;
|
RDMACommunicator_status = 1;
|
||||||
INFO("RDMA communicator initialized successfully");
|
INFO("RDMA communicator initialized successfully");
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
@@ -119,7 +130,9 @@ void RDMACommunicator::resize_vectors() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
local_cache_key_ptr_per_layer.resize(layer_number);
|
local_cache_key_ptr_per_layer.resize(layer_number);
|
||||||
local_cache_value_ptr_per_layer.resize(layer_number);
|
if (has_value_cache_) {
|
||||||
|
local_cache_value_ptr_per_layer.resize(layer_number);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RDMACommunicator::assign_pointers() {
|
void RDMACommunicator::assign_pointers() {
|
||||||
@@ -131,15 +144,19 @@ void RDMACommunicator::assign_pointers() {
|
|||||||
// Assign pointers for each layer and block
|
// Assign pointers for each layer and block
|
||||||
for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) {
|
||||||
// Validate layer head pointers
|
// Validate layer head pointers
|
||||||
if (local_cache_key_ptr_layer_head_[layer_idx] == 0 ||
|
if (local_cache_key_ptr_layer_head_[layer_idx] == 0) {
|
||||||
local_cache_value_ptr_layer_head_[layer_idx] == 0) {
|
|
||||||
throw std::runtime_error("Invalid cache pointer for layer " +
|
throw std::runtime_error("Invalid cache pointer for layer " +
|
||||||
std::to_string(layer_idx));
|
std::to_string(layer_idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resize block vectors for current layer
|
|
||||||
local_cache_key_ptr_per_layer[layer_idx].resize(block_number);
|
local_cache_key_ptr_per_layer[layer_idx].resize(block_number);
|
||||||
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
|
|
||||||
|
if (has_value_cache_) {
|
||||||
|
if (local_cache_value_ptr_layer_head_[layer_idx] == 0) {
|
||||||
|
throw std::runtime_error("Invalid VALUE cache pointer for layer " +
|
||||||
|
std::to_string(layer_idx));
|
||||||
|
}
|
||||||
|
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate and assign block pointers
|
// Calculate and assign block pointers
|
||||||
for (int block_idx = 0; block_idx < block_number; ++block_idx) {
|
for (int block_idx = 0; block_idx < block_number; ++block_idx) {
|
||||||
@@ -147,9 +164,12 @@ void RDMACommunicator::assign_pointers() {
|
|||||||
reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[layer_idx] +
|
reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[layer_idx] +
|
||||||
block_idx * block_size_byte);
|
block_idx * block_size_byte);
|
||||||
|
|
||||||
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
|
if (has_value_cache_) {
|
||||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[layer_idx] +
|
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
|
||||||
block_idx * block_size_byte);
|
reinterpret_cast<void*>(
|
||||||
|
local_cache_value_ptr_layer_head_[layer_idx] +
|
||||||
|
block_idx * block_size_byte);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -347,7 +367,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
server_exchange_mr(ctx);
|
server_exchange_mr(ctx, has_value_cache_);
|
||||||
} else {
|
} else {
|
||||||
auto ctx_iter = connectionContexts.find(event_fd);
|
auto ctx_iter = connectionContexts.find(event_fd);
|
||||||
if (ctx_iter == connectionContexts.end()) {
|
if (ctx_iter == connectionContexts.end()) {
|
||||||
@@ -435,18 +455,33 @@ bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
|
if (!write_mr_key_list.empty()) {
|
||||||
if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) {
|
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
|
||||||
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
|
if (write_mr_key_list[layer_idx]) {
|
||||||
ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
|
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
|
||||||
layer_idx);
|
ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
|
||||||
}
|
layer_idx);
|
||||||
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
|
}
|
||||||
ERR("Failed to deregister memory region: write_mr_value_list, layer %d",
|
write_mr_key_list[layer_idx] = nullptr;
|
||||||
layer_idx);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
write_mr_key_list.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!write_mr_value_list.empty()) {
|
||||||
|
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
|
||||||
|
if (write_mr_value_list[layer_idx]) {
|
||||||
|
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
|
||||||
|
ERR("Failed to deregister memory region: write_mr_value_list, layer "
|
||||||
|
"%d",
|
||||||
|
layer_idx);
|
||||||
|
}
|
||||||
|
write_mr_value_list[layer_idx] = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write_mr_value_list.clear();
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -548,7 +583,7 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
|||||||
ERR("Couldn't getexchange port infodestinations");
|
ERR("Couldn't getexchange port infodestinations");
|
||||||
return static_cast<int>(ConnStatus::kError);
|
return static_cast<int>(ConnStatus::kError);
|
||||||
} else {
|
} else {
|
||||||
client_exchange_mr(ctx);
|
client_exchange_mr(ctx, has_value_cache_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate RDMA read and register read buffers
|
// Allocate RDMA read and register read buffers
|
||||||
@@ -735,15 +770,17 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
if (!write_mr_key_list.empty()) {
|
||||||
if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) {
|
|
||||||
WARN("Memory regions already registered");
|
WARN("Memory regions already registered");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t list_size = layer_number;
|
const size_t list_size = layer_number;
|
||||||
write_mr_key_list.resize(list_size, nullptr);
|
write_mr_key_list.resize(list_size, nullptr);
|
||||||
write_mr_value_list.resize(list_size, nullptr);
|
|
||||||
|
if (has_value_cache_) {
|
||||||
|
write_mr_value_list.resize(list_size, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t access_flags =
|
const uint32_t access_flags =
|
||||||
IBV_ACCESS_LOCAL_WRITE |
|
IBV_ACCESS_LOCAL_WRITE |
|
||||||
@@ -753,8 +790,6 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
|
|||||||
|
|
||||||
for (int i = 0; i < static_cast<int>(list_size); ++i) {
|
for (int i = 0; i < static_cast<int>(list_size); ++i) {
|
||||||
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
|
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
|
||||||
void* val_ptr =
|
|
||||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
|
||||||
size_t size = static_cast<size_t>(block_size_byte) * block_number;
|
size_t size = static_cast<size_t>(block_size_byte) * block_number;
|
||||||
|
|
||||||
write_mr_key_list[i] =
|
write_mr_key_list[i] =
|
||||||
@@ -765,13 +800,18 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
|
|||||||
access_flags);
|
access_flags);
|
||||||
if (!write_mr_key_list[i]) goto fail;
|
if (!write_mr_key_list[i]) goto fail;
|
||||||
|
|
||||||
write_mr_value_list[i] =
|
if (has_value_cache_) {
|
||||||
register_memory_region(ctx->pd,
|
void* val_ptr =
|
||||||
val_ptr,
|
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
||||||
size,
|
|
||||||
"client_value_" + std::to_string(i),
|
write_mr_value_list[i] =
|
||||||
access_flags);
|
register_memory_region(ctx->pd,
|
||||||
if (!write_mr_value_list[i]) goto fail;
|
val_ptr,
|
||||||
|
size,
|
||||||
|
"client_value_" + std::to_string(i),
|
||||||
|
access_flags);
|
||||||
|
if (!write_mr_value_list[i]) goto fail;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -812,8 +852,6 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
|
|||||||
|
|
||||||
for (int i = 0; i < layer_number; ++i) {
|
for (int i = 0; i < layer_number; ++i) {
|
||||||
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
|
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
|
||||||
void* val_ptr =
|
|
||||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
|
||||||
size_t size = static_cast<size_t>(block_size_byte) * block_number;
|
size_t size = static_cast<size_t>(block_size_byte) * block_number;
|
||||||
|
|
||||||
struct ibv_mr* key_mr = register_memory_region(
|
struct ibv_mr* key_mr = register_memory_region(
|
||||||
@@ -822,21 +860,25 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
|
|||||||
ERR("Failed to register key MR at layer %d", i);
|
ERR("Failed to register key MR at layer %d", i);
|
||||||
goto fail;
|
goto fail;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ibv_mr* value_mr = register_memory_region(
|
|
||||||
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
|
|
||||||
if (!value_mr) {
|
|
||||||
ERR("Failed to register value MR at layer %d", i);
|
|
||||||
ibv_dereg_mr(key_mr);
|
|
||||||
goto fail;
|
|
||||||
}
|
|
||||||
|
|
||||||
write_cache_key_server_mr_list.push_back(key_mr);
|
write_cache_key_server_mr_list.push_back(key_mr);
|
||||||
write_cache_value_server_mr_list.push_back(value_mr);
|
|
||||||
|
if (has_value_cache_) {
|
||||||
|
void* val_ptr =
|
||||||
|
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
||||||
|
struct ibv_mr* value_mr = register_memory_region(
|
||||||
|
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
|
||||||
|
if (!value_mr) {
|
||||||
|
ERR("Failed to register value MR at layer %d", i);
|
||||||
|
ibv_dereg_mr(key_mr);
|
||||||
|
goto fail;
|
||||||
|
}
|
||||||
|
write_cache_value_server_mr_list.push_back(value_mr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list;
|
ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list;
|
||||||
ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list;
|
ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
fail:
|
fail:
|
||||||
@@ -899,8 +941,12 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
|||||||
|
|
||||||
uint32_t cache_key_rkey =
|
uint32_t cache_key_rkey =
|
||||||
ctx->conn.write_cache_key_remote_rkey_list[layer_idx];
|
ctx->conn.write_cache_key_remote_rkey_list[layer_idx];
|
||||||
uint32_t cache_value_rkey =
|
|
||||||
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
uint32_t cache_value_rkey = 0;
|
||||||
|
if (has_value_cache_) {
|
||||||
|
cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
|
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
|
||||||
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
|
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
|
||||||
uint64_t offset_in_block =
|
uint64_t offset_in_block =
|
||||||
@@ -914,15 +960,19 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
|||||||
cache_key_remote_addr[block_index] = (uint64_t(
|
cache_key_remote_addr[block_index] = (uint64_t(
|
||||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||||
offset_in_block));
|
offset_in_block));
|
||||||
char_ptr = static_cast<char*>(
|
|
||||||
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
if (has_value_cache_) {
|
||||||
cache_value_remote_addr[block_index] = (uint64_t(
|
char_ptr = static_cast<char*>(
|
||||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||||
offset_in_block));
|
cache_value_remote_addr[block_index] = (uint64_t(
|
||||||
|
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||||
|
offset_in_block));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->conn.wc_target_count = 0;
|
ctx->conn.wc_target_count = 0;
|
||||||
for (int i = 0; i < 2; ++i) {
|
int loop_count = has_value_cache_ ? 2 : 1;
|
||||||
|
for (int i = 0; i < loop_count; ++i) {
|
||||||
bool is_key = (i == 0);
|
bool is_key = (i == 0);
|
||||||
uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey);
|
uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey);
|
||||||
std::vector<uint64_t>& remote_addr =
|
std::vector<uint64_t>& remote_addr =
|
||||||
@@ -1038,6 +1088,10 @@ void RDMACommunicator::prepare_write_requests(
|
|||||||
bool is_key,
|
bool is_key,
|
||||||
std::vector<uint64_t>& remote_addr,
|
std::vector<uint64_t>& remote_addr,
|
||||||
uint32_t rkey) {
|
uint32_t rkey) {
|
||||||
|
if (!is_key) {
|
||||||
|
assert(!write_mr_value_list.empty() &&
|
||||||
|
"Trying to process Value Cache but it is empty!");
|
||||||
|
}
|
||||||
auto block_num = local_block_ids.size();
|
auto block_num = local_block_ids.size();
|
||||||
|
|
||||||
for (size_t i = 0; i < block_num; ++i) {
|
for (size_t i = 0; i < block_num; ++i) {
|
||||||
|
|||||||
@@ -40,11 +40,10 @@ class RDMACommManager:
|
|||||||
try:
|
try:
|
||||||
import rdma_comm
|
import rdma_comm
|
||||||
except:
|
except:
|
||||||
logger.error(
|
raise RuntimeError(
|
||||||
"The installation of the RDMA library failed."
|
"The installation of the RDMA library failed."
|
||||||
"Confirm whether your network card supports RDMA transmission."
|
"Confirm whether your network card supports RDMA transmission."
|
||||||
)
|
)
|
||||||
return
|
|
||||||
self.messager = rdma_comm.RDMACommunicator(
|
self.messager = rdma_comm.RDMACommunicator(
|
||||||
splitwise_role,
|
splitwise_role,
|
||||||
gpu_id,
|
gpu_id,
|
||||||
|
|||||||
@@ -755,8 +755,9 @@ class LLMEngine:
|
|||||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
ctx = multiprocessing.get_context("spawn")
|
||||||
self.dp_processed.append(
|
self.dp_processed.append(
|
||||||
multiprocessing.Process(
|
ctx.Process(
|
||||||
target=start_data_parallel_service,
|
target=start_data_parallel_service,
|
||||||
args=(
|
args=(
|
||||||
self.cfg,
|
self.cfg,
|
||||||
|
|||||||
@@ -205,7 +205,6 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
self.group_size,
|
self.group_size,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLA
|
# MLA
|
||||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||||
@@ -279,6 +278,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
"none",
|
"none",
|
||||||
getattr(forward_meta, "max_input_length", -1),
|
getattr(forward_meta, "max_input_length", -1),
|
||||||
)
|
)
|
||||||
@@ -422,10 +422,10 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
"none",
|
"none",
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FA
|
# FA
|
||||||
fmha_out = self.flash_attn_func(
|
fmha_out = self.flash_attn_func(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -307,6 +307,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
"none",
|
"none",
|
||||||
getattr(forward_meta, "max_input_length", -1),
|
getattr(forward_meta, "max_input_length", -1),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -258,10 +258,10 @@ class FusedMoE(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||||
|
|
||||||
if not param._is_initialized():
|
|
||||||
param.initialize()
|
|
||||||
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
|
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
|
||||||
return
|
return
|
||||||
|
if not param._is_initialized():
|
||||||
|
param.initialize()
|
||||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
if shard_id is None:
|
if shard_id is None:
|
||||||
# 1.gate up fused in disk
|
# 1.gate up fused in disk
|
||||||
|
|||||||
@@ -341,6 +341,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
|
|
||||||
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
|
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
|
||||||
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
|
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
|
||||||
|
|
||||||
query, compressed_kv, key_pe = qkv_a_out.split(
|
query, compressed_kv, key_pe = qkv_a_out.split(
|
||||||
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
|
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
|
||||||
)
|
)
|
||||||
@@ -399,6 +400,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim),
|
self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
fmha_out_decode = self.mla_attn(
|
fmha_out_decode = self.mla_attn(
|
||||||
q=q_input,
|
q=q_input,
|
||||||
k=None,
|
k=None,
|
||||||
@@ -418,6 +420,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
.transpose([1, 0, 2])
|
.transpose([1, 0, 2])
|
||||||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||||
)
|
)
|
||||||
|
|
||||||
if fmha_out is None:
|
if fmha_out is None:
|
||||||
fmha_out = fmha_out_decode
|
fmha_out = fmha_out_decode
|
||||||
else:
|
else:
|
||||||
@@ -515,6 +518,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
|||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@@ -674,7 +678,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
|
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
|
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in loaded_weight_name:
|
if weight_name not in loaded_weight_name:
|
||||||
continue
|
continue
|
||||||
@@ -741,6 +744,20 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
)
|
)
|
||||||
return position_ids, mask_encoder_batch
|
return position_ids, mask_encoder_batch
|
||||||
|
|
||||||
|
def empty_input_forward(self):
|
||||||
|
"""
|
||||||
|
empty_input_forward
|
||||||
|
"""
|
||||||
|
fake_hidden_states = paddle.empty(
|
||||||
|
shape=[1, self.fd_config.model_config.hidden_size],
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
|
for i in range(
|
||||||
|
self.fd_config.model_config.first_k_dense_replace,
|
||||||
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
|
):
|
||||||
|
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
|||||||
@@ -2328,7 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||||
group=self.parallel_config.tp_group,
|
group=self.parallel_config.tp_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Post Process
|
# 5. Post Process
|
||||||
model_output_data = ModelOutputData(
|
model_output_data = ModelOutputData(
|
||||||
next_tokens=self.share_inputs["next_tokens"],
|
next_tokens=self.share_inputs["next_tokens"],
|
||||||
|
|||||||
Reference in New Issue
Block a user