[Others] Remove useless code (#5404)

This commit is contained in:
周周周
2025-12-08 13:59:46 +08:00
committed by GitHub
parent 3066a0c34b
commit 2aea8a3a60
8 changed files with 139 additions and 166 deletions

View File

@@ -28,7 +28,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const float *k_norm_weight,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int *cu_seqlens_k,
T *qkv_out,
@@ -38,8 +38,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const int max_model_len,
const int head_dim,
const bool rope_3d,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
@@ -53,30 +53,33 @@ __global__ void GQAVariableLengthRotarySplitKernel(
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
const int half_headdim = head_dim / 2;
const int offset =
(q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v
const int all_head_num = elem_cnt / last_dim;
(q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v
const int all_head_num = elem_cnt / head_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num;
gloabl_hi += all_warp_num) {
int64_t linear_index =
gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index
gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index
const int token_idx =
linear_index / offset; // token id(第几个token,不分qkv)
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
if (seq_lens[ori_bi] == 0) continue;
int cache_kv_len = seq_lens_decoder[ori_bi];
// 这里其实是不需要处理的但是由于FA3的bug所以必须
if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int hi = bias / head_dim;
const int h_bias = bias % head_dim;
const int ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) +
seq_lens_decoder
[ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
const int64_t emb_idx =
ori_seq_id * half_lastdim + h_bias / 2; // embedding的id
ori_seq_id * half_headdim + h_bias / 2; // embedding的id
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
@@ -84,21 +87,21 @@ __global__ void GQAVariableLengthRotarySplitKernel(
T *out_p = nullptr;
if (hi < q_num_head) {
base_split_idx =
token_idx * q_num_head * last_dim + hi * last_dim + h_bias;
token_idx * q_num_head * head_dim + hi * head_dim + h_bias;
out_p = q;
} else if (hi < q_num_head + kv_num_head) {
base_split_idx = kv_write_idx * kv_num_head * last_dim +
(hi - q_num_head) * last_dim + h_bias;
base_split_idx = kv_write_idx * kv_num_head * head_dim +
(hi - q_num_head) * head_dim + h_bias;
out_p = k;
} else {
out_p = v;
base_split_idx = kv_write_idx * kv_num_head * last_dim +
(hi - q_num_head - kv_num_head) * last_dim + h_bias;
base_split_idx = kv_write_idx * kv_num_head * head_dim +
(hi - q_num_head - kv_num_head) * head_dim + h_bias;
}
// TODO check this correct or not
int64_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
rope_3d ? emb_idx + ori_bi * head_dim * max_model_len : emb_idx;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
@@ -122,7 +125,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); // 单个head的标准差
if (hi < q_num_head + kv_num_head) { // only q and k need norm
float row_variance = max(warp_m2 / last_dim, 0.0f);
float row_variance = max(warp_m2 / head_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < q_num_head) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
@@ -165,12 +168,12 @@ __global__ void GQAVariableLengthRotarySplitKernel(
template <typename T>
void gqa_rotary_qk_split_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
T *qkv_out, // [token_num, 3, num_head, head_dim]
T *q,
T *k,
T *v,
const T *qkv_input,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const float *rotary_emb, // [2, 1, 1, seq_len, head_dim / 2]
const float *q_norm_weight,
const float *k_norm_weight,
const int *batch_id_per_token,
@@ -181,14 +184,14 @@ void gqa_rotary_qk_split_variable(
const int token_num,
const int num_heads,
const int kv_num_heads,
const int seq_len,
const int max_model_len,
const int input_output_len,
const int dim_head,
const int head_dim,
const bool rope_3d,
const float rms_norm_eps,
const cudaStream_t &stream) {
assert(dim_head == 128 && "dim_head must be 128");
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
assert(head_dim == 128 && "head_dim must be 128");
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
@@ -199,7 +202,7 @@ void gqa_rotary_qk_split_variable(
dim3 block_size(kWarpSize, blocksize / kWarpSize);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
const float *sin_emb = rotary_emb + input_output_len * head_dim / 2;
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
grid_size,
block_size,
@@ -222,8 +225,8 @@ void gqa_rotary_qk_split_variable(
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
max_model_len,
head_dim,
rope_3d,
rms_norm_eps);
}
@@ -1163,9 +1166,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
meta_data.block_size = block_size;
meta_data.batch_size = seq_lens_this_time.dims()[0];
phi::GPUContext *dev_ctx = static_cast<phi::GPUContext *>(
phi::DeviceContextPool::Instance().Get(qkv.place()));
auto stream = qkv.stream();
paddle::Tensor qkv_out = GetEmptyTensor(qkv.dims(), qkv.dtype(), qkv.place());
paddle::Tensor q = GetEmptyTensor(

View File

@@ -16,25 +16,26 @@
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
int* __restrict__ cu_seqlens_k,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
int* __restrict__ kv_token_num,
const int bsz,
const int num_row_per_block) {
__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_encoder,
const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
int* __restrict__ cu_seqlens_k,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
int* __restrict__ kv_token_num,
const int bsz,
const int num_row_per_block) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
int total_tokens = 0;
cu_seqlens_k[0] = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int cache_len = seq_lens_decoder[bid];
const int q_len = seq_lens_this_time[bid];
if (q_len <= 0) {
cache_len = 0;
int cache_len = 0;
if (seq_lens_encoder[bid] > 0) {
// only deal with chunked prefill case.
cache_len = seq_lens_decoder[bid];
}
const int loop_times = div_up(cache_len, num_row_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
@@ -42,6 +43,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
tile_ids_per_batch[index++] = tile_id;
}
gridx += loop_times;
const int q_len = seq_lens_this_time[bid];
total_tokens += (cache_len + q_len);
cu_seqlens_k[bid + 1] = total_tokens;
}
@@ -51,6 +53,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
}
std::vector<paddle::Tensor> PreCacheLenConcat(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int max_dec_len,
@@ -58,45 +61,43 @@ std::vector<paddle::Tensor> PreCacheLenConcat(
auto stream = seq_lens_decoder.stream();
auto place = seq_lens_decoder.place();
int bsz = seq_lens_this_time.shape()[0];
const uint32_t max_tile_size_per_bs_pre_cache = div_up(max_dec_len, block_size);
const uint32_t max_tile_size_per_bs_pre_cache =
div_up(max_dec_len, block_size);
paddle::Tensor cu_seqlens_k = GetEmptyTensor(
{bsz + 1},
paddle::DataType::INT32,
place);
paddle::Tensor cu_seqlens_k =
GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, place);
paddle::Tensor pre_cache_batch_ids = GetEmptyTensor(
{bsz * max_tile_size_per_bs_pre_cache},
paddle::DataType::INT32,
place);
{bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place);
paddle::Tensor pre_cache_tile_ids_per_batch = GetEmptyTensor(
{bsz * max_tile_size_per_bs_pre_cache},
paddle::DataType::INT32,
place);
{bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place);
paddle::Tensor pre_cache_num_blocks =
GetEmptyTensor({1}, paddle::DataType::INT32, place);
GetEmptyTensor({1}, paddle::DataType::INT32, place);
paddle::Tensor kv_token_num =
GetEmptyTensor({1}, paddle::DataType::INT32, place);
GetEmptyTensor({1}, paddle::DataType::INT32, place);
pre_cache_len_concat<<<1, 32, 0, stream>>>(
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seqlens_k.data<int>(),
pre_cache_batch_ids.data<int>(),
pre_cache_tile_ids_per_batch.data<int>(),
pre_cache_num_blocks.data<int>(),
kv_token_num.data<int>(),
bsz,
block_size
);
paddle::Tensor pre_cache_num_blocks_cpu = pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
paddle::Tensor kv_token_num_cpu = kv_token_num.copy_to(paddle::CPUPlace(), false);
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seqlens_k.data<int>(),
pre_cache_batch_ids.data<int>(),
pre_cache_tile_ids_per_batch.data<int>(),
pre_cache_num_blocks.data<int>(),
kv_token_num.data<int>(),
bsz,
block_size);
paddle::Tensor pre_cache_num_blocks_cpu =
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
paddle::Tensor kv_token_num_cpu =
kv_token_num.copy_to(paddle::CPUPlace(), false);
return {cu_seqlens_k,
pre_cache_batch_ids,
pre_cache_tile_ids_per_batch,
pre_cache_num_blocks_cpu, /*cpu*/
kv_token_num_cpu /*cpu*/
};
return {
cu_seqlens_k,
pre_cache_batch_ids,
pre_cache_tile_ids_per_batch,
pre_cache_num_blocks_cpu, /*cpu*/
kv_token_num_cpu /*cpu*/
};
}
std::vector<paddle::DataType> PreCacheLenConcatInferDtype(
@@ -121,15 +122,13 @@ std::vector<std::vector<int64_t>> PreCacheLenConcatInferShape(
}
PD_BUILD_STATIC_OP(pre_cache_len_concat)
.Inputs({"seq_lens_decoder",
"seq_lens_this_time"})
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
.Outputs({"cu_seqlens_k",
"pre_cache_batch_ids",
"pre_cache_tile_ids_per_batch",
"pre_cache_num_blocks_cpu", /*cpu*/
"kv_token_num_cpu"}) /*cpu*/
.Attrs({"max_dec_len: int",
"block_size: int"})
"kv_token_num_cpu"}) /*cpu*/
.Attrs({"max_dec_len: int", "block_size: int"})
.SetKernelFn(PD_KERNEL(PreCacheLenConcat))
.SetInferShapeFn(PD_INFER_SHAPE(PreCacheLenConcatInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PreCacheLenConcatInferDtype));

View File

@@ -194,6 +194,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const bool rope_3d);
std::vector<paddle::Tensor> PreCacheLenConcat(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int max_dec_len,

View File

@@ -206,20 +206,9 @@ class AppendAttentionBackend(AttentionBackend):
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
key_cache_shape[-1] = self.head_dim // 2
value_cache_shape = key_cache_shape
return key_cache_shape, value_cache_shape
def forward_mixed(

View File

@@ -63,13 +63,7 @@ class FlashAttentionMetadata(AttentionMetadata):
FlashAttentionMetadata
"""
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
@@ -83,7 +77,6 @@ class FlashAttentionMetadata(AttentionMetadata):
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
@@ -133,9 +126,6 @@ class FlashAttentionBackend(AttentionBackend):
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.flash_attn_func is None:
@@ -154,7 +144,8 @@ class FlashAttentionBackend(AttentionBackend):
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
# Note(ZKK): here must be consistent with append_attn_backend.py
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
)
@@ -172,27 +163,13 @@ class FlashAttentionBackend(AttentionBackend):
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
key_cache_shape[-1] = self.head_dim // 2
value_cache_shape = key_cache_shape
return key_cache_shape, value_cache_shape
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
@@ -215,18 +192,20 @@ class FlashAttentionBackend(AttentionBackend):
self.block_size,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
if forward_meta.max_len_tensor_cpu[1] > 0:
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -251,8 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata
@@ -276,19 +254,21 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
if metadata.max_len_tensor_cpu[1] > 0:
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
if use_fa_do_prefill:
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
forward_meta.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
@@ -315,7 +295,7 @@ class FlashAttentionBackend(AttentionBackend):
q,
k,
v,
metadata.cu_seqlens_q,
forward_meta.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
@@ -327,23 +307,23 @@ class FlashAttentionBackend(AttentionBackend):
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
self.zero_seq_enc_lens_for_decode,
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
metadata.max_len_tensor_cpu_decoder,
metadata.rotary_embs,
metadata.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
@@ -378,7 +358,7 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_method is not None,
)
if metadata.max_len_tensor_cpu[1] > 0:
if use_fa_do_prefill:
merge_prefill_decode_output(
res_encoder,
res_decoder,

View File

@@ -24,6 +24,7 @@ from fastdeploy.platforms import current_platform
def pre_cache_len_concat(
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
max_dec_len: int = 0,
@@ -32,7 +33,7 @@ def pre_cache_len_concat(
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size)
out = pre_cache_len_concat(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size)
return out
else:
raise NotImplementedError

View File

@@ -71,7 +71,6 @@ class TestAttentionPerformance(unittest.TestCase):
self.fd_config.parallel_config.tp_group = [0]
# Initialize Attention Layer
os.environ["FD_ATTENTION_BACKEND"] = "APPEND_ATTN"
attn_cls = get_attention_backend()
self.attn_backend = attn_cls(
self.fd_config,
@@ -123,10 +122,10 @@ class TestAttentionPerformance(unittest.TestCase):
"max_position_embeddings": 131072,
"max_model_len": 131072,
"head_dim": 128,
"hidden_size": 4096,
"num_attention_heads": 32,
"num_key_value_heads": 4,
"num_hidden_layers": 57,
"hidden_size": 8192,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"num_hidden_layers": 2,
}
model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
config_path = os.path.join(model_dir, "config.json")
@@ -158,6 +157,7 @@ class TestAttentionPerformance(unittest.TestCase):
dense_quant_type="block_wise_fp8",
moe_quant_type="block_wise_fp8",
kv_cache_quant_type="float8_e4m3fn",
# kv_cache_quant_type=None,
),
graph_opt_config=GraphOptimizationConfig({}),
commit_config=CommitConfig(),
@@ -270,7 +270,7 @@ class TestAttentionPerformance(unittest.TestCase):
partial_rotary_factor=fd_config.model_config.partial_rotary_factor,
)
input_ids = paddle.zeros([batch_size, seq_len if mode == ForwardMode.EXTEND else 1], dtype="int64")
input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
token_num = paddle.sum(seq_lens_this_time)
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, token_num, seq_lens_this_time
@@ -294,12 +294,13 @@ class TestAttentionPerformance(unittest.TestCase):
attn_mask_offsets=None,
**attn_backend_buffers,
)
return forward_meta
hidden_states = paddle.randn([token_num, self.fd_config.model_config.hidden_size], dtype="bfloat16")
return forward_meta, hidden_states
def test_decode_performance_with_prefill(self):
# Test parameters
test_steps = 100
act_tensor_dtype = paddle.bfloat16
# prefill_batch_size = 1
# prefill_seq_len = 4096
@@ -356,11 +357,7 @@ class TestAttentionPerformance(unittest.TestCase):
# p.step()
for decode_batch_size in [32, 16, 8, 4, 2]:
decode_hidden_states = paddle.randn(
[decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype
)
forward_meta = self.create_forward_meta(
forward_meta, hidden_states = self.create_forward_meta(
batch_size=decode_batch_size,
seq_len=36 * 1024,
mode=ForwardMode.DECODE,
@@ -374,12 +371,12 @@ class TestAttentionPerformance(unittest.TestCase):
paddle.device.synchronize()
# 必须要先预热一次!因为预处理被放到了第一层再做了!
self.attn_forward(forward_meta, decode_hidden_states)
self.attn_forward(forward_meta, hidden_states)
attn_cuda_graphs = graphs.CUDAGraph()
attn_cuda_graphs.capture_begin()
self.attn_forward(forward_meta, decode_hidden_states)
self.attn_forward(forward_meta, hidden_states)
attn_cuda_graphs.capture_end()

View File

@@ -69,7 +69,10 @@ class TestPreCacheLenConcat(unittest.TestCase):
seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32")
seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32")
outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size)
seq_lens_encoder_t = seq_lens_this_time_t
outputs = pre_cache_len_concat(
seq_lens_encoder_t, seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size
)
cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs]
# Shape checks
@@ -91,8 +94,11 @@ class TestPreCacheLenConcat(unittest.TestCase):
seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32")
seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32")
seq_lens_encoder_t = seq_lens_this_time_t
outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size)
outputs = pre_cache_len_concat(
seq_lens_encoder_t, seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size
)
cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs]
# Reference implementation