diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 4639d1e93..000820688 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -761,6 +761,17 @@ void SpeculateStepPaddle( const int encoder_decoder_block_num, const int max_draft_tokens); +void MergePrefillDecodeOutput( + const paddle::Tensor &encoder_res, + const paddle::Tensor &decoder_res, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &cu_seq_q, + const int head_num, + const int head_dim, + const int max_token); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -1111,4 +1122,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); + + m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function"); } diff --git a/custom_ops/gpu_ops/merge_prefill_decode_output.cu b/custom_ops/gpu_ops/merge_prefill_decode_output.cu new file mode 100644 index 000000000..6902b7250 --- /dev/null +++ b/custom_ops/gpu_ops/merge_prefill_decode_output.cu @@ -0,0 +1,117 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +__global__ void FillEncoderDecoderResKernel( + T * encoder_res_data, + T * decoder_res_data, + const int * seq_lens_encoder, + const int * seq_lens_decoder, + const int * seq_lens_this_time, + const int * cu_seq_q, + const int head_num, + const int head_dim) { + + const int bidb = blockIdx.x; + const int bidh = blockIdx.y; + const int bidt = blockIdx.z * warps; + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int land_id = tid % 32; + const int token_id = bidt + warp_id; + + const int seq_len_encoder = seq_lens_encoder[bidb]; + const int seq_len_decoder = seq_lens_decoder[bidb]; + const int seq_len_this_time = seq_lens_this_time[bidb]; + + if (seq_len_encoder > 0 || seq_len_decoder == 0 || token_id >= seq_len_this_time) { + return; + } + + const int load_idx = ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4; + + *reinterpret_cast(encoder_res_data + load_idx) = *reinterpret_cast(decoder_res_data + load_idx); +} + +void MergePrefillDecodeOutput( + const paddle::Tensor &encoder_res, + const paddle::Tensor &decoder_res, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &cu_seq_q, + const int head_num, + const int head_dim, + const int max_token) { + + if (head_dim != 128) { + PD_THROW("Only supported head_dim = 128"); + } + const int batch_size = seq_lens_encoder.shape()[0]; + constexpr int warps = 4; + const int tokens_block = (max_token + warps - 1) / warps; + dim3 grid_dims; + grid_dims.x = batch_size; + grid_dims.y = head_num; + grid_dims.z = tokens_block; + + if (encoder_res.dtype() == paddle::DataType::FLOAT16) { + using T = phi::dtype::float16; + FillEncoderDecoderResKernel + <<>>( + const_cast(encoder_res.data()), + const_cast(decoder_res.data()), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + cu_seq_q.data(), + head_num, + head_dim + ); + } else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) { + using T = phi::dtype::bfloat16; + FillEncoderDecoderResKernel + <<>>( + const_cast(encoder_res.data()), + const_cast(decoder_res.data()), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + cu_seq_q.data(), + head_num, + head_dim + ); + } +} + +PD_BUILD_STATIC_OP(merge_prefill_decode_output) + .Inputs({"encoder_res", + "decoder_res", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "cu_seq_q"}) + .Outputs({"res"}) + .Attrs({"head_num: int", + "head_dim: int", + "max_token: int"}) + .SetInplaceMap({{"encoder_res", "res"}}) + .SetKernelFn(PD_KERNEL(MergePrefillDecodeOutput)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 1cb091116..f7c934aa1 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -294,6 +294,7 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", + "gpu_ops/merge_prefill_decode_output.cu", ] # pd_disaggregation diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 306164635..cfcf9ef92 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionMetadata, ) from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, get_block_shape_and_split_kv_block, gqa_rope_write_cache, init_kv_signal_per_query, @@ -46,6 +47,15 @@ from fastdeploy.model_executor.layers.attention.utils import init_rank_and_devic if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output +else: + merge_prefill_decode_output = None + +import os + @dataclass class FlashAttentionMetadata(AttentionMetadata): @@ -61,6 +71,7 @@ class FlashAttentionMetadata(AttentionMetadata): kv_batch_ids: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None + max_len_kv: paddle.Tensor = None cu_seqlens_q: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None @@ -76,6 +87,12 @@ class FlashAttentionMetadata(AttentionMetadata): kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) + _fuse_kernel_compute_dtype: str = "bf16" + _dtype: paddle.dtype = paddle.bfloat16 + + max_len_tensor_cpu: paddle.Tensor = None + max_len_tensor_cpu_decoder: paddle.Tensor = None + class FlashAttentionBackend(AttentionBackend): """ @@ -143,6 +160,11 @@ class FlashAttentionBackend(AttentionBackend): print( "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." ) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) + self.zero_seq_enc_lens_for_decode = paddle.zeros( + shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32 + ) def get_attntion_meta(self): """get_attntion_meta""" @@ -208,7 +230,7 @@ class FlashAttentionBackend(AttentionBackend): ) = pre_cache_len_concat( forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - metadata.set_max_lengths[2], + forward_meta.max_len_tensor_cpu[2], self.block_size, ) @@ -227,6 +249,18 @@ class FlashAttentionBackend(AttentionBackend): metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) + + if metadata._dtype == "bfloat16": + metadata._fuse_kernel_compute_dtype = "bf16" + elif metadata._dtype == "float16": + metadata._fuse_kernel_compute_dtype = "fp16" + elif metadata._dtype == "float32": + metadata._fuse_kernel_compute_dtype = "fp32" + + metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu + metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu) + metadata.max_len_tensor_cpu_decoder[1] = 0 + self.attention_metadata = metadata def forward_mixed( @@ -248,45 +282,112 @@ class FlashAttentionBackend(AttentionBackend): layer.layer_id + self.start_layer_index, ) - q, k, v, _ = gqa_rope_write_cache( + if metadata.max_len_tensor_cpu[1] > 0: + q, k, v, _ = gqa_rope_write_cache( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.rotary_embs, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + metadata.block_tables, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + metadata.pre_cache_batch_ids, + metadata.pre_cache_tile_ids_per_batch, + metadata.pre_cache_num_blocks_cpu, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + metadata.kv_token_num_cpu[0].item(), + self.max_seq_len, + getattr(layer, "cache_quant_type_str", "none"), + ) + + res_encoder = self.flash_attn_func( + q, + k, + v, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + max_seqlen_q=forward_meta.max_len_tensor_cpu[0], + max_seqlen_k=forward_meta.max_len_tensor_cpu[3], + causal=self.causal, + **self.flash_attn_kwargs, + )[0].reshape([-1, self.attn_outputsize_tp]) + + res_decoder = append_attention( qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.rotary_embs, - forward_meta.seq_lens_this_time, - forward_meta.seq_lens_encoder, + self.zero_seq_enc_lens_for_decode, forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, - metadata.pre_cache_batch_ids, - metadata.pre_cache_tile_ids_per_batch, - metadata.pre_cache_num_blocks_cpu, + forward_meta.decoder_batch_ids, # from buffer + forward_meta.decoder_tile_ids_per_batch, # from buffer + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, + metadata.max_len_kv, + metadata.rotary_embs, + forward_meta.attn_mask, + layer.qkv_bias, + layer.qkv_scale, getattr(layer, "cache_k_scale", None), getattr(layer, "cache_v_scale", None), getattr(layer, "cache_k_out_scale", None), getattr(layer, "cache_v_out_scale", None), getattr(layer, "cache_k_zp", None), getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, metadata.kv_signal_data_list[layer.layer_id], - metadata.kv_token_num_cpu[0].item(), - self.max_seq_len, + metadata._fuse_kernel_compute_dtype, getattr(layer, "cache_quant_type_str", "none"), - ) + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.max_partition_size, + self.max_seq_len, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + )[0] - res = self.flash_attn_func( - q, - k, - v, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - max_seqlen_q=forward_meta.max_len_tensor_cpu[0], - max_seqlen_k=forward_meta.max_len_tensor_cpu[3], - causal=self.causal, - **self.flash_attn_kwargs, - )[0].reshape([-1, self.attn_outputsize_tp]) - return res + if metadata.max_len_tensor_cpu[1] > 0: + merge_prefill_decode_output( + res_encoder, + res_decoder, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_heads, + self.head_dim, + self.speculate_max_draft_token_num + 1, + ) + return res_encoder + else: + return res_decoder