mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 12:52:29 +08:00
【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
@@ -11,10 +11,11 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
@@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ num_blocks_x,
|
||||
int *__restrict__ res_chunk_size,
|
||||
const int bsz,
|
||||
const int set_chunk_size,
|
||||
const int block_size,
|
||||
const int sm_cout) {
|
||||
const uint32_t conf_id = threadIdx.x;
|
||||
int gridx = 0;
|
||||
if (set_chunk_size > 0 && conf_id == 0) {
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
*res_chunk_size = set_chunk_size;
|
||||
} else if (conf_id < config_size) {
|
||||
__shared__ int gridx_shared[config_size];
|
||||
// chunk_size is a multiple of 64
|
||||
const int chunk_size = block_size << conf_id;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
gridx_shared[conf_id] = gridx;
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
const int bsz,
|
||||
const int chunk_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
|
||||
if (seq_len == 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
@@ -197,7 +285,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -230,8 +320,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
@@ -241,6 +329,106 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int sm_cout;
|
||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
||||
constexpr int config_size =
|
||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
||||
|
||||
search_chunk_size_for_mla<config_size>
|
||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
decoder_chunk_size_device.data<int>(),
|
||||
bsz,
|
||||
set_chunk_size,
|
||||
block_size,
|
||||
sm_cout);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
auto decoder_chunk_size_cpu =
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
bsz,
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
@@ -272,28 +460,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -303,7 +469,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
@@ -105,7 +105,7 @@ void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
@@ -305,7 +305,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -473,23 +475,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_num_blocks_device,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
|
@@ -59,6 +59,15 @@ inline uint32_t get_cascade_attention_num_threads() {
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
||||
inline int get_mla_dec_chunk_size(int bsz) {
|
||||
static const char* mla_dec_chunk_size_env =
|
||||
std::getenv("FLAGS_mla_dec_chunk_size");
|
||||
static const int mla_dec_chunk_size =
|
||||
mla_dec_chunk_size_env == nullptr
|
||||
? -1
|
||||
: std::stoi(std::string(mla_dec_chunk_size_env));
|
||||
return bsz > 1 ? mla_dec_chunk_size : 64;
|
||||
}
|
||||
|
@@ -563,3 +563,11 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
inline bool GetMlaUseTensorcore() {
|
||||
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
|
||||
const bool mla_use_tensorcore =
|
||||
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
||||
return mla_use_tensorcore;
|
||||
}
|
||||
|
@@ -70,7 +70,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -78,9 +77,8 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -97,14 +95,12 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
int num_chunks = div_up(max_seq_len, 64);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
@@ -127,14 +123,14 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.chunk_size_device =
|
||||
const_cast<int*>(decoder_chunk_size_device.data<int>());
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
@@ -151,7 +147,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
@@ -176,7 +171,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -184,9 +178,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -210,7 +203,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -218,9 +210,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -47,7 +47,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -55,9 +54,8 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -128,12 +128,13 @@ struct CollectiveMainloop {
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
// IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
IdType const* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -144,7 +145,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
// int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
@@ -160,12 +161,13 @@ struct CollectiveMainloop {
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
// IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
IdType* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -176,7 +178,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
// int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
@@ -198,12 +200,13 @@ struct CollectiveMainloop {
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
// const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
const_cast<IdType*>(args.chunk_size_device),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
@@ -214,7 +217,7 @@ struct CollectiveMainloop {
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
// args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
@@ -281,9 +284,9 @@ struct CollectiveMainloop {
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
@@ -322,9 +325,9 @@ struct CollectiveMainloop {
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
|
@@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
@@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
|
@@ -62,13 +62,12 @@ struct Params {
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
@@ -76,7 +75,7 @@ struct Params {
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
@@ -96,9 +95,7 @@ struct Params {
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
@@ -118,7 +115,7 @@ struct Params {
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
@@ -137,6 +134,7 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
const int chunk_size = mainloop_params.chunk_size_device[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
@@ -205,58 +203,10 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -309,76 +259,12 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -429,15 +315,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
@@ -460,12 +346,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
@@ -476,7 +362,6 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
@@ -500,13 +385,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
|
||||
// by the graph.
|
||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
@@ -517,37 +398,38 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
merge_multi_chunks_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
KernelTraits::HEAD_DIM_VO>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
params.chunk_size_device,
|
||||
reinterpret_cast<NV_TYPE *>(params.O),
|
||||
params.q_num_head,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
|
@@ -249,18 +249,16 @@ struct prefill_softmax_state_t {
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ batch_id_per_token,
|
||||
const int * __restrict__ chunk_size_device,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
@@ -271,13 +269,15 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
|
||||
if (bid == -1) continue;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
|
@@ -22,23 +22,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -64,9 +59,12 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
// NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage
|
||||
// int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
//
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
@@ -96,7 +94,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
@@ -104,9 +101,8 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_num_blocks_data,
|
||||
decoder_chunk_size_device,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -145,23 +141,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -208,23 +199,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
decoder_chunk_size_device,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -254,23 +240,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
decoder_chunk_size_device,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -307,23 +288,18 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& decoder_chunk_size_device_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
@@ -361,23 +337,18 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& decoder_chunk_size_device_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
@@ -415,23 +386,18 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"batch_id_per_token",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"decoder_chunk_size_device",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 81 KiB |
Binary file not shown.
Before Width: | Height: | Size: 81 KiB |
@@ -99,6 +99,8 @@ class ForwardMeta:
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of blocks that attention backend can use in decode stage
|
||||
decoder_num_blocks_device: Optional[paddle.Tensor] = None
|
||||
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
|
||||
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||
# A tensor that holds multiple lengths related to prefill or decode stages.
|
||||
@@ -118,6 +120,8 @@ class ForwardMeta:
|
||||
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
|
||||
max_len_kv_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
decoder_chunk_size_device: Optional[paddle.Tensor] = None
|
||||
|
||||
# Sequence length of encoder for ever batch
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
# Sequence length of Encoder for ever batch
|
||||
|
@@ -141,6 +141,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
|
@@ -198,6 +198,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
|
@@ -187,9 +187,11 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
|
||||
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
@@ -347,23 +349,18 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
metadata.max_dec_len_this_time,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
@@ -468,23 +465,18 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
metadata.max_dec_len_this_time,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
|
@@ -30,7 +30,9 @@ def get_block_shape_and_split_kv_block(
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
decoder_batch_ids: paddle.Tensor,
|
||||
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
decoder_num_blocks_cpu: paddle.Tensor,
|
||||
decoder_num_blocks_device: paddle.Tensor,
|
||||
decoder_chunk_size_device: paddle.Tensor,
|
||||
max_len_tensor_cpu: paddle.Tensor,
|
||||
encoder_batch_ids: paddle.Tensor,
|
||||
encoder_tile_ids_per_batch: paddle.Tensor,
|
||||
@@ -55,7 +57,9 @@ def get_block_shape_and_split_kv_block(
|
||||
seq_lens_this_time,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu,
|
||||
decoder_num_blocks_cpu,
|
||||
decoder_num_blocks_device,
|
||||
decoder_chunk_size_device,
|
||||
max_len_tensor_cpu,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
|
@@ -210,6 +210,12 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_device"]
|
||||
)
|
||||
self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_chunk_size_device"]
|
||||
)
|
||||
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["max_len_tensor_cpu"]
|
||||
).cpu()
|
||||
@@ -338,6 +344,8 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_batch_ids"] = None
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.model_inputs["decoder_num_blocks_device"] = None
|
||||
self.model_inputs["decoder_chunk_size_device"] = None
|
||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.model_inputs["encoder_batch_ids"] = None
|
||||
self.model_inputs["encoder_tile_ids_per_batch"] = None
|
||||
@@ -528,6 +536,8 @@ class MTPProposer(Proposer):
|
||||
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
|
||||
decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"],
|
||||
decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"],
|
||||
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
|
||||
|
@@ -838,6 +838,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["decoder_num_blocks_device"] = None
|
||||
self.share_inputs["decoder_chunk_size_device"] = None
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.share_inputs["encoder_batch_ids"] = None
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||
@@ -991,6 +993,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
|
||||
self.share_inputs["batch_id_per_token"][:] = -1
|
||||
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||
@@ -1070,6 +1074,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"],
|
||||
decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
@@ -1196,8 +1204,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||
# NOTE: (changwenbin) When using auto_chunk,
|
||||
# decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
|
||||
decode_max_tile_size = (
|
||||
1024
|
||||
* self.parallel_config.max_num_seqs
|
||||
* np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q)
|
||||
)
|
||||
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||
@@ -1208,6 +1220,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
|
@@ -380,10 +380,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
|
||||
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
|
||||
self.seq_lens_this_time = self.seq_lens_encoder
|
||||
|
||||
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12)
|
||||
self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
|
||||
self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
|
||||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
@@ -484,6 +486,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.decoder_batch_ids,
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.decoder_num_blocks_device,
|
||||
self.decoder_chunk_size_device,
|
||||
self.max_len_tensor_cpu,
|
||||
self.encoder_batch_ids,
|
||||
self.encoder_tile_ids_per_batch,
|
||||
|
@@ -378,9 +378,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
|
||||
self.seq_lens_this_time = self.seq_lens_encoder
|
||||
|
||||
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12)
|
||||
self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
|
||||
self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
|
||||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
@@ -464,6 +467,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.decoder_batch_ids,
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.decoder_num_blocks_device,
|
||||
self.decoder_chunk_size_device,
|
||||
self.max_len_tensor_cpu,
|
||||
self.encoder_batch_ids,
|
||||
self.encoder_tile_ids_per_batch,
|
||||
|
@@ -192,7 +192,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_block_shape_q = 16
|
||||
group_size = self.num_q_head // self.num_kv_head
|
||||
decode_max_tile_size = (
|
||||
self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
|
||||
1024 * self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
|
||||
)
|
||||
encode_max_tile_size = (
|
||||
self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q
|
||||
@@ -202,6 +202,8 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
|
||||
decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
@@ -222,6 +224,8 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_device,
|
||||
decoder_chunk_size_device,
|
||||
max_len_tensor_cpu,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
|
Reference in New Issue
Block a user