mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-27 02:20:31 +08:00 
			
		
		
		
	 05c670e593
			
		
	
	05c670e593
	
	
	
		
			
			* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
		
			
				
	
	
		
			176 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			176 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| /*
 | |
|  * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
 | |
|  * Dao. Licensed under the BSD 3-Clause.
 | |
|  *
 | |
|  * Modified by the FlashInfer team.
 | |
|  */
 | |
| 
 | |
| 
 | |
| #ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
 | |
| #define ATTENTION_HOPPER_EPILOGUE_CUH_
 | |
| 
 | |
| #include <cutlass/cutlass.h>
 | |
| 
 | |
| #include "cute/tensor.hpp"
 | |
| #include "cutlass/gemm/collective/collective_builder.hpp"
 | |
| #include "named_barrier.cuh"
 | |
| #include "utils.cuh"
 | |
| 
 | |
| #ifdef DEBUG_MLA
 | |
| #undef DEBUG_MLA
 | |
| #endif
 | |
| // #define DEBUG_MLA
 | |
| 
 | |
| namespace mla_attn {
 | |
| 
 | |
| using namespace cute;
 | |
| 
 | |
| template <typename Ktraits>
 | |
| struct CollectiveEpilogue {
 | |
|   using DTypeO = typename Ktraits::DTypeO;
 | |
|   static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
 | |
|   static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
 | |
|   static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
 | |
|   using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
 | |
| 
 | |
|   static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
 | |
|   static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
 | |
| 
 | |
|   static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
 | |
|   static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
 | |
| 
 | |
|   using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
 | |
|                                    GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
 | |
|                                    decltype(cute::get<1>(TileShape_PDV{}))>());
 | |
|   using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
 | |
| 
 | |
|   using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
 | |
|   using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
 | |
| 
 | |
|   using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
 | |
|   using StrideT = cute::Shape<int32_t, _1, int32_t>;
 | |
|   using LayoutT = cute::Layout<ShapeT, StrideT>;
 | |
| 
 | |
|   using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
 | |
|   using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
 | |
|   using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
 | |
| 
 | |
|   using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
 | |
|   using StrideNTMAT = cute::Shape<int32_t, _1>;
 | |
|   using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
 | |
| 
 | |
|   using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
 | |
|   using TMA_O = decltype(make_tma_copy(
 | |
|       GmemTiledCopyOTMA{},
 | |
|       make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
 | |
|       select<0, 1>(TileShape_PDV{}), _1{}));  // no mcast for O
 | |
| 
 | |
|   static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
 | |
|   static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
 | |
|   static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
 | |
|   static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
 | |
|   static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
 | |
|   using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
 | |
|   using TiledCopyOThrLayout = decltype(cute::make_layout(
 | |
|       cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
 | |
|   using TiledCopyOValLayout =
 | |
|       decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
 | |
|   using TiledCopyO =
 | |
|       decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{},  // Thr layout
 | |
|                                TiledCopyOValLayout{}                     // Val layout
 | |
|                                ));
 | |
|   struct Arguments {
 | |
|     DTypeO* O_ptr;
 | |
|     LayoutNTMAT const layout_O;
 | |
|     DTypeO* O_ptr_tmp;
 | |
|     LayoutNTMAT const layout_O_tmp;
 | |
|   };
 | |
| 
 | |
|   // Device side kernel params
 | |
|   struct Params {
 | |
|     DTypeO* O_ptr;
 | |
|     LayoutNTMAT const layout_O;
 | |
|     DTypeO* O_ptr_tmp;
 | |
|     LayoutNTMAT const layout_O_tmp;
 | |
|   };
 | |
| 
 | |
|   static Params to_underlying_arguments_ntma(Arguments const& args) {
 | |
|     return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
 | |
|   }
 | |
| 
 | |
|   CUTLASS_DEVICE
 | |
|   static void prefetch_tma_descriptors(Params const& epilogue_params) {}
 | |
| 
 | |
|   template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
 | |
|             typename TiledMma>
 | |
|   CUTLASS_DEVICE void store(Params const& epilogue_params,
 | |
|                             FrgTensorO const& tOrO,
 | |
|                             FrgTensorLSE const& lse,
 | |
|                             SharedStorage& shared_storage,
 | |
|                             TiledMma tiled_mma,
 | |
|                             const int thread_idx,
 | |
|                             const int bid,
 | |
|                             const int bsz,
 | |
|                             const int seq_len_now,
 | |
|                             const int start_token_idx,
 | |
|                             const int tile_idx,
 | |
|                             const int kv_len,
 | |
|                             const int chunk_size,
 | |
|                             const int max_draft_token_num,
 | |
|                             const int o_stride_bsz) {
 | |
|     const int num_chunks = cute::ceil_div(kv_len, chunk_size);
 | |
|     Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
 | |
|     auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
 | |
|     auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
 | |
| 
 | |
|     Tensor tOrO_out = convert_type<DTypeO>(tOrO);
 | |
|     Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);  // ((Atom,AtomNum), MMA_M, MMA_N)
 | |
|     Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
 | |
|     // make sure gemm done
 | |
|     cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
 | |
|                                       /*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
 | |
|     // r2s
 | |
|     cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
 | |
|     // make sure r2s done
 | |
|     cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
 | |
|                                       /*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
 | |
|     TiledCopyO gmem_tiled_copy_O;
 | |
|     auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
 | |
|     Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
 | |
|     Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
 | |
|     Tensor cO = make_identity_tensor(gO.shape());  // (O, D) -> (o_idx, d_idx)
 | |
|     ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
 | |
|     Tensor tOgO = thr_copy_O.partition_D(gO);  // (CPY, CPY_O, CPY_D)
 | |
|     Tensor tOsO = thr_copy_O.partition_S(sO);  // (CPY, CPY_O, CPY_D)
 | |
|     Tensor tOcO = thr_copy_O.partition_D(cO);  // (CPY, CPY_O, CPY_D)
 | |
|     Tensor tOgOGroup = flatten_1(tOgO);        // (CPY, (CPY_O, CPY_D))
 | |
|     Tensor tOsOGroup = flatten_1(tOsO);        // (CPY, (CPY_O, CPY_D))
 | |
|     Tensor tOcOGroup = flatten_1(tOcO);        // (CPY, (CPY_O, CPY_D))
 | |
| 
 | |
|     // copy if not out of bound
 | |
|     auto predicate_fn = [&](auto coords) {
 | |
|       auto s_coords = tOcOGroup(_0{}, coords);
 | |
|       return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
 | |
|     };
 | |
|     copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
 | |
|   }
 | |
| };
 | |
| 
 | |
| }  // namespace mla_attn
 | |
| 
 | |
| #endif  // ATTENTION_HOPPER_EPILOGUE_CUH_
 |