diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index c4afa3d1c..f09dbb99d 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -142,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t tx_offset = tx / 8; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; #pragma unroll const int j = ty; @@ -151,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t h_offset = offset_now % group_size; T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -171,7 +169,7 @@ template __device__ __forceinline__ void load_q_global_smem( - T* q_ptr_base, + const T* q_ptr_base, smem_t* q_smem, uint32_t q_idx_base, const uint32_t qo_upper_bound, @@ -194,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t offset_now = base_offset + j * 4; const uint32_t n_offset = offset_now / group_size; const uint32_t h_offset = offset_now % group_size; - T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; + const T* q_ptr = + q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -223,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); #pragma unroll - for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; - ++i) { + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { const int offset = i * 1024 + ty * 256 + tx * 8; Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); #pragma unroll @@ -289,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise( const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check #pragma unroll - for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; - ++i) { + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -332,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( block_size / num_elems_per_128b(); // 8 constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t kv_idx = - kv_idx_base + - tx % 4 * num_elems_per_128b(); + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); if constexpr (NUM_WARP_Q == 4) { int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; @@ -343,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; ++i) { // m (num_frags_y * 16 / (num_warps * 8)) #pragma unroll - for (uint32_t j = 0; j < num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( *smem_offset, j); @@ -369,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; ++i) { // m (num_frags_y * 16 / (num_warps * 8)) #pragma unroll - for (uint32_t j = 0; j < 2 * num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( @@ -392,27 +383,28 @@ __device__ __forceinline__ void produce_v_blockwise_c8( } } -template +template __device__ __forceinline__ void produce_k_dynamic_scale( - T* k_smem_scale, - T* cache_k_reg, - const int* block_table_now, - const T* cache_k_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end -) { + T* k_smem_scale, + T* cache_k_reg, + const int* block_table_now, + const T* cache_k_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { const uint32_t tx = threadIdx.x, ty = threadIdx.y; if constexpr (NUM_WARP_Q == 4) { // 4 warps shared block_size const uint32_t tid = ty * 32 + tx; int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_k_scale_now = cache_k_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; if (tid < block_size) { k_smem_scale[tid] = cache_k_scale_now[tid]; } @@ -427,10 +419,12 @@ __device__ __forceinline__ void produce_k_dynamic_scale( const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_k_scale_now = cache_k_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; const int kv_idx_this_thread = kv_idx + ty * 32 + tx; if (kv_idx_this_thread < chunk_end) { - k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; + k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; } else { k_smem_scale[ty * 32 + tx] = 0; } @@ -443,20 +437,19 @@ __device__ __forceinline__ void produce_k_dynamic_scale( } } -template +template __device__ __forceinline__ void produce_v_dynamic_scale( - T* v_smem_scale, - T* cache_v_reg, - const int* block_table_now, - const T* cache_v_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end -) { + T* v_smem_scale, + T* cache_v_reg, + const int* block_table_now, + const T* cache_v_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { const uint32_t tx = threadIdx.x, ty = threadIdx.y; if constexpr (NUM_WARP_Q == 4) { @@ -464,7 +457,9 @@ __device__ __forceinline__ void produce_v_dynamic_scale( const uint32_t tid = ty * 32 + tx; int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_v_scale_now = cache_v_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; if (tid < block_size) { v_smem_scale[tid] = cache_v_scale_now[tid]; } @@ -481,10 +476,12 @@ __device__ __forceinline__ void produce_v_dynamic_scale( const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_v_scale_now = cache_v_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; const int kv_idx_this_thread = kv_idx + ty * 32 + tx; if (kv_idx_this_thread < chunk_end) { - v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; + v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; } else { v_smem_scale[ty * 32 + tx] = 0; } @@ -560,8 +557,7 @@ __device__ __forceinline__ void produce_k_blockwise_c8( for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; ++i) { // m num_frags_z * 16 / (num_warps * 4) #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 8; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { smem.load_128b_async(*smem_offset, cache_k_now, true); *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( *smem_offset, j); @@ -614,8 +610,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4( #pragma unroll for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m #pragma unroll - for (uint32_t j = 0; j < num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>( *smem_offset, j); @@ -671,8 +666,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4( for (uint32_t i = 0; i < num_frags_z * 2 / num_warps; ++i) { // m num_frags_z * 16 / (num_warps * 8) #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 8; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { smem.load_128b_async(*smem_offset, cache_k_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>( *smem_offset, j); @@ -937,7 +931,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, - const T *cache_k_scale, + const T* cache_k_scale, float (*s_frag)[num_frags_z][8]) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); @@ -973,8 +967,8 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, #pragma unroll for (uint32_t fy = 0; fy < 2; ++fy) { T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fy * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); // scale zp if constexpr (!IsDynamicC8) { if constexpr (is_scale_channel_wise) { @@ -1036,7 +1030,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask, const uint32_t chunk_end, const uint32_t attn_mask_len, float (*s_frag)[num_frags_z][8], - const int *mask_offset = nullptr, + const int* mask_offset = nullptr, const int sliding_window = 0) { const uint32_t tx = threadIdx.x; #pragma unroll @@ -1053,24 +1047,25 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask, 8 * (reg_id / 4) + reg_id % 2; bool out_of_boundary; if (mask_offset) { - out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true; - } - else if (sliding_window > 0) - { - bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window; - out_of_boundary = - (causal - ? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end); - } - else - { - out_of_boundary = - (causal - ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end); - if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) { - const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + out_of_boundary = q_idx < qo_len + ? (kv_idx >= mask_offset[q_idx * 2 + 1] || + kv_idx < mask_offset[q_idx * 2]) + : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - + sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; bool mask = attn_mask[mask_idx]; out_of_boundary |= mask; } @@ -1236,7 +1231,7 @@ __device__ __forceinline__ void compute_sfm_v_c8( float (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2], - const T *cache_v_scale) { + const T* cache_v_scale) { constexpr uint32_t num_vecs_per_blocksize = block_size / num_elems_per_128b(); T s_frag_f16[num_frags_x][num_frags_z][8]; @@ -1268,8 +1263,8 @@ __device__ __forceinline__ void compute_sfm_v_c8( #pragma unroll for (uint32_t fz = 0; fz < 2; ++fz) { T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fz * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp if constexpr (!IsDynamicC8) { if constexpr (is_scale_channel_wise) { @@ -1300,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8( o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), b_frag_dq); - } } } @@ -1328,7 +1322,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( float (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2], - T *cache_v_scale) { + T* cache_v_scale) { constexpr uint32_t num_vecs_per_blocksize = block_size / num_elems_per_128b(); @@ -1362,8 +1356,8 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( for (uint32_t fz = 0; fz < 2; ++fz) { // dequant b_frag -> b_frag_dq T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fz * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp if constexpr (!IsDynamicC8) { if constexpr (is_scale_channel_wise) { @@ -1372,7 +1366,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; } } else { - #pragma unroll +#pragma unroll for (uint32_t b_i = 0; b_i < 8; ++b_i) { b_frag_dq_T[b_i] *= cache_v_scale[0]; } @@ -1431,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, } #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; - ++fz) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t b_frag[4]; @@ -1611,10 +1604,9 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - fx * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1627,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( } __syncthreads(); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * 4 + tx / 8, tx % 8); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); o_idx_base += (tx / 8) / group_size; o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + @@ -1642,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + ((fx * 16 + j * 4) % group_size) * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { // need write o_smem->store_128b(o_smem_offset_w, o_ptr); @@ -1658,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( } } - template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1717,7 +1707,6 @@ struct StoreFunc { } }; - template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1770,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - fx * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1786,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( } __syncthreads(); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * 4 + tx / 8, tx % 8); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); const uint32_t tx_offset = tx / 8; #pragma unroll @@ -1804,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (n_offset < qo_upper_bound) { if constexpr (!partition_kv) { Load( @@ -1881,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - (ty * num_frags_x + fx) * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1897,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( __syncthreads(); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * num_frags_x * 16 + tx / 8, - tx % 8); + ty * num_frags_x * 16 + tx / 8, tx % 8); const uint32_t tx_offset = tx / 8; #pragma unroll @@ -1914,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (n_offset < qo_upper_bound) { if (!partition_kv) { Load( - reinterpret_cast(o_smem->base + o_smem_offset_w), - &ori_out_vec); + reinterpret_cast(o_smem->base + o_smem_offset_w), + &ori_out_vec); if (in_scale > 0.0) { if (shift_bias) { Load(shift_bias + shift_smooth_offset, @@ -1929,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( &smooth_weight_vec); } } - #pragma unroll +#pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { StoreFunc()(ori_out_vec, - shift_bias_vec, - smooth_weight_vec, - out_vec, - quant_max_bound, - quant_min_bound, - in_scale, - i); + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store(out_vec, o_ptr); } else { @@ -1979,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - (ty * num_frags_x + fx) * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1995,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem( __syncthreads(); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * num_frags_x * 16 + tx / 8, - tx % 8); + ty * num_frags_x * 16 + tx / 8, tx % 8); o_idx_base += (tx / 8) / group_size; o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + @@ -2009,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem( T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + ((fx * 16 + j * 4) % group_size) * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } @@ -2125,7 +2104,6 @@ __global__ void merge_multi_chunks_kernel( &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); } - template __device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], float* md_smem, @@ -2307,18 +2285,18 @@ template __global__ void merge_multi_chunks_decoder_kernel( - const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, // head_dim] - const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] - const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] - const int *__restrict__ seq_lens_q, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ seq_lens_encoder, - const int *__restrict__ cu_seqlens_q, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] - OutT *__restrict__ out, + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + OutT* __restrict__ out, const float quant_max_bound, const float quant_min_bound, const float in_scale, @@ -2419,8 +2397,14 @@ __global__ void merge_multi_chunks_decoder_kernel( } #pragma unroll for (int i = 0; i < vec_size; ++i) { - StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store( out_vec, @@ -2435,19 +2419,19 @@ template __global__ void merge_multi_chunks_v2_kernel( - const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, // head_dim] - const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] - const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] - const int *__restrict__ seq_lens_q, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ seq_lens_encoder, - const int *__restrict__ batch_id_per_token, - const int *__restrict__ cu_seqlens_q, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] - OutT *__restrict__ out, + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + OutT* __restrict__ out, const float quant_max_bound, const float quant_min_bound, const float in_scale, @@ -2464,7 +2448,7 @@ __global__ void merge_multi_chunks_v2_kernel( __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; - if(bid == -1){ + if (bid == -1) { continue; } const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; @@ -2486,7 +2470,7 @@ __global__ void merge_multi_chunks_v2_kernel( const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); if (num_chunks_this_seq <= 1) { continue; - }else if (!ENABLE_PREFILL){ + } else if (!ENABLE_PREFILL) { continue; } @@ -2496,12 +2480,12 @@ __global__ void merge_multi_chunks_v2_kernel( if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((half2 *)(&res_vec) + i) = make_half2(0, 0); + *((half2*)(&res_vec) + i) = make_half2(0, 0); } } else { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); } } float m; @@ -2581,10 +2565,17 @@ __global__ void merge_multi_chunks_v2_kernel( Load(smooth_weight + shift_smooth_offset, &smooth_weight_vec); } + #pragma unroll for (int i = 0; i < vec_size; ++i) { - StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store( out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 4fc43e34f..4a42235f5 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -20,12 +20,11 @@ #include "utils.cuh" template -__global__ void -GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, - const int *seq_lens_encoder, - const int *seq_lens_this_time_merged, - const int *seq_lens_encoder_merged, const int *seq_mapping, - const int *system_lens, int *max_lens, const int batch_size) { +__global__ void GetMaxLenKernel(const int *seq_lens_decoder, + const int *seq_lens_this_time, + const int *seq_lens_encoder, + int *max_lens, + const int batch_size) { const int tid = threadIdx.x; typedef cub::BlockReduce BlockReduce; @@ -36,9 +35,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, int max_len_decoder_this_thread = 0; int max_len_this_thread = 0; int max_just_dec_len_this_thread = 0; - int max_just_dec_merged_len_this_time_this_thread = 0; - int max_system_len_this_thread = 0; - int max_dec_len_without_system_this_thread = 0; int max_len_kv_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { const int seq_len_this_time = seq_lens_this_time[i]; @@ -47,17 +43,17 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, max(seq_len_this_time, max_len_this_time_this_thread); max_len_encoder_this_thread = max(seq_lens_encoder[i], max_len_encoder_this_thread); - max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread); - if (seq_len_this_time <= 0) - continue; - const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; max_len_this_thread = max(seq_len_decoder + seq_len_this_time, max_len_this_thread); max_just_dec_len_this_thread = max(max_just_dec_len_this_thread, max_just_dec_len_now); - if (seq_len_decoder == 0) - continue; + if (seq_len_decoder == 0) continue; max_len_kv_this_thread = max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); } @@ -74,14 +70,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); int total_just_dec = BlockReduce(temp_storage) .Reduce(max_just_dec_len_this_thread, MaxOp()); - int total_just_dec_merged = - BlockReduce(temp_storage) - .Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp()); - int total_system_len = BlockReduce(temp_storage) - .Reduce(max_system_len_this_thread, MaxOp()); - int total_dec_len_without_system = - BlockReduce(temp_storage) - .Reduce(max_dec_len_without_system_this_thread, MaxOp()); int total_max_len_kv = BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); if (tid == 0) { @@ -90,24 +78,10 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, max_lens[2] = total_max_len_decoder; max_lens[3] = total; max_lens[4] = total_just_dec; - max_lens[5] = total_just_dec_merged; - max_lens[6] = total_system_len; - max_lens[7] = total_dec_len_without_system; max_lens[8] = total_max_len_kv; } } -void GetMaxLen(const paddle::Tensor &seq_lens_tensor, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - paddle::Tensor &max_len_tensor, const int batch_size) { - constexpr int blockSize = 1024; - GetMaxLenKernel<<<1, blockSize, 0, seq_lens_encoder.stream()>>>( - seq_lens_tensor.data(), seq_lens_this_time.data(), - seq_lens_encoder.data(), nullptr, nullptr, nullptr, nullptr, - max_len_tensor.data(), batch_size); -} - template __global__ void search_chunk_size_for_mla( const int *__restrict__ seq_lens_q, @@ -154,11 +128,11 @@ __global__ void search_chunk_size_for_mla( uint32_t res_id = 0; uint32_t max_last_wave_block = 0; for (uint32_t i = 1; i < config_size; i++) { - uint32_t last_wave_block = gridx_shared[i] % sm_cout; - if (last_wave_block >= max_last_wave_block) { - res_id = i; - max_last_wave_block = last_wave_block; - } + uint32_t last_wave_block = gridx_shared[i] % sm_cout; + if (last_wave_block >= max_last_wave_block) { + res_id = i; + max_last_wave_block = last_wave_block; + } } *num_blocks_x = gridx_shared[res_id]; *res_chunk_size = block_size << res_id; @@ -185,11 +159,11 @@ __global__ void split_block_for_mla(const int *__restrict__ seq_lens_q, int loop_times; loop_times = cute::ceil_div(seq_len_decoder, chunk_size); if (seq_len_encoder > 0) { - loop_times = 0; + loop_times = 0; } for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { - batch_ids[index] = bid; - tile_ids_per_batch[index++] = tile_id; + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; } } } @@ -255,8 +229,10 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, int *__restrict__ tile_ids_per_batch, - int *__restrict__ num_blocks_x, const int bsz, - const int pad_len, const int num_row_per_block) { + int *__restrict__ num_blocks_x, + const int bsz, + const int pad_len, + const int num_row_per_block) { if (threadIdx.x == 0) { int gridx = 0; int index = 0; @@ -281,31 +257,39 @@ void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, - paddle::Tensor &decoder_batch_ids, // Inplace - paddle::Tensor &decoder_tile_ids_per_batch, // Inplace - paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory - paddle::Tensor &decoder_num_blocks_device, // Inplace - paddle::Tensor &decoder_chunk_size_device, // Inplace - paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU - paddle::Tensor &encoder_batch_ids, // Inplace - paddle::Tensor &encoder_tile_ids_per_batch, // Inplace - paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU - paddle::Tensor &kv_batch_ids, // Inplace - paddle::Tensor &kv_tile_ids_per_batch, // Inplace - paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &decoder_batch_ids, // Inplace + paddle::Tensor &decoder_tile_ids_per_batch, // Inplace + paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_device, // Inplace + paddle::Tensor &decoder_chunk_size_device, // Inplace + paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU + paddle::Tensor &encoder_batch_ids, // Inplace + paddle::Tensor &encoder_tile_ids_per_batch, // Inplace + paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &kv_batch_ids, // Inplace + paddle::Tensor &kv_tile_ids_per_batch, // Inplace + paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, const int block_size, - const int decoder_step_token_num) -{ + const int decoder_step_token_num) { auto stream = seq_lens_encoder.stream(); int bsz = seq_lens_this_time.shape()[0]; - paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place()); - GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, - max_len_tensor_gpu, bsz); - max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); + + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -320,7 +304,6 @@ void GetBlockShapeAndSplitKVBlock( // decoder if (max_dec_len_this_time > 0) { - const bool mla_backend = checkAttentionBackend(); if (mla_backend && group_size <= 64) { const int set_chunk_size = get_mla_dec_chunk_size(bsz); @@ -356,8 +339,9 @@ void GetBlockShapeAndSplitKVBlock( const int chunk_size = decoder_chunk_size_cpu.data()[0]; // NOTE: (changwenbin) When using auto_chunk, - // decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K. - // const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024; + // decode_max_tile_size must take into account the maximum case, where * + // 1024 can cover 128K. const uint32_t decoder_batch_shape = + // seq_lens_decoder.dims()[0] * 1024; const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); @@ -375,7 +359,6 @@ void GetBlockShapeAndSplitKVBlock( decoder_batch_shape * sizeof(int32_t), stream)); - split_block_for_mla<<<1, 32, 0, stream>>>( seq_lens_this_time.data(), seq_lens_encoder.data(), @@ -419,49 +402,72 @@ void GetBlockShapeAndSplitKVBlock( decoder_num_blocks_cpu.copy_( decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); - decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); } // encoder if (max_enc_len_this_time > 0) { - const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); + const uint32_t max_tile_size_per_bs_kv = + div_up(max_enc_dec_len_this_time, block_size); const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(kv_tile_ids_per_batch.data(), + 0, + kv_batch_shape * sizeof(int32_t), + stream)); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>( seq_lens_decoder.data(), // sequence_lengths->data(), - seq_lens_encoder.data(), kv_batch_ids.data(), - kv_tile_ids_per_batch.data(), kv_num_blocks_x.data(), bsz, - block_size, block_size); + seq_lens_encoder.data(), + kv_batch_ids.data(), + kv_tile_ids_per_batch.data(), + kv_num_blocks_x.data(), + bsz, + block_size, + block_size); - kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); + kv_num_blocks_x_cpu.copy_( + kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); // Clear buffer - const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); + const uint32_t encoder_max_tile_size_per_bs_q = + div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(encoder_batch_ids.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(encoder_tile_ids_per_batch.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), nullptr, + split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), + nullptr, encoder_batch_ids.data(), encoder_tile_ids_per_batch.data(), - encoder_num_blocks_x.data(), bsz, - encoder_block_shape_q, group_size); - encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); + encoder_num_blocks_x.data(), + bsz, + encoder_block_shape_q, + group_size); + encoder_num_blocks_x_cpu.copy_( + encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); } - } std::vector> GetBlockShapeAndSplitKVBlockInferShape( @@ -472,8 +478,7 @@ std::vector> GetBlockShapeAndSplitKVBlockInferShape( const int decoder_block_shape_q, const int group_size, const int block_size, - const int decoder_step_token_num -) { + const int decoder_step_token_num) { return {}; } @@ -485,39 +490,36 @@ std::vector GetBlockShapeAndSplitKVBlockInferDtype( const int decoder_block_shape_q, const int group_size, const int block_size, - const int decoder_step_token_num -) { + const int decoder_step_token_num) { return {}; } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) .Inputs({ - "seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time", - "decoder_batch_ids", - "decoder_tile_ids_per_batch", - "decoder_num_blocks_cpu", - "decoder_num_blocks_device", - "decoder_chunk_size_device", - "max_len_tensor_cpu", - "encoder_batch_ids", - "encoder_tile_ids_per_batch", - "encoder_num_blocks_x_cpu", - "kv_batch_ids", - "kv_tile_ids_per_batch", - "kv_num_blocks_x_cpu", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks_cpu", + "decoder_num_blocks_device", + "decoder_chunk_size_device", + "max_len_tensor_cpu", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks_x_cpu", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks_x_cpu", }) .Outputs({ }) - .Attrs({ - "encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "group_size: int", - "block_size: int", - "decoder_step_token_num: int" - }) + .Attrs({"encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "group_size: int", + "block_size: int", + "decoder_step_token_num: int"}) .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 99cc613d8..90fd7079c 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -21,7 +21,6 @@ template __global__ void multi_query_append_attention_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * // head_dim] - T *__restrict__ cache_v, + const T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + const T *__restrict__ cache_v, const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] + const T *__restrict__ sinks, // [q_num_heads] const int *__restrict__ seq_lens, const int *__restrict__ seq_lens_kv, const int *__restrict__ batch_ids, @@ -45,7 +45,6 @@ __global__ void multi_query_append_attention_kernel( const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, const int max_seq_len, - const int max_dec_len, const int max_block_num_per_seq, const float scale, const float quant_max_bound, @@ -72,12 +71,10 @@ __global__ void multi_query_append_attention_kernel( const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids_per_batch[btid]; const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - block_table_now = block_table + batch_id * max_block_num_per_seq; - - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -85,8 +82,7 @@ __global__ void multi_query_append_attention_kernel( if (q_len <= 0) { return; } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; if (ENABLE_PREFILL) { kv_len += q_len; @@ -111,6 +107,9 @@ __global__ void multi_query_append_attention_kernel( const uint32_t chunk_len = chunk_end - chunk_start; extern __shared__ uint8_t smem[]; + static_assert(num_frags_y * 16 == HEAD_DIM); + static_assert(num_frags_z * 16 == BLOCK_SIZE); + float s_frag[num_frags_x][num_frags_z][8]; float o_frag[num_frags_x][num_frags_y][8]; float m_frag[num_frags_x][2]; @@ -131,7 +130,7 @@ __global__ void multi_query_append_attention_kernel( const uint32_t o_offset = q_start_seq_id * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; + const T *q_base_ptr = q + q_offset; T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { @@ -149,11 +148,16 @@ __global__ void multi_query_append_attention_kernel( } else { o_base_ptr_int8 = out + o_offset; } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + load_q_global_smem( q_base_ptr, &qo_smem, @@ -172,7 +176,6 @@ __global__ void multi_query_append_attention_kernel( v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * sizeof(T)); - const uint32_t num_iterations = div_up( CAUSAL ? (min(chunk_len, @@ -183,12 +186,13 @@ __global__ void multi_query_append_attention_kernel( : chunk_len, num_frags_z * 16); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, + (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : mask_offset ? 0 : chunk_len) / + : mask_offset ? 0 + : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( 8 * (tid / 16) + tid % 8, (tid % 16) / 8); @@ -204,8 +208,8 @@ __global__ void multi_query_append_attention_kernel( const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); - T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + const T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + const T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; produce_kv_blockwise(); __syncthreads(); - if constexpr (!partition_kv ) { + if constexpr (!partition_kv) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -375,7 +382,6 @@ __global__ void multi_query_append_attention_kernel( HEAD_DIM); } - if constexpr (partition_kv) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -421,13 +427,13 @@ template __global__ void multi_query_append_attention_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] T *__restrict__ cache_v, const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] + const T *__restrict__ sinks, // [q_num_heads] const int *__restrict__ seq_lens, const int *__restrict__ seq_lens_kv, const int *__restrict__ batch_ids, @@ -435,9 +441,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, - const int max_dec_len, const int max_block_num_per_seq, const float scale, const float quant_max_bound, @@ -469,8 +474,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -478,8 +483,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( if (q_len <= 0) { return; } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; if (ENABLE_PREFILL) { kv_len += q_len; @@ -540,11 +544,16 @@ __global__ void multi_query_append_attention_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 + + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + load_q_global_smem_multi_warps( @@ -648,16 +656,18 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq, - sliding_window); + num_frags_z>( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window); } // update m,d @@ -720,15 +730,19 @@ __global__ void multi_query_append_attention_warp1_4_kernel( if (num_chunks_this_seq <= 1) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -876,7 +890,6 @@ void MultiQueryAppendAttention( CAUSAL, num_warps, NUM_WARP_Q, - NUM_WARP_KV, HEAD_DIM, BLOCK_SIZE, num_frags_x, @@ -908,7 +921,6 @@ void MultiQueryAppendAttention( CAUSAL, num_warps, NUM_WARP_Q, - NUM_WARP_KV, HEAD_DIM, BLOCK_SIZE, num_frags_x, @@ -933,8 +945,8 @@ void MultiQueryAppendAttention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -943,7 +955,6 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -996,8 +1007,8 @@ void MultiQueryAppendAttention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1006,7 +1017,6 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1048,8 +1058,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1087,8 +1097,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1138,9 +1148,9 @@ void MultiQueryAppendAttention( uint32_t attn_mask_len; if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; + attn_mask_len = attn_mask.get().shape()[1]; } else { - attn_mask_len = -1; + attn_mask_len = -1; } const int num_chunks = div_up(max_seq_len, chunk_size); @@ -1179,8 +1189,8 @@ void MultiQueryAppendAttention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1189,9 +1199,8 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1250,14 +1259,14 @@ void MultiQueryAppendAttention( reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1266,9 +1275,8 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1306,14 +1314,14 @@ void MultiQueryAppendAttention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1326,15 +1334,14 @@ void MultiQueryAppendAttention( } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1345,14 +1352,14 @@ void MultiQueryAppendAttention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound,