mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-25 01:20:43 +08:00 
			
		
		
		
	 b047681c5d
			
		
	
	b047681c5d
	
	
		
			
	
		
	
	
		
			Some checks failed
		
		
	
	Deploy GitHub Pages / deploy (push) Has been cancelled
				
			Publish Job / publish_pre_check (push) Has been cancelled
				
			Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
				
			Publish Job / FD-Clone-Linux (push) Has been cancelled
				
			Publish Job / Show Code Archive Output (push) Has been cancelled
				
			Publish Job / BUILD_SM8090 (push) Has been cancelled
				
			Publish Job / BUILD_SM8689 (push) Has been cancelled
				
			Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
				
			Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
				
			Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
				
			Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
				
			Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
				
			Publish Job / Run Base Tests (push) Has been cancelled
				
			Publish Job / Run Accuracy Tests (push) Has been cancelled
				
			* 支持24稀疏 * code style * 增加stmatrix 宏定义判断 * code style
		
			
				
	
	
		
			101 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include <assert.h>
 | |
| #include <stdint.h>
 | |
| #include <stdlib.h>
 | |
| 
 | |
| #include <cuda_fp16.h>
 | |
| 
 | |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 | |
| #include <cuda_bf16.h>
 | |
| #endif
 | |
| 
 | |
| #include <cute/tensor.hpp>
 | |
| #include <cute/arch/cluster_sm90.hpp>  // For cute::elect_one_sync()
 | |
| 
 | |
| #include <cutlass/array.h>
 | |
| #include <cutlass/cutlass.h>
 | |
| #include <cutlass/numeric_conversion.h>
 | |
| #include <cutlass/numeric_types.h>
 | |
| 
 | |
| 
 | |
| using namespace cute;
 | |
| 
 | |
| 
 | |
| template<typename T>
 | |
| struct PackedHalf;
 | |
| 
 | |
| template<>
 | |
| struct PackedHalf<cutlass::half_t> {
 | |
|     using Type = __half2;
 | |
| };
 | |
| 
 | |
| template<>
 | |
| struct PackedHalf<cutlass::bfloat16_t> {
 | |
|     using Type = nv_bfloat162;
 | |
| };
 | |
| 
 | |
| template <class PointerType>
 | |
| __device__ GmmaDescriptor make_smem_desc(
 | |
|         PointerType smem_ptr,
 | |
|         int layout_type,
 | |
|         int leading_byte_offset = 0,
 | |
|         int stride_byte_offset = 1024) {
 | |
| 
 | |
|     GmmaDescriptor desc;
 | |
|     auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
 | |
|     desc.bitfield.start_address_ = uint_ptr >> 4;
 | |
|     desc.bitfield.layout_type_ = layout_type;
 | |
|     desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
 | |
|     desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
 | |
|     desc.bitfield.base_offset_ = 0;
 | |
|     return desc;
 | |
| }
 | |
| 
 | |
| template <typename Mma, size_t ...Idx>
 | |
| __forceinline__ __device__ static void gemm(uint64_t const& desc_a, uint64_t const& desc_b, float* d, const uint32_t e, std::index_sequence<Idx...>) {
 | |
|     Mma::fma(desc_a, desc_b, d[Idx]..., e, GMMA::ScaleOut::One);
 | |
| }
 | |
| 
 | |
| template <typename Mma, int kBlockK, int NumMmaThreads, typename T>
 | |
| __forceinline__ __device__ void gemm(
 | |
|         const T * sA,
 | |
|         const T * sB,
 | |
|         float * acc_c,
 | |
|         const uint32_t *E) {
 | |
| 
 | |
|     constexpr int acc_num = sizeof(Mma::CRegisters) / sizeof(float);
 | |
| 
 | |
|     warpgroup_arrive();
 | |
|     // 选择的下标   对应的16进制
 | |
|     //    01          4
 | |
|     //    02          8
 | |
|     //    03          12
 | |
|     //    12          9
 | |
|     //    13          13
 | |
|     //    23          14
 | |
|     #pragma unroll
 | |
|     for (int i = 0; i < kBlockK / 64; i++) {
 | |
|         GmmaDescriptor a_desc = make_smem_desc(sA + i * 32, 1, 0, 1024);
 | |
|         GmmaDescriptor b_desc = make_smem_desc(sB + i * 64, 1, 0, 1024);
 | |
|         gemm<Mma>(a_desc, b_desc, acc_c, E[i * NumMmaThreads], std::make_index_sequence<acc_num>{});
 | |
|     }
 | |
| 
 | |
|     warpgroup_commit_batch();
 | |
|     warpgroup_wait<0>();
 | |
| }
 |