mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] Remove useless code (#5404)
This commit is contained in:
@@ -28,7 +28,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const float *k_norm_weight,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_k,
|
||||
T *qkv_out,
|
||||
@@ -38,8 +38,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim,
|
||||
const int max_model_len,
|
||||
const int head_dim,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
@@ -53,30 +53,33 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
LoadFloat q_norm_vec, k_norm_vec;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int half_headdim = head_dim / 2;
|
||||
const int offset =
|
||||
(q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
(q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v
|
||||
const int all_head_num = elem_cnt / head_dim;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num;
|
||||
gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index =
|
||||
gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index
|
||||
gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index
|
||||
const int token_idx =
|
||||
linear_index / offset; // token id(第几个token,不分qkv)
|
||||
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
|
||||
int cache_kv_len = seq_lens_decoder[ori_bi];
|
||||
// 这里其实是不需要处理的,但是由于FA3的bug,所以必须!
|
||||
if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0;
|
||||
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
const int hi = bias / head_dim;
|
||||
const int h_bias = bias % head_dim;
|
||||
|
||||
const int ori_seq_id =
|
||||
(token_idx - cu_seqlens_q[ori_bi]) +
|
||||
seq_lens_decoder
|
||||
[ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
|
||||
cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
|
||||
const int64_t emb_idx =
|
||||
ori_seq_id * half_lastdim + h_bias / 2; // embedding的id
|
||||
ori_seq_id * half_headdim + h_bias / 2; // embedding的id
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
@@ -84,21 +87,21 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
T *out_p = nullptr;
|
||||
if (hi < q_num_head) {
|
||||
base_split_idx =
|
||||
token_idx * q_num_head * last_dim + hi * last_dim + h_bias;
|
||||
token_idx * q_num_head * head_dim + hi * head_dim + h_bias;
|
||||
out_p = q;
|
||||
} else if (hi < q_num_head + kv_num_head) {
|
||||
base_split_idx = kv_write_idx * kv_num_head * last_dim +
|
||||
(hi - q_num_head) * last_dim + h_bias;
|
||||
base_split_idx = kv_write_idx * kv_num_head * head_dim +
|
||||
(hi - q_num_head) * head_dim + h_bias;
|
||||
out_p = k;
|
||||
} else {
|
||||
out_p = v;
|
||||
base_split_idx = kv_write_idx * kv_num_head * last_dim +
|
||||
(hi - q_num_head - kv_num_head) * last_dim + h_bias;
|
||||
base_split_idx = kv_write_idx * kv_num_head * head_dim +
|
||||
(hi - q_num_head - kv_num_head) * head_dim + h_bias;
|
||||
}
|
||||
|
||||
// TODO check this correct or not
|
||||
int64_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
rope_3d ? emb_idx + ori_bi * head_dim * max_model_len : emb_idx;
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
@@ -122,7 +125,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); // 单个head的标准差
|
||||
|
||||
if (hi < q_num_head + kv_num_head) { // only q and k need norm
|
||||
float row_variance = max(warp_m2 / last_dim, 0.0f);
|
||||
float row_variance = max(warp_m2 / head_dim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < q_num_head) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
|
||||
@@ -165,12 +168,12 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
|
||||
template <typename T>
|
||||
void gqa_rotary_qk_split_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
T *qkv_out, // [token_num, 3, num_head, head_dim]
|
||||
T *q,
|
||||
T *k,
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, head_dim / 2]
|
||||
const float *q_norm_weight,
|
||||
const float *k_norm_weight,
|
||||
const int *batch_id_per_token,
|
||||
@@ -181,14 +184,14 @@ void gqa_rotary_qk_split_variable(
|
||||
const int token_num,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int seq_len,
|
||||
const int max_model_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int head_dim,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps,
|
||||
const cudaStream_t &stream) {
|
||||
assert(dim_head == 128 && "dim_head must be 128");
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
assert(head_dim == 128 && "head_dim must be 128");
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim;
|
||||
|
||||
constexpr int HEAD_DIM = 128;
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
@@ -199,7 +202,7 @@ void gqa_rotary_qk_split_variable(
|
||||
dim3 block_size(kWarpSize, blocksize / kWarpSize);
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
const float *sin_emb = rotary_emb + input_output_len * head_dim / 2;
|
||||
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
|
||||
grid_size,
|
||||
block_size,
|
||||
@@ -222,8 +225,8 @@ void gqa_rotary_qk_split_variable(
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
max_model_len,
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
}
|
||||
@@ -1163,9 +1166,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
meta_data.block_size = block_size;
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
phi::GPUContext *dev_ctx = static_cast<phi::GPUContext *>(
|
||||
phi::DeviceContextPool::Instance().Get(qkv.place()));
|
||||
|
||||
auto stream = qkv.stream();
|
||||
paddle::Tensor qkv_out = GetEmptyTensor(qkv.dims(), qkv.dtype(), qkv.place());
|
||||
paddle::Tensor q = GetEmptyTensor(
|
||||
|
||||
@@ -16,25 +16,26 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
|
||||
__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
|
||||
const int* __restrict__ seq_lens_this_time,
|
||||
int* __restrict__ cu_seqlens_k,
|
||||
int* __restrict__ batch_ids,
|
||||
int* __restrict__ tile_ids_per_batch,
|
||||
int* __restrict__ num_blocks_x,
|
||||
int* __restrict__ kv_token_num,
|
||||
const int bsz,
|
||||
const int num_row_per_block) {
|
||||
__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ seq_lens_decoder,
|
||||
const int* __restrict__ seq_lens_this_time,
|
||||
int* __restrict__ cu_seqlens_k,
|
||||
int* __restrict__ batch_ids,
|
||||
int* __restrict__ tile_ids_per_batch,
|
||||
int* __restrict__ num_blocks_x,
|
||||
int* __restrict__ kv_token_num,
|
||||
const int bsz,
|
||||
const int num_row_per_block) {
|
||||
if (threadIdx.x == 0) {
|
||||
int gridx = 0;
|
||||
int index = 0;
|
||||
int total_tokens = 0;
|
||||
cu_seqlens_k[0] = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int cache_len = seq_lens_decoder[bid];
|
||||
const int q_len = seq_lens_this_time[bid];
|
||||
if (q_len <= 0) {
|
||||
cache_len = 0;
|
||||
int cache_len = 0;
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
// only deal with chunked prefill case.
|
||||
cache_len = seq_lens_decoder[bid];
|
||||
}
|
||||
const int loop_times = div_up(cache_len, num_row_per_block);
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
@@ -42,6 +43,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
gridx += loop_times;
|
||||
const int q_len = seq_lens_this_time[bid];
|
||||
total_tokens += (cache_len + q_len);
|
||||
cu_seqlens_k[bid + 1] = total_tokens;
|
||||
}
|
||||
@@ -51,6 +53,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PreCacheLenConcat(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const int max_dec_len,
|
||||
@@ -58,45 +61,43 @@ std::vector<paddle::Tensor> PreCacheLenConcat(
|
||||
auto stream = seq_lens_decoder.stream();
|
||||
auto place = seq_lens_decoder.place();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
const uint32_t max_tile_size_per_bs_pre_cache = div_up(max_dec_len, block_size);
|
||||
const uint32_t max_tile_size_per_bs_pre_cache =
|
||||
div_up(max_dec_len, block_size);
|
||||
|
||||
paddle::Tensor cu_seqlens_k = GetEmptyTensor(
|
||||
{bsz + 1},
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
paddle::Tensor cu_seqlens_k =
|
||||
GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor pre_cache_batch_ids = GetEmptyTensor(
|
||||
{bsz * max_tile_size_per_bs_pre_cache},
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
{bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor pre_cache_tile_ids_per_batch = GetEmptyTensor(
|
||||
{bsz * max_tile_size_per_bs_pre_cache},
|
||||
paddle::DataType::INT32,
|
||||
place);
|
||||
{bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor pre_cache_num_blocks =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, place);
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor kv_token_num =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, place);
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, place);
|
||||
|
||||
pre_cache_len_concat<<<1, 32, 0, stream>>>(
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
pre_cache_batch_ids.data<int>(),
|
||||
pre_cache_tile_ids_per_batch.data<int>(),
|
||||
pre_cache_num_blocks.data<int>(),
|
||||
kv_token_num.data<int>(),
|
||||
bsz,
|
||||
block_size
|
||||
);
|
||||
paddle::Tensor pre_cache_num_blocks_cpu = pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
|
||||
paddle::Tensor kv_token_num_cpu = kv_token_num.copy_to(paddle::CPUPlace(), false);
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
pre_cache_batch_ids.data<int>(),
|
||||
pre_cache_tile_ids_per_batch.data<int>(),
|
||||
pre_cache_num_blocks.data<int>(),
|
||||
kv_token_num.data<int>(),
|
||||
bsz,
|
||||
block_size);
|
||||
paddle::Tensor pre_cache_num_blocks_cpu =
|
||||
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
|
||||
paddle::Tensor kv_token_num_cpu =
|
||||
kv_token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
return {cu_seqlens_k,
|
||||
pre_cache_batch_ids,
|
||||
pre_cache_tile_ids_per_batch,
|
||||
pre_cache_num_blocks_cpu, /*cpu*/
|
||||
kv_token_num_cpu /*cpu*/
|
||||
};
|
||||
return {
|
||||
cu_seqlens_k,
|
||||
pre_cache_batch_ids,
|
||||
pre_cache_tile_ids_per_batch,
|
||||
pre_cache_num_blocks_cpu, /*cpu*/
|
||||
kv_token_num_cpu /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PreCacheLenConcatInferDtype(
|
||||
@@ -121,15 +122,13 @@ std::vector<std::vector<int64_t>> PreCacheLenConcatInferShape(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(pre_cache_len_concat)
|
||||
.Inputs({"seq_lens_decoder",
|
||||
"seq_lens_this_time"})
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
||||
.Outputs({"cu_seqlens_k",
|
||||
"pre_cache_batch_ids",
|
||||
"pre_cache_tile_ids_per_batch",
|
||||
"pre_cache_num_blocks_cpu", /*cpu*/
|
||||
"kv_token_num_cpu"}) /*cpu*/
|
||||
.Attrs({"max_dec_len: int",
|
||||
"block_size: int"})
|
||||
"kv_token_num_cpu"}) /*cpu*/
|
||||
.Attrs({"max_dec_len: int", "block_size: int"})
|
||||
.SetKernelFn(PD_KERNEL(PreCacheLenConcat))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PreCacheLenConcatInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PreCacheLenConcatInferDtype));
|
||||
|
||||
@@ -194,6 +194,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const bool rope_3d);
|
||||
|
||||
std::vector<paddle::Tensor> PreCacheLenConcat(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const int max_dec_len,
|
||||
|
||||
@@ -206,20 +206,9 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
Calculate kv cache shape
|
||||
"""
|
||||
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
|
||||
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
key_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
]
|
||||
value_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
]
|
||||
key_cache_shape[-1] = self.head_dim // 2
|
||||
value_cache_shape = key_cache_shape
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def forward_mixed(
|
||||
|
||||
@@ -63,13 +63,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
@@ -83,7 +77,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
|
||||
max_len_tensor_cpu: paddle.Tensor = None
|
||||
max_len_tensor_cpu_decoder: paddle.Tensor = None
|
||||
|
||||
|
||||
@@ -133,9 +126,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
if self.flash_attn_func is None:
|
||||
@@ -154,7 +144,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
|
||||
# Note(ZKK): here must be consistent with append_attn_backend.py
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
|
||||
self.zero_seq_enc_lens_for_decode = paddle.zeros(
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
|
||||
)
|
||||
@@ -172,27 +163,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
Calculate kv cache shape
|
||||
"""
|
||||
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
|
||||
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
key_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
]
|
||||
value_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
]
|
||||
key_cache_shape[-1] = self.head_dim // 2
|
||||
value_cache_shape = key_cache_shape
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
@@ -215,18 +192,20 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.max_len_tensor_cpu[2],
|
||||
self.block_size,
|
||||
)
|
||||
if forward_meta.max_len_tensor_cpu[1] > 0:
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.max_len_tensor_cpu[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
@@ -251,8 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
self.attention_metadata = metadata
|
||||
@@ -276,19 +254,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
|
||||
if metadata.max_len_tensor_cpu[1] > 0:
|
||||
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
|
||||
|
||||
if use_fa_do_prefill:
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.block_tables,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
@@ -315,7 +295,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||
@@ -327,23 +307,23 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
self.zero_seq_enc_lens_for_decode,
|
||||
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
forward_meta.block_tables,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids, # from buffer
|
||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_len_tensor_cpu_decoder,
|
||||
metadata.rotary_embs,
|
||||
metadata.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
@@ -378,7 +358,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.speculative_method is not None,
|
||||
)
|
||||
|
||||
if metadata.max_len_tensor_cpu[1] > 0:
|
||||
if use_fa_do_prefill:
|
||||
merge_prefill_decode_output(
|
||||
res_encoder,
|
||||
res_decoder,
|
||||
|
||||
@@ -24,6 +24,7 @@ from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def pre_cache_len_concat(
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
max_dec_len: int = 0,
|
||||
@@ -32,7 +33,7 @@ def pre_cache_len_concat(
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
|
||||
|
||||
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size)
|
||||
out = pre_cache_len_concat(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -71,7 +71,6 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
self.fd_config.parallel_config.tp_group = [0]
|
||||
|
||||
# Initialize Attention Layer
|
||||
os.environ["FD_ATTENTION_BACKEND"] = "APPEND_ATTN"
|
||||
attn_cls = get_attention_backend()
|
||||
self.attn_backend = attn_cls(
|
||||
self.fd_config,
|
||||
@@ -123,10 +122,10 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
"max_position_embeddings": 131072,
|
||||
"max_model_len": 131072,
|
||||
"head_dim": 128,
|
||||
"hidden_size": 4096,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 4,
|
||||
"num_hidden_layers": 57,
|
||||
"hidden_size": 8192,
|
||||
"num_attention_heads": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"num_hidden_layers": 2,
|
||||
}
|
||||
model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
|
||||
config_path = os.path.join(model_dir, "config.json")
|
||||
@@ -158,6 +157,7 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
dense_quant_type="block_wise_fp8",
|
||||
moe_quant_type="block_wise_fp8",
|
||||
kv_cache_quant_type="float8_e4m3fn",
|
||||
# kv_cache_quant_type=None,
|
||||
),
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
commit_config=CommitConfig(),
|
||||
@@ -270,7 +270,7 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
partial_rotary_factor=fd_config.model_config.partial_rotary_factor,
|
||||
)
|
||||
|
||||
input_ids = paddle.zeros([batch_size, seq_len if mode == ForwardMode.EXTEND else 1], dtype="int64")
|
||||
input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
|
||||
input_ids, token_num, seq_lens_this_time
|
||||
@@ -294,12 +294,13 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
attn_mask_offsets=None,
|
||||
**attn_backend_buffers,
|
||||
)
|
||||
return forward_meta
|
||||
|
||||
hidden_states = paddle.randn([token_num, self.fd_config.model_config.hidden_size], dtype="bfloat16")
|
||||
return forward_meta, hidden_states
|
||||
|
||||
def test_decode_performance_with_prefill(self):
|
||||
# Test parameters
|
||||
test_steps = 100
|
||||
act_tensor_dtype = paddle.bfloat16
|
||||
|
||||
# prefill_batch_size = 1
|
||||
# prefill_seq_len = 4096
|
||||
@@ -356,11 +357,7 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
# p.step()
|
||||
|
||||
for decode_batch_size in [32, 16, 8, 4, 2]:
|
||||
decode_hidden_states = paddle.randn(
|
||||
[decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype
|
||||
)
|
||||
|
||||
forward_meta = self.create_forward_meta(
|
||||
forward_meta, hidden_states = self.create_forward_meta(
|
||||
batch_size=decode_batch_size,
|
||||
seq_len=36 * 1024,
|
||||
mode=ForwardMode.DECODE,
|
||||
@@ -374,12 +371,12 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
paddle.device.synchronize()
|
||||
|
||||
# 必须要先预热一次!因为预处理被放到了第一层再做了!
|
||||
self.attn_forward(forward_meta, decode_hidden_states)
|
||||
self.attn_forward(forward_meta, hidden_states)
|
||||
|
||||
attn_cuda_graphs = graphs.CUDAGraph()
|
||||
attn_cuda_graphs.capture_begin()
|
||||
|
||||
self.attn_forward(forward_meta, decode_hidden_states)
|
||||
self.attn_forward(forward_meta, hidden_states)
|
||||
|
||||
attn_cuda_graphs.capture_end()
|
||||
|
||||
|
||||
@@ -69,7 +69,10 @@ class TestPreCacheLenConcat(unittest.TestCase):
|
||||
seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32")
|
||||
seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32")
|
||||
|
||||
outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size)
|
||||
seq_lens_encoder_t = seq_lens_this_time_t
|
||||
outputs = pre_cache_len_concat(
|
||||
seq_lens_encoder_t, seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size
|
||||
)
|
||||
cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs]
|
||||
|
||||
# Shape checks
|
||||
@@ -91,8 +94,11 @@ class TestPreCacheLenConcat(unittest.TestCase):
|
||||
|
||||
seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32")
|
||||
seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32")
|
||||
seq_lens_encoder_t = seq_lens_this_time_t
|
||||
|
||||
outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size)
|
||||
outputs = pre_cache_len_concat(
|
||||
seq_lens_encoder_t, seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size
|
||||
)
|
||||
cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs]
|
||||
|
||||
# Reference implementation
|
||||
|
||||
Reference in New Issue
Block a user