diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index f1f5a6177..6af601dad 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -317,7 +317,6 @@ void AppendAttentionKernel( qkv, // [token_num, num_heads, head_dim] seq_lens_decoder, seq_lens_encoder, - batch_id_per_token, cu_seqlens_q, block_tables, rotary_embs, @@ -344,7 +343,6 @@ void AppendAttentionKernel( qkv_out, // [token_num, num_heads, head_dim] seq_lens_decoder, seq_lens_encoder, - batch_id_per_token, cu_seqlens_q, block_tables, rotary_embs, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 75f9ebd8d..45c9d0a02 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -18,6 +18,53 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" + +// Note(ZKK) +// This function is very easy! +// just make HeadDim data to be new HeadDim data! + +template +__device__ __forceinline__ void apply_rope( + const T* input, + const float* cos_emb, + const float* sin_emb, + T* output, + const int thread_id) { + + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT out_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + +#pragma unroll + for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) { + Load(&input[head_bias], &src_vec); + const uint32_t emb_idx = head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + out_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + out_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } + Store(out_vec, &output[head_bias]); + } +} + template __global__ void append_decode_cache_T_rope_qk_norm_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, @@ -28,7 +75,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -164,7 +211,7 @@ __global__ void append_decode_cache_T_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -270,7 +317,7 @@ __global__ void append_decode_cache_T_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -391,7 +438,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -505,7 +551,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel( // head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -639,7 +684,6 @@ __global__ void append_decode_cache_int8_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -677,44 +721,18 @@ __global__ void append_decode_cache_int8_rope_kernel( if (head_idx < num_heads) { // q - using LoadT = AlignedVector; - using LoadBiasT = AlignedVector; - using LoadOutScaleT = AlignedVector; - constexpr int HalfVecSize = VecSize / 2; - using LoadEmbT = AlignedVector; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim; - LoadT src_vec; - LoadBiasT out_vec; - LoadEmbT cos_emb_vec; - LoadEmbT sin_emb_vec; - const T* qkv_now = quant_qkv + start_token_idx * hidden_size; - T* qkv_out_now = qkv_out + start_token_idx * hidden_size; -#pragma unroll - for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; - head_bias += 32 * VecSize) { - const int bias_idx = head_idx * HeadDim + head_bias; - Load(&qkv_now[bias_idx], &src_vec); + uint32_t emb_offset = write_seq_id * half_head_size; + emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0; + apply_rope( + qkv_now, + cos_emb + emb_offset, + sin_emb + emb_offset, + qkv_out_now, + lane_id); - // q rope - const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; - Load(&cos_emb[new_emb_idx], &cos_emb_vec); - Load(&sin_emb[new_emb_idx], &sin_emb_vec); -#pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - // dequant + add_bias + rope - float input_left = static_cast(src_vec[2 * i]); - float input_right = static_cast(src_vec[2 * i + 1]); - - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - } - Store(out_vec, &qkv_out_now[bias_idx]); - } } else if (head_idx < num_heads + 2 * kv_num_heads) { // k constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 @@ -889,7 +907,6 @@ __global__ void append_decode_cache_int8_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1194,7 +1211,6 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1496,7 +1512,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1893,7 +1909,7 @@ __global__ void append_decode_cache_int4_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1934,44 +1950,18 @@ __global__ void append_decode_cache_int4_rope_kernel( if (head_idx < num_heads) { // q - using LoadT = AlignedVector; - using LoadBiasT = AlignedVector; - using LoadOutScaleT = AlignedVector; - constexpr int HalfVecSize = VecSize / 2; - using LoadEmbT = AlignedVector; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim; - LoadT src_vec; - LoadBiasT out_vec; - LoadEmbT cos_emb_vec; - LoadEmbT sin_emb_vec; - const T* qkv_now = quant_qkv + start_token_idx * hidden_size; - T* qkv_out_now = qkv_out + start_token_idx * hidden_size; -#pragma unroll - for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; - head_bias += 32 * VecSize) { - const int bias_idx = head_idx * HeadDim + head_bias; - Load(&qkv_now[bias_idx], &src_vec); + uint32_t emb_offset = write_seq_id * half_head_size; + emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0; + apply_rope( + qkv_now, + cos_emb + emb_offset, + sin_emb + emb_offset, + qkv_out_now, + lane_id); - // q rope - const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; - Load(&cos_emb[new_emb_idx], &cos_emb_vec); - Load(&sin_emb[new_emb_idx], &sin_emb_vec); -#pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - // dequant + add_bias + rope - float input_left = static_cast(src_vec[2 * i]); - float input_right = static_cast(src_vec[2 * i + 1]); - - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - } - Store(out_vec, &qkv_out_now[bias_idx]); - } } else if (head_idx < num_heads + 2 * kv_num_heads) { // k constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 @@ -2191,7 +2181,7 @@ __global__ void append_decode_cache_int4_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2522,7 +2512,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2895,7 +2885,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( // block_size, head_size // 2] T* __restrict__ qkv_out, const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index 68b22968b..d6643ca20 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -21,7 +21,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, T* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -59,7 +58,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -84,7 +82,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, T* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -120,7 +117,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -143,7 +139,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -167,7 +162,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -190,7 +184,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -214,7 +207,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -247,7 +239,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -273,7 +264,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -299,7 +289,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -325,7 +314,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -351,7 +339,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -386,7 +373,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -414,7 +400,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -442,7 +427,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -470,7 +454,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, value_cache, qkv_out, block_tables, - batch_id_per_token, cu_seqlens_q, seq_lens, seq_lens_encoder, @@ -497,7 +480,6 @@ void DecoderWriteCacheWithRoPEKernel( const paddle::Tensor& qkv, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -549,7 +531,6 @@ void DecoderWriteCacheWithRoPEKernel( reinterpret_cast(value_cache_out->data()), reinterpret_cast(qkv_out->data()), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -584,7 +565,6 @@ void DecoderWriteCacheWithRoPEKernel( reinterpret_cast(value_cache_out->data()), reinterpret_cast(qkv_out->data()), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -616,7 +596,6 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -649,7 +628,6 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -683,7 +661,6 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -716,7 +693,6 @@ void DecoderWriteCacheWithRoPEKernel( value_cache_out->data(), reinterpret_cast(const_cast(qkv_out->data())), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -764,7 +740,6 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -794,7 +769,6 @@ DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -823,7 +797,6 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -852,7 +825,6 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h index 459f29448..887276195 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h @@ -23,7 +23,6 @@ void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs,