// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri * Dao. Licensed under the BSD 3-Clause. * * Modified by the FlashInfer team. */ #include #include #include "utils.cuh" namespace mla_attn { using namespace cute; template struct MaxOp { __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } }; template struct SumOp { __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } }; template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ __forceinline__ T run(T x, Operator& op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; template <> struct Allreduce<2> { template static __device__ __forceinline__ T run(T x, Operator& op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; template __device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { summary(mi) = op(summary(mi), tensor(mi, ni)); } } } template __device__ __forceinline__ void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++) { dst(i) = Allreduce<4>::run(src(i), op); } } template __device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor& max) { MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor& sum) { SumOp sum_op; thread_reduce_(tensor, sum, sum_op); if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } } template __forceinline__ __device__ void apply_exp2(Tensor& tensor, Tensor const& max) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { auto row_max = max(mi); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { tensor(mi, ni) = __expf(tensor(mi, ni) - row_max); } } } template __forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { auto row_max = max(mi); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // row_max * scale is a constant for each row, so we can use fma here tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale); } } } template struct OnlineSoftmax { constexpr static float fill_value = -5e4; using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum, scores_scale; float sm_scale_log2; CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) { clear(scores_scale); }; __forceinline__ __device__ TensorT get_lse() const { return row_sum; } template __forceinline__ __device__ TensorT update(Tensor0& acc_s) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); if constexpr (init) { reduce_max(scores, row_max); if constexpr (WITH_SCALE) { scale_apply_exp2(scores, row_max, sm_scale_log2); } else { apply_exp2(scores, row_max); } reduce_sum(scores, row_sum); } else { // update row_max Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); reduce_max(scores, row_max); // update scores_scale and scale row_sum #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = row_max(mi); if constexpr (WITH_SCALE) { scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2); } else { scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur); } row_sum(mi) *= scores_scale(mi); } // perform exp2 on scores if constexpr (WITH_SCALE) { scale_apply_exp2(scores, row_max, sm_scale_log2); } else { apply_exp2(scores, row_max); } // update row_sum reduce_sum(scores, row_sum); return scores_scale; } }; template __forceinline__ __device__ TensorT finalize(Tensor0& acc_s) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float sum = row_sum(mi); float inv_sum = 1.f / sum; scores_scale(mi) = inv_sum; row_max(mi) *= sm_scale_log2; } return scores_scale; }; template __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } } }; template __forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) { // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale_input(mi); } } }; }; } // namespace mla_attn