/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include "utils.hpp" using namespace cute; template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; template<> struct Allreduce<2> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; template __device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { summary(mi) = op(summary(mi), tensor(mi, ni)); } } } template __device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ dst(i) = Allreduce<4>::run(src(i), op); } } template __device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; thread_reduce_(tensor, sum, sum_op); if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } } __forceinline__ __device__ __half2 half_exp(__half2 x) { uint32_t tmp_out, tmp_in; tmp_in = reinterpret_cast(x); asm ("ex2.approx.f16x2 %0, %1;\n" : "=r"(tmp_out) : "r"(tmp_in)); __half2 out = reinterpret_cast<__half2&>(tmp_out); return out; } // Apply the exp to all the elements. template __forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { MaxOp max_op; max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { max(mi) = max_op(max(mi), tensor(mi, ni)); } max(mi) = Allreduce<4>::run(max(mi), max_op); const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); sum(mi) += tensor(mi, ni); } } } template __forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { const float max_scaled = max(mi) * scale; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); } } } template struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; CUTLASS_DEVICE Softmax() {}; template __forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) { Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); TensorT scores_scale; if constexpr (Is_first) { reduce_max(scores, row_max); cute::fill(scores_scale, 1.f); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); reduce_max(scores, row_max); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = row_max(mi); scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale(mi); } } return scores_scale; }; template __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) { Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); TensorT scores_scale; if constexpr (Is_first) { reduce_max(scores, row_max); scale_apply_exp2(scores, row_max, softmax_scale_log2); reduce_sum(scores, row_sum); cute::fill(scores_scale, 1.f); } else { scale_apply_exp2(scores, row_max, softmax_scale_log2); reduce_sum(scores, row_sum); } return scores_scale; }; __forceinline__ __device__ TensorT finalize(float softmax_scale_log2) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT scores_scale; #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float sum = row_sum(mi); float inv_sum = 1.0f / sum; row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); scores_scale(mi) = inv_sum; } return scores_scale; }; template __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } } }; };