集中式支持fa3 (#3112)

This commit is contained in:
yangjianfengo1
2025-08-01 18:03:36 +08:00
committed by GitHub
parent bdb83e007d
commit 64d7a3194d
4 changed files with 257 additions and 25 deletions

View File

@@ -761,6 +761,17 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num, const int encoder_decoder_block_num,
const int max_draft_tokens); const int max_draft_tokens);
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
PYBIND11_MODULE(fastdeploy_ops, m) { PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -1111,4 +1122,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
} }

View File

@@ -0,0 +1,117 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <int warps, typename T>
__global__ void FillEncoderDecoderResKernel(
T * encoder_res_data,
T * decoder_res_data,
const int * seq_lens_encoder,
const int * seq_lens_decoder,
const int * seq_lens_this_time,
const int * cu_seq_q,
const int head_num,
const int head_dim) {
const int bidb = blockIdx.x;
const int bidh = blockIdx.y;
const int bidt = blockIdx.z * warps;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
const int land_id = tid % 32;
const int token_id = bidt + warp_id;
const int seq_len_encoder = seq_lens_encoder[bidb];
const int seq_len_decoder = seq_lens_decoder[bidb];
const int seq_len_this_time = seq_lens_this_time[bidb];
if (seq_len_encoder > 0 || seq_len_decoder == 0 || token_id >= seq_len_this_time) {
return;
}
const int load_idx = ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
*reinterpret_cast<float2*>(encoder_res_data + load_idx) = *reinterpret_cast<float2*>(decoder_res_data + load_idx);
}
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token) {
if (head_dim != 128) {
PD_THROW("Only supported head_dim = 128");
}
const int batch_size = seq_lens_encoder.shape()[0];
constexpr int warps = 4;
const int tokens_block = (max_token + warps - 1) / warps;
dim3 grid_dims;
grid_dims.x = batch_size;
grid_dims.y = head_num;
grid_dims.z = tokens_block;
if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T*>(encoder_res.data<T>()),
const_cast<T*>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim
);
} else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T*>(encoder_res.data<T>()),
const_cast<T*>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim
);
}
}
PD_BUILD_STATIC_OP(merge_prefill_decode_output)
.Inputs({"encoder_res",
"decoder_res",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seq_q"})
.Outputs({"res"})
.Attrs({"head_num: int",
"head_dim: int",
"max_token: int"})
.SetInplaceMap({{"encoder_res", "res"}})
.SetKernelFn(PD_KERNEL(MergePrefillDecodeOutput));

View File

@@ -294,6 +294,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/merge_prefill_decode_output.cu",
] ]
# pd_disaggregation # pd_disaggregation

View File

@@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionMetadata, AttentionMetadata,
) )
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block, get_block_shape_and_split_kv_block,
gqa_rope_write_cache, gqa_rope_write_cache,
init_kv_signal_per_query, init_kv_signal_per_query,
@@ -46,6 +47,15 @@ from fastdeploy.model_executor.layers.attention.utils import init_rank_and_devic
if TYPE_CHECKING: if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else:
merge_prefill_decode_output = None
import os
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadata): class FlashAttentionMetadata(AttentionMetadata):
@@ -61,6 +71,7 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_batch_ids: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
cu_seqlens_q: paddle.Tensor = None cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None
@@ -76,6 +87,12 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
""" """
@@ -143,6 +160,11 @@ class FlashAttentionBackend(AttentionBackend):
print( print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32
)
def get_attntion_meta(self): def get_attntion_meta(self):
"""get_attntion_meta""" """get_attntion_meta"""
@@ -208,7 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
) = pre_cache_len_concat( ) = pre_cache_len_concat(
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
metadata.set_max_lengths[2], forward_meta.max_len_tensor_cpu[2],
self.block_size, self.block_size,
) )
@@ -227,6 +249,18 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_signal_metadata = open_shm_and_get_meta_signal( metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag self.rank, int(self.device_id), self.keep_pd_step_flag
) )
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata self.attention_metadata = metadata
def forward_mixed( def forward_mixed(
@@ -248,6 +282,7 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index, layer.layer_id + self.start_layer_index,
) )
if metadata.max_len_tensor_cpu[1] > 0:
q, k, v, _ = gqa_rope_write_cache( q, k, v, _ = gqa_rope_write_cache(
qkv, qkv,
forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id],
@@ -278,7 +313,7 @@ class FlashAttentionBackend(AttentionBackend):
getattr(layer, "cache_quant_type_str", "none"), getattr(layer, "cache_quant_type_str", "none"),
) )
res = self.flash_attn_func( res_encoder = self.flash_attn_func(
q, q,
k, k,
v, v,
@@ -289,4 +324,70 @@ class FlashAttentionBackend(AttentionBackend):
causal=self.causal, causal=self.causal,
**self.flash_attn_kwargs, **self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp]) )[0].reshape([-1, self.attn_outputsize_tp])
return res
res_decoder = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
self.zero_seq_enc_lens_for_decode,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.max_partition_size,
self.max_seq_len,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
if metadata.max_len_tensor_cpu[1] > 0:
merge_prefill_decode_output(
res_encoder,
res_decoder,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads,
self.head_dim,
self.speculate_max_draft_token_num + 1,
)
return res_encoder
else:
return res_decoder