mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			118 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			118 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| 
 | |
| #include "paddle/extension.h"
 | |
| 
 | |
| #ifndef PD_BUILD_STATIC_OP
 | |
| #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
 | |
| #endif
 | |
| 
 | |
| template <int warps, typename T>
 | |
| __global__ void FillEncoderDecoderResKernel(
 | |
|         T * encoder_res_data,
 | |
|         T * decoder_res_data,
 | |
|         const int * seq_lens_encoder,
 | |
|         const int * seq_lens_decoder,
 | |
|         const int * seq_lens_this_time,
 | |
|         const int * cu_seq_q,
 | |
|         const int head_num,
 | |
|         const int head_dim) {
 | |
| 
 | |
|     const int bidb = blockIdx.x;
 | |
|     const int bidh = blockIdx.y;
 | |
|     const int bidt = blockIdx.z * warps;
 | |
|     const int tid = threadIdx.x;
 | |
|     const int warp_id = tid / 32;
 | |
|     const int land_id = tid % 32;
 | |
|     const int token_id = bidt + warp_id;
 | |
| 
 | |
|     const int seq_len_encoder = seq_lens_encoder[bidb];
 | |
|     const int seq_len_decoder = seq_lens_decoder[bidb];
 | |
|     const int seq_len_this_time = seq_lens_this_time[bidb];
 | |
| 
 | |
|     if (seq_len_encoder > 0 || seq_len_decoder == 0 || token_id >= seq_len_this_time) {
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     const int load_idx = ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
 | |
| 
 | |
|     *reinterpret_cast<float2*>(encoder_res_data + load_idx) = *reinterpret_cast<float2*>(decoder_res_data + load_idx);
 | |
| }
 | |
| 
 | |
| void MergePrefillDecodeOutput(
 | |
|         const paddle::Tensor &encoder_res,
 | |
|         const paddle::Tensor &decoder_res,
 | |
|         const paddle::Tensor &seq_lens_encoder,
 | |
|         const paddle::Tensor &seq_lens_decoder,
 | |
|         const paddle::Tensor &seq_lens_this_time,
 | |
|         const paddle::Tensor &cu_seq_q,
 | |
|         const int head_num,
 | |
|         const int head_dim,
 | |
|         const int max_token) {
 | |
| 
 | |
|     if (head_dim != 128) {
 | |
|         PD_THROW("Only supported head_dim = 128");
 | |
|     }
 | |
|     const int batch_size = seq_lens_encoder.shape()[0];
 | |
|     constexpr int warps = 4;
 | |
|     const int tokens_block = (max_token + warps - 1) / warps;
 | |
|     dim3 grid_dims;
 | |
|     grid_dims.x = batch_size;
 | |
|     grid_dims.y = head_num;
 | |
|     grid_dims.z = tokens_block;
 | |
| 
 | |
|     if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
 | |
|         using T = phi::dtype::float16;
 | |
|         FillEncoderDecoderResKernel<warps>
 | |
|             <<<grid_dims, 128, 0, encoder_res.stream()>>>(
 | |
|             const_cast<T*>(encoder_res.data<T>()),
 | |
|             const_cast<T*>(decoder_res.data<T>()),
 | |
|             seq_lens_encoder.data<int>(),
 | |
|             seq_lens_decoder.data<int>(),
 | |
|             seq_lens_this_time.data<int>(),
 | |
|             cu_seq_q.data<int>(),
 | |
|             head_num,
 | |
|             head_dim
 | |
|             );
 | |
|     } else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
 | |
|         using T = phi::dtype::bfloat16;
 | |
|         FillEncoderDecoderResKernel<warps>
 | |
|             <<<grid_dims, 128, 0, encoder_res.stream()>>>(
 | |
|             const_cast<T*>(encoder_res.data<T>()),
 | |
|             const_cast<T*>(decoder_res.data<T>()),
 | |
|             seq_lens_encoder.data<int>(),
 | |
|             seq_lens_decoder.data<int>(),
 | |
|             seq_lens_this_time.data<int>(),
 | |
|             cu_seq_q.data<int>(),
 | |
|             head_num,
 | |
|             head_dim
 | |
|             );
 | |
|     }
 | |
| }
 | |
| 
 | |
| PD_BUILD_STATIC_OP(merge_prefill_decode_output)
 | |
|     .Inputs({"encoder_res",
 | |
|              "decoder_res",
 | |
|              "seq_lens_encoder",
 | |
|              "seq_lens_decoder",
 | |
|              "seq_lens_this_time",
 | |
|              "cu_seq_q"})
 | |
|     .Outputs({"res"})
 | |
|     .Attrs({"head_num: int",
 | |
|             "head_dim: int",
 | |
|             "max_token: int"})
 | |
|     .SetInplaceMap({{"encoder_res", "res"}})
 | |
|     .SetKernelFn(PD_KERNEL(MergePrefillDecodeOutput));
 | 
