// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef ATTENTION_HOPPER_UTILS_CUH_ #define ATTENTION_HOPPER_UTILS_CUH_ #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_ref.h" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/tensor_view_io.h" #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include #include #include #include #include #include #include #include "cutlass/fast_math.h" namespace mla_attn { using namespace cute; template CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) { Tensor tensor_flatten = cute::flatten(tensor); return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten); } CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride, int64_t h_stride) { return make_layout(make_shape(nnz, head_dim, num_heads), make_stride(n_stride, cute::_1{}, h_stride)); } CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) { return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads))); } template CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, int head_idx, int offset, int seq_len) { auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)), make_coord(offset, _0{})); auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride())); auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); return g_tensor; } template CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, int head_idx, int offset, int seq_len) { auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset)); auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len), cute::make_shape(shape<0>(m_tensor)))); auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_)); return g_tensor; } // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, // MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); auto l = acc_layout; return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); }; // For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, // MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); }; template __forceinline__ __device__ auto convert_type(Tensor const& tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; auto frag = convert_op(*reinterpret_cast*>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } template __forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB, TensorC& tCrC) { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); warpgroup_arrive(); if constexpr (init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } #define HOSTDEVICE __host__ __device__ template struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; HOSTDEVICE inline const T& operator[](int i) const { return val[i]; } HOSTDEVICE inline T& operator[](int i) { return val[i]; } }; template HOSTDEVICE inline void Load(const T* addr, AlignedVector* vec) { const AlignedVector* addr_vec = reinterpret_cast*>(addr); *vec = *addr_vec; } template HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { AlignedVector* addr_vec = reinterpret_cast*>(addr); *addr_vec = vec; } template struct prefill_softmax_state_t { AlignedVector o; float m; float d; __device__ __forceinline__ void init() { if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { *((half2*)(&o) + i) = make_half2(0, 0); } } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); } } d = 1.f; if constexpr (std::is_same::value) { m = -5e4f; } else if constexpr (std::is_same::value) { m = -3.38953e38f; } } __device__ __forceinline__ void merge(const AlignedVector& other_o, const float other_m, const float other_d) { float m_prev = m, d_prev = d; m = max(m_prev, other_m); const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); const T scale1_T = static_cast(scale1), scale2_T = static_cast(scale2); d = d_prev * scale1 + other_d * scale2; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] = o[i] * scale1_T + other_o[i] * scale2_T; } } __device__ __forceinline__ void normalize() { const T d_t = static_cast(d); #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] /= d_t; } } }; template __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim] const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads] const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads] const int * __restrict__ seq_lens_this_time, const int * __restrict__ seq_lens_decoder, const int *__restrict__ cu_seqlens_q, const int * __restrict__ batch_id_per_token, const int * __restrict__ chunk_size_device, T * __restrict__ out, // [token_num, num_heads, head_dim] const int num_heads, const int head_dim, const int token_num, const int bsz, const int max_draft_token_num) { const int vid = threadIdx.x, ty = threadIdx.y; const int hid = blockIdx.y; __shared__ T smem[bdy * HEAD_DIM]; __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; // NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs. if (bid == -1) continue; const int seq_len_q = seq_lens_this_time[bid]; if (seq_len_q == 0) continue; const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; int seq_len_kv = seq_lens_decoder[bid]; if (seq_len_kv == 0) continue; seq_len_kv += seq_len_q; const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]); if (num_chunks_this_seq <= 1) { // not need merge continue; } using LoadT = AlignedVector; LoadT load_vec; LoadT res_vec; if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { *((half2*)(&res_vec) + i) = make_half2(0, 0); } } else { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); } } float m; float d = 1.f; if constexpr (std::is_same::value) { m = -5e4f; } else if constexpr (std::is_same::value) { m = -3.0e+30f; } for (int i = ty; i < num_chunks_this_seq; i += bdy) { uint32_t offset; offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid; float m_prev = m; float d_prev = d; const float m_now = multi_m[offset]; const float d_now = multi_d[offset]; m = max(m_prev, m_now); offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size; Load(&multi_out[offset], &load_vec); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const T scale1_T = static_cast(scale1), scale2_T = static_cast(scale2); d = d * scale1 + d_now * scale2; #pragma unroll for (int j = 0; j < vec_size; j++) { res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; } } // store ty res Store(res_vec, &smem[ty * head_dim + vid * vec_size]); md_smem[2 * ty] = m; md_smem[2 * ty + 1] = d; __syncthreads(); if (ty == 0) { // merge bdy prefill_softmax_state_t st; st.init(); #pragma unroll for (int i = 0; i < bdy; i++) { Load(&smem[i * head_dim + vid * vec_size], &load_vec); const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; st.merge(load_vec, m_tmp, d_tmp); } st.normalize(); Store(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); } __syncthreads(); } } } // namespace mla_attn #endif // ATTENTION_HOPPER_UTILS_CUH_