mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
457 lines
17 KiB
Plaintext
457 lines
17 KiB
Plaintext
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
/*
|
|
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
|
* Dao. Licensed under the BSD 3-Clause.
|
|
*
|
|
* Modified by the FlashInfer team.
|
|
*/
|
|
|
|
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
|
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
|
|
|
#include <cuda.h>
|
|
#include <cuda_device_runtime_api.h>
|
|
#include <cutlass/arch/reg_reconfig.h>
|
|
#include <cutlass/array.h>
|
|
#include <cutlass/cutlass.h>
|
|
#include <cutlass/numeric_conversion.h>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
#include <type_traits>
|
|
#include <vector>
|
|
|
|
#include "attention_updater.cuh"
|
|
#include "cute/tensor.hpp"
|
|
#include "cutlass/pipeline/pipeline.hpp"
|
|
#include "epilogue.cuh"
|
|
#include "helper.h"
|
|
#include "kernel_traits.cuh"
|
|
#include "mainloop_mma.cuh"
|
|
#include "mainloop_load.cuh"
|
|
#include "utils.cuh"
|
|
|
|
#ifdef DEBUG_MLA
|
|
#undef DEBUG_MLA
|
|
#endif
|
|
// #define DEBUG_MLA
|
|
|
|
namespace mla_attn {
|
|
|
|
using namespace cute;
|
|
|
|
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
|
struct Params {
|
|
using DTypeQ = DTypeQ_;
|
|
using DTypeKV = DTypeKV_;
|
|
using DTypeO = DTypeO_;
|
|
using IdType = IdType_;
|
|
|
|
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
|
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
|
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
|
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
|
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
|
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
|
|
|
alignas(16) IdType *block_tables;
|
|
alignas(16) IdType *seq_lens_this_time;
|
|
alignas(16) IdType *seq_lens_decoder;
|
|
alignas(16) IdType *cumsum_q_seqlens;
|
|
alignas(16) IdType *batch_id_per_token;
|
|
|
|
alignas(16) IdType *batch_ids;
|
|
alignas(16) IdType *tile_ids_per_batch;
|
|
alignas(16) IdType *num_blocks_x;
|
|
alignas(16) IdType *chunk_size_device;
|
|
|
|
uint32_t q_stride_bsz;
|
|
uint32_t q_stride_head_num;
|
|
|
|
uint32_t kv_stride_block_num;
|
|
uint32_t kv_stride_block_size;
|
|
|
|
uint32_t o_stride_bsz;
|
|
uint32_t o_stride_head_num;
|
|
|
|
int bsz;
|
|
int token_num;
|
|
int max_block_num;
|
|
int max_block_num_per_seq;
|
|
int q_num_head;
|
|
int qk_head_dim;
|
|
int vo_head_dim;
|
|
int block_size;
|
|
int max_draft_token_num;
|
|
int chunk_num;
|
|
|
|
float sm_scale;
|
|
};
|
|
|
|
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
|
if (group_size == 8) { \
|
|
constexpr size_t GROUP_SIZE = 8; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 16) { \
|
|
constexpr size_t GROUP_SIZE = 16; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 64) { \
|
|
constexpr size_t GROUP_SIZE = 64; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the group_size: ", group_size); \
|
|
return cudaErrorNotSupported; \
|
|
}
|
|
|
|
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
|
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
|
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
|
typename CollectiveMainloop::Params const mainloop_params,
|
|
CUTE_GRID_CONSTANT
|
|
typename CollectiveEpilogue::Params const epilogue_params) {
|
|
|
|
using DTypeQ = typename Ktraits::DTypeQ;
|
|
using DTypeKV = typename Ktraits::DTypeKV;
|
|
using DTypeO = typename Ktraits::DTypeO;
|
|
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
|
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
|
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
|
|
|
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
|
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
|
|
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
|
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
|
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
|
const int chunk_size = mainloop_params.chunk_size_device[0];
|
|
|
|
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
|
|
|
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
|
using PipelineParams = typename MainloopPipeline::Params;
|
|
using PipelineState = typename MainloopPipeline::PipelineState;
|
|
|
|
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
|
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
|
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
|
|
|
extern __shared__ char shared_memory[];
|
|
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
|
|
|
int const lane_predicate = cute::elect_one_sync();
|
|
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
|
|
|
if (warp_idx == 0 && lane_predicate) {
|
|
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
|
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
|
|
}
|
|
|
|
// Obtain warp index
|
|
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
|
|
|
PipelineParams pipeline_params;
|
|
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
|
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
|
: MainloopPipeline::ThreadCategory::Consumer;
|
|
if constexpr (use_tma_load_kv) {
|
|
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
|
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
|
} else {
|
|
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
|
|
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
|
|
}
|
|
|
|
PipelineParamsQ pipeline_params_q;
|
|
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
|
: MainloopPipelineQ::ThreadCategory::Consumer;
|
|
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
|
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
|
|
|
|
|
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
|
MainloopPipeline pipeline_kv = [&] {
|
|
if constexpr (use_tma_load_kv) {
|
|
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
|
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
|
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
|
} else {
|
|
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
|
}
|
|
}();
|
|
__syncthreads();
|
|
|
|
CollectiveMainloop collective_mainloop;
|
|
CollectiveEpilogue collective_epilogue;
|
|
|
|
if (warp_group_idx == 0) {
|
|
// producer
|
|
if constexpr(USE_REG_EALLOC) {
|
|
cutlass::arch::warpgroup_reg_dealloc<72>();
|
|
}
|
|
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
|
|
|
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
|
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
|
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
|
const int bid = mainloop_params.batch_ids[i];
|
|
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
|
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
|
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
|
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
|
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
|
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
|
|
|
// load Q
|
|
collective_mainloop.load_q(
|
|
mainloop_params,
|
|
pipeline_q,
|
|
smem_pipe_write_q,
|
|
shared_storage,
|
|
threadIdx.x,
|
|
bid);
|
|
|
|
if constexpr (!use_tma_load_kv) {
|
|
// load kv
|
|
collective_mainloop.load_kv(
|
|
mainloop_params,
|
|
pipeline_kv,
|
|
smem_pipe_write_kv,
|
|
shared_storage,
|
|
bid,
|
|
seq_len_decoder_now,
|
|
tile_id
|
|
);
|
|
} else {
|
|
if (warp_idx_in_warpgroup == 0) {
|
|
// load kv tma
|
|
collective_mainloop.load_kv_tma(
|
|
mainloop_params,
|
|
pipeline_kv,
|
|
smem_pipe_write_kv,
|
|
shared_storage,
|
|
bid,
|
|
seq_len_decoder_now,
|
|
tile_id
|
|
);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// consumer
|
|
if constexpr(USE_REG_EALLOC) {
|
|
cutlass::arch::warpgroup_reg_alloc<216>();
|
|
}
|
|
PipelineStateQ smem_pipe_read_q;
|
|
PipelineState smem_pipe_read_kv;
|
|
|
|
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
|
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
|
|
|
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
|
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
|
clear(tOrO);
|
|
clear(attention_updater.scores_scale);
|
|
const int bid = mainloop_params.batch_ids[i];
|
|
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
|
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
|
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
|
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
|
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
|
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
|
|
|
if constexpr (BLOCK_SHAPE_KV == 64) {
|
|
mma_f16<Ktraits, CAUSAL>(
|
|
mainloop_params,
|
|
pipeline_q,
|
|
smem_pipe_read_q,
|
|
pipeline_kv,
|
|
smem_pipe_read_kv,
|
|
tOrO,
|
|
attention_updater,
|
|
threadIdx.x - NUM_COPY_THREADS,
|
|
bid,
|
|
seq_len_decoder_now,
|
|
seq_len_now,
|
|
tile_id,
|
|
shared_storage);
|
|
} else if (BLOCK_SHAPE_KV == 32) {
|
|
mma_f16_two_stages<Ktraits, CAUSAL>(
|
|
mainloop_params,
|
|
pipeline_q,
|
|
smem_pipe_read_q,
|
|
pipeline_kv,
|
|
smem_pipe_read_kv,
|
|
tOrO,
|
|
attention_updater,
|
|
threadIdx.x - NUM_COPY_THREADS,
|
|
bid,
|
|
seq_len_decoder_now,
|
|
seq_len_now,
|
|
tile_id,
|
|
shared_storage);
|
|
}
|
|
|
|
collective_epilogue.store(
|
|
epilogue_params,
|
|
tOrO,
|
|
attention_updater.get_lse(),
|
|
shared_storage,
|
|
tiled_mma_pv,
|
|
threadIdx.x - NUM_COPY_THREADS,
|
|
bid,
|
|
mainloop_params.bsz,
|
|
seq_len_now,
|
|
start_token_idx,
|
|
tile_id,
|
|
seq_len_decoder_now,
|
|
chunk_size,
|
|
mainloop_params.max_draft_token_num,
|
|
mainloop_params.o_stride_bsz);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
|
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
|
cudaStream_t stream) {
|
|
using DTypeQ = typename KernelTraits::DTypeQ;
|
|
using DTypeKV = typename KernelTraits::DTypeKV;
|
|
using DTypeO = typename KernelTraits::DTypeO;
|
|
using IdType = typename KernelTraits::IdType;
|
|
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
|
|
|
using CollectiveMainloop =
|
|
CollectiveMainloop<KernelTraits, CAUSAL>;
|
|
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
|
|
|
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
|
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
|
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
|
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
|
params.Q,
|
|
params.KV,
|
|
params.m,
|
|
params.d,
|
|
params.block_tables,
|
|
params.seq_lens_this_time,
|
|
params.seq_lens_decoder,
|
|
params.cumsum_q_seqlens,
|
|
params.batch_ids,
|
|
params.tile_ids_per_batch,
|
|
params.num_blocks_x,
|
|
params.chunk_size_device,
|
|
params.sm_scale,
|
|
params.bsz,
|
|
params.max_block_num,
|
|
params.max_block_num_per_seq,
|
|
params.q_stride_bsz,
|
|
params.q_stride_head_num,
|
|
params.kv_stride_block_num,
|
|
params.kv_stride_block_size,
|
|
params.o_stride_bsz,
|
|
params.o_stride_head_num,
|
|
params.chunk_num,
|
|
params.max_draft_token_num
|
|
});
|
|
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
|
params.O,
|
|
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
|
params.O_tmp,
|
|
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
|
});
|
|
|
|
// Get the ptr to kernel function.
|
|
auto kernel =
|
|
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
|
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
|
int device;
|
|
cudaGetDevice(&device);
|
|
int multiprocessor_count;
|
|
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
|
int act_blocks_per_sm;
|
|
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
|
|
|
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
|
|
// by the graph.
|
|
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
|
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
|
dim3 block_dims(ctaSize, 1, 1);
|
|
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
|
mainloop_params, epilogue_params
|
|
);
|
|
if (params.chunk_num > 1) {
|
|
constexpr int vec_size = 16 / sizeof(DTypeO);
|
|
constexpr int merge_block_size = 256;
|
|
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
|
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
|
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
|
dim3 blocks_merge(blockx, blocky);
|
|
merge_multi_chunks_kernel<NV_TYPE,
|
|
vec_size,
|
|
blocky,
|
|
KernelTraits::HEAD_DIM_VO>
|
|
<<<grids_merge, blocks_merge, 0, stream>>>(
|
|
reinterpret_cast<NV_TYPE *>(params.O_tmp),
|
|
params.m,
|
|
params.d,
|
|
params.seq_lens_this_time,
|
|
params.seq_lens_decoder,
|
|
params.cumsum_q_seqlens,
|
|
params.batch_id_per_token,
|
|
params.chunk_size_device,
|
|
reinterpret_cast<NV_TYPE *>(params.O),
|
|
params.q_num_head,
|
|
params.vo_head_dim,
|
|
params.token_num,
|
|
params.bsz,
|
|
params.max_draft_token_num);
|
|
}
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
|
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
|
constexpr bool CAUSAL = true;
|
|
if constexpr (HEAD_DIM_QK == 576) {
|
|
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
|
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
|
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
|
HEAD_DIM_QK,
|
|
HEAD_DIM_VO,
|
|
GROUP_SIZE,
|
|
/*BLOCK_SHAPE_Q_=*/64,
|
|
/*BLOCK_SHAPE_KV_=*/64,
|
|
/*NUM_STAGES_=*/2,
|
|
typename Params::DTypeQ,
|
|
typename Params::DTypeKV,
|
|
typename Params::DTypeO,
|
|
typename Params::IdType,
|
|
NV_TYPE>,
|
|
CAUSAL,
|
|
Params,
|
|
USE_REG_EALLOC,
|
|
USE_FIXED_BLOCK>(params, stream);)
|
|
} else {
|
|
return cudaErrorNotSupported;
|
|
}
|
|
return cudaSuccess;
|
|
};
|
|
|
|
} // namespace mla_attn
|
|
|
|
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_
|