mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 03:46:40 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			125 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			125 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /******************************************************************************
 | |
|  * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
 | |
|  ******************************************************************************/
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "cute/atom/mma_atom.hpp"
 | |
| #include "cutlass/gemm/collective/collective_builder.hpp"
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/layout/layout.h"
 | |
| #include "cutlass/numeric_types.h"
 | |
| #include "cutlass/pipeline/pipeline.hpp"
 | |
| 
 | |
| using namespace cute;
 | |
| 
 | |
| struct Flash_mask_params {
 | |
|     void *__restrict__ q_ptr;
 | |
|     void *__restrict__ k_ptr;
 | |
|     void *__restrict__ v_ptr;
 | |
|     void * __restrict__ o_ptr;
 | |
|     int * __restrict__ cu_seq_q;
 | |
|     int * __restrict__ cu_seq_k;
 | |
|     int * __restrict__ mask;
 | |
|     int * seq_len_encoder;
 | |
|     int head_num;
 | |
|     int kv_head_num;
 | |
|     int max_seq_len_q;
 | |
|     int max_seq_len_k;
 | |
|     int batch_size;
 | |
|     int gqa_group_size;
 | |
|     float scale_softmax_log2;
 | |
| };
 | |
| 
 | |
| template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
 | |
|           class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
 | |
| struct SharedStorageQKVO {
 | |
|     cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
 | |
|     cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
 | |
|     union {
 | |
|         cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
 | |
|         cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
 | |
|     };
 | |
|     struct {
 | |
|         cutlass::arch::ClusterTransactionBarrier barrier_Q;
 | |
|         typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
 | |
|         typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
 | |
|     };
 | |
| };
 | |
| 
 | |
| template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool NeedMask_, typename elem_type=cutlass::half_t>
 | |
| struct Flash_mask_kernel_traits {
 | |
|     using Element = elem_type;
 | |
|     using ElementAccum = float;
 | |
|     using index_t = int32_t;
 | |
| 
 | |
|     static constexpr int kNWarps = kNWarps_;
 | |
|     static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
 | |
| 
 | |
|     static constexpr int kBlockM = kBlockM_;
 | |
|     static constexpr int kBlockN = kBlockN_;
 | |
|     static constexpr int kHeadDim = kHeadDim_;
 | |
|     static_assert(kHeadDim % 32 == 0);
 | |
|     using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
 | |
|     using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
 | |
|     static constexpr int kStages = kStages_;
 | |
|     static constexpr int NeedMask = NeedMask_;
 | |
| 
 | |
|     using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
 | |
|     using TiledMma0 = decltype(cute::make_tiled_mma(
 | |
|         cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
 | |
|         AtomLayoutMNK{}));
 | |
|     using TiledMma1 = decltype(cute::make_tiled_mma(
 | |
|         cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
 | |
|             GMMA::Major::K, GMMA::Major::MN>(),
 | |
|         AtomLayoutMNK{}));
 | |
| 
 | |
|     using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
 | |
|         decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
 | |
|     using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
 | |
| 
 | |
|     using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
 | |
|         decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
 | |
|     using SmemLayoutK =
 | |
|         decltype(tile_to_shape(SmemLayoutAtomK{},
 | |
|                  make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
 | |
| 
 | |
|     using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
 | |
|         decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
 | |
|     using SmemLayoutV =
 | |
|         decltype(tile_to_shape(SmemLayoutAtomV{},
 | |
|                  make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
 | |
| 
 | |
|     using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
 | |
|         decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
 | |
|     using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
 | |
| 
 | |
|     using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
 | |
|     using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
 | |
| 
 | |
|     using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
 | |
| 
 | |
|     static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
 | |
|     static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
 | |
|     static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
 | |
|     static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
 | |
|     static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
 | |
|     static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
 | |
|     using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
 | |
|     using TiledCopyOThrLayout = decltype(cute::make_layout(
 | |
|         cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
 | |
|         LayoutRight{}));
 | |
|     using TiledCopyOValLayout = decltype(cute::make_layout(
 | |
|         cute::make_shape(_1{}, Int<kNumVecElem>{}),
 | |
|         LayoutRight{}));
 | |
|     using GmemTiledCopyO = decltype(make_tiled_copy(
 | |
|         TiledCopyOAtom{},
 | |
|         TiledCopyOThrLayout{},
 | |
|         TiledCopyOValLayout{}
 | |
|     ));
 | |
| 
 | |
|     using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
 | |
|     using PipelineState = typename cutlass::PipelineState<kStages>;
 | |
| };
 | 
