mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	 05c670e593
			
		
	
	05c670e593
	
	
	
		
			
			* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
		
			
				
	
	
		
			237 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			237 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| #pragma once
 | |
| 
 | |
| 
 | |
| #include "multi_head_latent_attention_kernel.h"
 | |
| 
 | |
| template <size_t vec_size, typename T>
 | |
| struct softmax_state_t {
 | |
|   AlignedVector<T, vec_size> o;
 | |
|   T m;
 | |
|   T d;
 | |
| 
 | |
|   __device__ __forceinline__ void init() {
 | |
|     if constexpr (std::is_same<T, half>::value) {
 | |
| #pragma unroll
 | |
|       for (int i = 0; i < vec_size / 2; ++i) {
 | |
|         *((half2*)(&o) + i) = make_half2(0, 0);
 | |
|       }
 | |
|     } else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
 | |
| #pragma unroll
 | |
|       for (int i = 0; i < vec_size / 2; ++i) {
 | |
|         *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
 | |
|       }
 | |
|     }
 | |
|     d = 1.f;
 | |
|     if constexpr (std::is_same<T, half>::value) {
 | |
|       m = __float2half(-5e4f);
 | |
|     } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
 | |
|       m = __float2bfloat16(-3.38953e38f);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   __device__ __forceinline__ softmax_state_t() {
 | |
|     init();
 | |
|   }
 | |
| 
 | |
|   __device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
 | |
|                                         T other_m,
 | |
|                                         T other_d) {
 | |
|     // using kType = typename cascade_attn_nv_type2_traits<T>::type;
 | |
|     T m_prev = m, d_prev = d;
 | |
|     m = m_prev > other_m ? m_prev : other_m;
 | |
|     T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
 | |
| 
 | |
|     d = d_prev * scale1 + other_d * scale2;
 | |
| 
 | |
| #pragma unroll
 | |
|     for (size_t i = 0; i < vec_size; ++i) {
 | |
|       o[i] = o[i] * scale1 + other_o[i] * scale2;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   __device__ __forceinline__ void normalize() {
 | |
| 
 | |
| #pragma unroll
 | |
|     for (size_t i = 0; i < vec_size; ++i) {
 | |
|       o[i] /= d;
 | |
|     }
 | |
|   }
 | |
| 
 | |
| };
 | |
| 
 | |
| template <size_t vec_size, typename T, uint32_t num_tiles = 0>
 | |
| struct softmax_state_ts {
 | |
|   uint32_t num_tiles_ = num_tiles;
 | |
|   AlignedVector<T, vec_size> o[num_tiles];
 | |
|   float m;
 | |
|   float d;
 | |
| 
 | |
|   __device__ __forceinline__ void init() {
 | |
| #pragma unroll
 | |
|     for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
 | |
|       if constexpr (std::is_same<T, half>::value) {
 | |
| #pragma unroll
 | |
|         for (int i = 0; i < vec_size / 2; ++i) {
 | |
|           *((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
 | |
|         }
 | |
|       } else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
 | |
| #pragma unroll
 | |
|         for (int i = 0; i < vec_size / 2; ++i) {
 | |
|           *((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|     d = 1.f;
 | |
|     if constexpr (std::is_same<T, half>::value) {
 | |
|       m = -5e4f;
 | |
|     } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
 | |
|       m = -3.38953e38f;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   __device__ __forceinline__ softmax_state_ts() {
 | |
|     init();
 | |
|   }
 | |
| 
 | |
|   __device__ __forceinline__ void normalize(const uint32_t tile_id) {
 | |
| 
 | |
| #pragma unroll
 | |
|     for (size_t i = 0; i < vec_size; i++) {
 | |
|       o[tile_id][i] /= d;
 | |
|     }
 | |
|   }
 | |
| 
 | |
| };
 | |
| 
 | |
| template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
 | |
| __device__ __forceinline__ void produce_kv(CacheT *smem,
 | |
|                                           CacheT *kv_base_gptr,
 | |
|                                           const int * block_table_smem,
 | |
|                                           const uint32_t seq_offset_gmem,
 | |
|                                           const uint32_t seq_offset_smem,
 | |
|                                           const uint32_t kv_head_idx,
 | |
|                                           const uint32_t kv_num_heads,
 | |
|                                           const uint32_t tidx,
 | |
|                                           const uint32_t chunk_start,
 | |
|                                           const uint32_t chunk_end) {
 | |
|   int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
 | |
|   if (block_id < 0) {
 | |
|     block_id = 0;
 | |
|   }
 | |
|   const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
 | |
|   // 8/16 T/int8 each time
 | |
|   const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
 | |
|   const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
 | |
|   for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
 | |
|     pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
 | |
|       smem + smem_offset_base + vid * CACHE_VEC_SIZE,
 | |
|       kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
 | |
|       seq_offset_gmem < chunk_end
 | |
|     );
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
 | |
| __device__ __forceinline__ void compute_qk(const T* cu_q_smem,
 | |
|                                            const CacheT* k_smem,
 | |
|                                            const uint32_t kv_idx_base,
 | |
|                                            const uint32_t stage_idx,
 | |
|                                            const uint32_t iter_base,
 | |
|                                            const uint32_t iter_bound,
 | |
|                                            const uint32_t tidx,
 | |
|                                            const uint32_t gid,
 | |
|                                            const float scale,
 | |
|                                            float *s,
 | |
|                                            softmax_state_ts<vec_size, T, num_tile_v>& st) {
 | |
|   const CacheT* smem;
 | |
|   AlignedVector<T, vec_size> q_vec;
 | |
|   AlignedVector<T, vec_size> k_vec;
 | |
|   float m_prev = st.m;
 | |
|   // smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
 | |
|   smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
 | |
| #pragma unroll
 | |
|   for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
 | |
|     if (iter_base + j < iter_bound) {
 | |
|       if constexpr (std::is_same<T, half>::value) {
 | |
|         s[j] = 0.f;
 | |
|       } else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
 | |
|         s[j] = 0.f;
 | |
|       }
 | |
| #pragma unroll
 | |
|       for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
 | |
|         Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
 | |
|         Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
 | |
|         for (uint32_t i = 0; i < vec_size; ++i) {
 | |
|           s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
 | |
|         }
 | |
|       }
 | |
| #pragma unroll
 | |
|       for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
 | |
|         s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
 | |
|       }
 | |
|       __syncthreads();
 | |
|     } else {
 | |
|       if constexpr (std::is_same<T, half>::value) {
 | |
|         s[j] = -5e4f;
 | |
|       } else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
 | |
|         s[j] = -3.38953e38f;
 | |
|       }
 | |
|     }
 | |
|     st.m = st.m > s[j] ? st.m : s[j];
 | |
|   }
 | |
| 
 | |
|   // T o_scale = hexp(m_prev - st.m);
 | |
|   float o_scale = __expf(m_prev - st.m);
 | |
|   st.d *= o_scale;
 | |
| 
 | |
| #pragma unroll
 | |
|   for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
 | |
|     // s[j] = hexp(s[j] - st.m);
 | |
|     s[j] = __expf(s[j] - st.m);
 | |
|     st.d += s[j];
 | |
|   }
 | |
| #pragma unroll
 | |
|   for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
 | |
|     for (uint32_t i = 0; i < vec_size; ++i) {
 | |
|       st.o[tile_id][i] *= o_scale;
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
 | |
| __device__ __forceinline__ void compute_sv(const float *s,
 | |
|                                            const CacheT *base_v_smem,
 | |
|                                            const uint32_t stage_idx,
 | |
|                                            const uint32_t iter_base,
 | |
|                                            const uint32_t iter_bound,
 | |
|                                            const uint32_t tidx,
 | |
|                                            softmax_state_ts<vec_size, T, num_tile>& st) {
 | |
|   const CacheT* v_smem;
 | |
|   AlignedVector<T, vec_size> v_vec;
 | |
| #pragma unroll
 | |
|   for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
 | |
|     v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
 | |
|     for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
 | |
|       Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
 | |
|       uint32_t tile_id = vid / bdx;
 | |
| #pragma unroll
 | |
|       for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
 | |
|         st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 |