mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	 89397516a8
			
		
	
	89397516a8
	
	
	
		
			
			* init * 增加多线程编译 * fix bug * fix bug * code style * 增加fp16 * 将print替换成assert * 修复stmatrix * 减小单测shape * 减小单测shape
		
			
				
	
	
		
			155 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			155 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #include "cute/algorithm/copy.hpp"
 | |
| #include "cute/atom/mma_atom.hpp"
 | |
| #include "cutlass/gemm/collective/collective_builder.hpp"
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/layout/layout.h"
 | |
| #include "cutlass/numeric_types.h"
 | |
| #include "cutlass/pipeline/pipeline.hpp"
 | |
| 
 | |
| using namespace cute;
 | |
| 
 | |
| template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
 | |
|           class SmemLayoutB, class SmemLayoutC>
 | |
| struct SharedStorage {
 | |
|     union {
 | |
|         struct {
 | |
|             cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
 | |
|             cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
 | |
|         };
 | |
|         cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
 | |
|     };
 | |
| 
 | |
|   struct {
 | |
|     typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
 | |
|   };
 | |
| };
 | |
| 
 | |
| template<int kBlockM_, int kBlockN_, int kBlockK_,
 | |
|         int kNWarps_, int kStages_,
 | |
|         int kTiles_, int M_,
 | |
|         int TokenPackSize_,
 | |
|         int TAIL_N_ = 0,
 | |
|         int kClusterM_ = 1,
 | |
|         typename elem_type=cutlass::float_e4m3_t,
 | |
|         typename OutputType = cutlass::bfloat16_t>
 | |
| struct Kernel_traits {
 | |
|     using Element = elem_type;
 | |
|     using ElementAccum = float;
 | |
|     using ElementOutput = OutputType;
 | |
|     static_assert(cutlass::sizeof_bits_v<Element> == 8);
 | |
| 
 | |
|     static constexpr int kNWarps = kNWarps_;
 | |
|     static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
 | |
|     static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
 | |
|     static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
 | |
| 
 | |
|     static_assert(kNWarps_ == 12 || kNWarps_ == 16);
 | |
| 
 | |
|     static constexpr int kBlockM = kBlockM_;
 | |
|     static constexpr int kBlockN = kBlockN_;
 | |
|     static constexpr int kBlockK = kBlockK_;
 | |
|     static constexpr int kTiles = kTiles_;
 | |
|     static constexpr int TokenPackSize = TokenPackSize_;
 | |
|     static constexpr int M = M_;
 | |
|     static constexpr int TAIL_N = TAIL_N_;
 | |
| 
 | |
|     using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
 | |
|     using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
 | |
| 
 | |
|     static constexpr int kClusterM = kClusterM_;
 | |
|     using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
 | |
| 
 | |
|     static constexpr int kStages = kStages_;
 | |
|     static_assert(kStages > 1);
 | |
| 
 | |
|     using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
 | |
| 
 | |
|     using TiledMma = decltype(cute::make_tiled_mma(
 | |
|         cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
 | |
|         AtomLayoutMNK{}));
 | |
| 
 | |
|     using TiledMma_TAIL = decltype(cute::make_tiled_mma(
 | |
|         cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
 | |
|         AtomLayoutMNK{}));
 | |
| 
 | |
|     using SmemLayoutAtomA = decltype(
 | |
|         cutlass::gemm::collective::detail::rs_smem_selector<
 | |
|             GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
 | |
| 
 | |
|     using SmemLayoutA = decltype(
 | |
|         tile_to_shape(SmemLayoutAtomA{},
 | |
|             make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
 | |
| 
 | |
|     using SmemLayoutAtomB = decltype(
 | |
|         cutlass::gemm::collective::detail::rs_smem_selector<
 | |
|             GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
 | |
|             decltype(cute::get<2>(TileShape_MNK{}))>());
 | |
| 
 | |
|     using SmemLayoutB = decltype(
 | |
|         tile_to_shape(SmemLayoutAtomB{},
 | |
|             make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
 | |
| 
 | |
|     using SmemLayoutAtomB_TAIL = decltype(
 | |
|         cutlass::gemm::collective::detail::rs_smem_selector<
 | |
|             GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
 | |
|             decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
 | |
| 
 | |
|     using SmemLayoutB_TAIL = decltype(
 | |
|         tile_to_shape(SmemLayoutAtomB_TAIL{},
 | |
|             make_shape(
 | |
|                 shape<1>(TileShape_MNK_TAIL{}),
 | |
|                 shape<2>(TileShape_MNK_TAIL{}),
 | |
|                 Int<kStages>{})
 | |
|             ));
 | |
| 
 | |
|     using SmemLayoutAtomC = decltype(
 | |
|         cutlass::gemm::collective::detail::rs_smem_selector<
 | |
|         GMMA::Major::K, ElementOutput,
 | |
|         decltype(cute::get<0>(TileShape_MNK{})),
 | |
|         decltype(cute::get<1>(TileShape_MNK{}))>());
 | |
| 
 | |
|     using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
 | |
| 
 | |
|     using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
 | |
|     using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
 | |
| 
 | |
|     using SharedStorage = SharedStorage<
 | |
|         kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
 | |
| 
 | |
|     using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
 | |
|     using PipelineState = typename cutlass::PipelineState<kStages>;
 | |
| 
 | |
| 
 | |
|     static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
 | |
|     static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
 | |
|     // static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
 | |
|     static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
 | |
|     using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
 | |
|     using TiledCopyCThrLayout = decltype(cute::make_layout(
 | |
|         cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
 | |
|         LayoutRight{}));
 | |
|     using TiledCopyCValLayout = decltype(cute::make_layout(
 | |
|         cute::make_shape(_1{}, Int<kNumVecElem>{}),
 | |
|         LayoutRight{}));
 | |
|     using TiledCopyC = decltype(make_tiled_copy(
 | |
|         TiledCopyCAtom{},
 | |
|         TiledCopyCThrLayout{}, // Thr layout
 | |
|         TiledCopyCValLayout{} // Val layout
 | |
|     ));
 | |
| };
 |