diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 53b7e6266..ffedbe60a 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -28,7 +28,7 @@ __global__ void GQAVariableLengthRotarySplitKernel( const float *k_norm_weight, const int *batch_id_per_token, const int *cu_seqlens_q, - const int *seq_lens, + const int *seq_lens_encoder, const int *seq_lens_decoder, const int *cu_seqlens_k, T *qkv_out, @@ -38,8 +38,8 @@ __global__ void GQAVariableLengthRotarySplitKernel( const int64_t elem_cnt, const int q_num_head, const int kv_num_head, - const int seq_len, - const int last_dim, + const int max_model_len, + const int head_dim, const bool rope_3d, const float rms_norm_eps) { using LoadT = AlignedVector; @@ -53,30 +53,33 @@ __global__ void GQAVariableLengthRotarySplitKernel( LoadFloat q_norm_vec, k_norm_vec; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; int64_t all_warp_num = gridDim.x * blockDim.y; - const int half_lastdim = last_dim / 2; + const int half_headdim = head_dim / 2; const int offset = - (q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v - const int all_head_num = elem_cnt / last_dim; + (q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v + const int all_head_num = elem_cnt / head_dim; for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) { int64_t linear_index = - gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index + gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index const int token_idx = linear_index / offset; // token id(第几个token,不分qkv) const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch - if (seq_lens[ori_bi] == 0) continue; + + int cache_kv_len = seq_lens_decoder[ori_bi]; + // 这里其实是不需要处理的,但是由于FA3的bug,所以必须! + if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0; + const int bias = linear_index % offset; - const int hi = bias / last_dim; - const int h_bias = bias % last_dim; + const int hi = bias / head_dim; + const int h_bias = bias % head_dim; const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + - seq_lens_decoder - [ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) + cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) const int64_t emb_idx = - ori_seq_id * half_lastdim + h_bias / 2; // embedding的id + ori_seq_id * half_headdim + h_bias / 2; // embedding的id const int64_t base_idx = - token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + h_bias; Load(&qkv[base_idx], &src_vec); const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; @@ -84,21 +87,21 @@ __global__ void GQAVariableLengthRotarySplitKernel( T *out_p = nullptr; if (hi < q_num_head) { base_split_idx = - token_idx * q_num_head * last_dim + hi * last_dim + h_bias; + token_idx * q_num_head * head_dim + hi * head_dim + h_bias; out_p = q; } else if (hi < q_num_head + kv_num_head) { - base_split_idx = kv_write_idx * kv_num_head * last_dim + - (hi - q_num_head) * last_dim + h_bias; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head) * head_dim + h_bias; out_p = k; } else { out_p = v; - base_split_idx = kv_write_idx * kv_num_head * last_dim + - (hi - q_num_head - kv_num_head) * last_dim + h_bias; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head - kv_num_head) * head_dim + h_bias; } // TODO check this correct or not int64_t new_emb_idx = - rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + rope_3d ? emb_idx + ori_bi * head_dim * max_model_len : emb_idx; float thread_m2 = 0.0f; float warp_m2 = 0.0f; @@ -122,7 +125,7 @@ __global__ void GQAVariableLengthRotarySplitKernel( WelfordWarpAllReduce(thread_m2, &warp_m2); // 单个head的标准差 if (hi < q_num_head + kv_num_head) { // only q and k need norm - float row_variance = max(warp_m2 / last_dim, 0.0f); + float row_variance = max(warp_m2 / head_dim, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); if (hi < q_num_head) { Load(&q_norm_weight[threadIdx.x * VecSize], @@ -165,12 +168,12 @@ __global__ void GQAVariableLengthRotarySplitKernel( template void gqa_rotary_qk_split_variable( - T *qkv_out, // [token_num, 3, num_head, dim_head] + T *qkv_out, // [token_num, 3, num_head, head_dim] T *q, T *k, T *v, const T *qkv_input, - const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const float *rotary_emb, // [2, 1, 1, seq_len, head_dim / 2] const float *q_norm_weight, const float *k_norm_weight, const int *batch_id_per_token, @@ -181,14 +184,14 @@ void gqa_rotary_qk_split_variable( const int token_num, const int num_heads, const int kv_num_heads, - const int seq_len, + const int max_model_len, const int input_output_len, - const int dim_head, + const int head_dim, const bool rope_3d, const float rms_norm_eps, const cudaStream_t &stream) { - assert(dim_head == 128 && "dim_head must be 128"); - int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head; + assert(head_dim == 128 && "head_dim must be 128"); + int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim; constexpr int HEAD_DIM = 128; constexpr int PackSize = HEAD_DIM / kWarpSize; @@ -199,7 +202,7 @@ void gqa_rotary_qk_split_variable( dim3 block_size(kWarpSize, blocksize / kWarpSize); const float *cos_emb = rotary_emb; - const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + const float *sin_emb = rotary_emb + input_output_len * head_dim / 2; launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel, grid_size, block_size, @@ -222,8 +225,8 @@ void gqa_rotary_qk_split_variable( elem_nums, num_heads, kv_num_heads, - seq_len, - dim_head, + max_model_len, + head_dim, rope_3d, rms_norm_eps); } @@ -1163,9 +1166,6 @@ std::vector GQARopeWriteCacheKernel( meta_data.block_size = block_size; meta_data.batch_size = seq_lens_this_time.dims()[0]; - phi::GPUContext *dev_ctx = static_cast( - phi::DeviceContextPool::Instance().Get(qkv.place())); - auto stream = qkv.stream(); paddle::Tensor qkv_out = GetEmptyTensor(qkv.dims(), qkv.dtype(), qkv.place()); paddle::Tensor q = GetEmptyTensor( diff --git a/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu b/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu index 15da09e08..492b3a266 100644 --- a/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu +++ b/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu @@ -16,25 +16,26 @@ #include "paddle/extension.h" #include "paddle/phi/core/memory/memcpy.h" -__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder, - const int* __restrict__ seq_lens_this_time, - int* __restrict__ cu_seqlens_k, - int* __restrict__ batch_ids, - int* __restrict__ tile_ids_per_batch, - int* __restrict__ num_blocks_x, - int* __restrict__ kv_token_num, - const int bsz, - const int num_row_per_block) { +__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_this_time, + int* __restrict__ cu_seqlens_k, + int* __restrict__ batch_ids, + int* __restrict__ tile_ids_per_batch, + int* __restrict__ num_blocks_x, + int* __restrict__ kv_token_num, + const int bsz, + const int num_row_per_block) { if (threadIdx.x == 0) { int gridx = 0; int index = 0; int total_tokens = 0; cu_seqlens_k[0] = 0; for (uint32_t bid = 0; bid < bsz; bid++) { - int cache_len = seq_lens_decoder[bid]; - const int q_len = seq_lens_this_time[bid]; - if (q_len <= 0) { - cache_len = 0; + int cache_len = 0; + if (seq_lens_encoder[bid] > 0) { + // only deal with chunked prefill case. + cache_len = seq_lens_decoder[bid]; } const int loop_times = div_up(cache_len, num_row_per_block); for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { @@ -42,6 +43,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder, tile_ids_per_batch[index++] = tile_id; } gridx += loop_times; + const int q_len = seq_lens_this_time[bid]; total_tokens += (cache_len + q_len); cu_seqlens_k[bid + 1] = total_tokens; } @@ -51,6 +53,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder, } std::vector PreCacheLenConcat( + const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const int max_dec_len, @@ -58,45 +61,43 @@ std::vector PreCacheLenConcat( auto stream = seq_lens_decoder.stream(); auto place = seq_lens_decoder.place(); int bsz = seq_lens_this_time.shape()[0]; - const uint32_t max_tile_size_per_bs_pre_cache = div_up(max_dec_len, block_size); + const uint32_t max_tile_size_per_bs_pre_cache = + div_up(max_dec_len, block_size); - paddle::Tensor cu_seqlens_k = GetEmptyTensor( - {bsz + 1}, - paddle::DataType::INT32, - place); + paddle::Tensor cu_seqlens_k = + GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, place); paddle::Tensor pre_cache_batch_ids = GetEmptyTensor( - {bsz * max_tile_size_per_bs_pre_cache}, - paddle::DataType::INT32, - place); + {bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place); paddle::Tensor pre_cache_tile_ids_per_batch = GetEmptyTensor( - {bsz * max_tile_size_per_bs_pre_cache}, - paddle::DataType::INT32, - place); + {bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place); paddle::Tensor pre_cache_num_blocks = - GetEmptyTensor({1}, paddle::DataType::INT32, place); + GetEmptyTensor({1}, paddle::DataType::INT32, place); paddle::Tensor kv_token_num = - GetEmptyTensor({1}, paddle::DataType::INT32, place); + GetEmptyTensor({1}, paddle::DataType::INT32, place); pre_cache_len_concat<<<1, 32, 0, stream>>>( - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seqlens_k.data(), - pre_cache_batch_ids.data(), - pre_cache_tile_ids_per_batch.data(), - pre_cache_num_blocks.data(), - kv_token_num.data(), - bsz, - block_size - ); - paddle::Tensor pre_cache_num_blocks_cpu = pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false); - paddle::Tensor kv_token_num_cpu = kv_token_num.copy_to(paddle::CPUPlace(), false); + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + cu_seqlens_k.data(), + pre_cache_batch_ids.data(), + pre_cache_tile_ids_per_batch.data(), + pre_cache_num_blocks.data(), + kv_token_num.data(), + bsz, + block_size); + paddle::Tensor pre_cache_num_blocks_cpu = + pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false); + paddle::Tensor kv_token_num_cpu = + kv_token_num.copy_to(paddle::CPUPlace(), false); - return {cu_seqlens_k, - pre_cache_batch_ids, - pre_cache_tile_ids_per_batch, - pre_cache_num_blocks_cpu, /*cpu*/ - kv_token_num_cpu /*cpu*/ - }; + return { + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, /*cpu*/ + kv_token_num_cpu /*cpu*/ + }; } std::vector PreCacheLenConcatInferDtype( @@ -121,15 +122,13 @@ std::vector> PreCacheLenConcatInferShape( } PD_BUILD_STATIC_OP(pre_cache_len_concat) - .Inputs({"seq_lens_decoder", - "seq_lens_this_time"}) + .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"}) .Outputs({"cu_seqlens_k", "pre_cache_batch_ids", "pre_cache_tile_ids_per_batch", "pre_cache_num_blocks_cpu", /*cpu*/ - "kv_token_num_cpu"}) /*cpu*/ - .Attrs({"max_dec_len: int", - "block_size: int"}) + "kv_token_num_cpu"}) /*cpu*/ + .Attrs({"max_dec_len: int", "block_size: int"}) .SetKernelFn(PD_KERNEL(PreCacheLenConcat)) .SetInferShapeFn(PD_INFER_SHAPE(PreCacheLenConcatInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(PreCacheLenConcatInferDtype)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index abf16db95..c52971472 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -194,6 +194,7 @@ std::vector GQARopeWriteCacheKernel( const bool rope_3d); std::vector PreCacheLenConcat( + const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const int max_dec_len, diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 346251a30..4608bd81e 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -206,20 +206,9 @@ class AppendAttentionBackend(AttentionBackend): Calculate kv cache shape """ key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] - value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - key_cache_shape = [ - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim // 2, - ] - value_cache_shape = [ - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim // 2, - ] + key_cache_shape[-1] = self.head_dim // 2 + value_cache_shape = key_cache_shape return key_cache_shape, value_cache_shape def forward_mixed( diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 3f570aacf..951f6621b 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -63,13 +63,7 @@ class FlashAttentionMetadata(AttentionMetadata): FlashAttentionMetadata """ - rotary_embs: Optional[paddle.Tensor] = None - block_tables: Optional[paddle.Tensor] = None - - cu_seqlens_q: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None - max_seqlen_q: int = 0 - max_seqlen_k: int = 0 pre_cache_batch_ids = None pre_cache_tile_ids_per_batch = None @@ -83,7 +77,6 @@ class FlashAttentionMetadata(AttentionMetadata): _fuse_kernel_compute_dtype: str = "bf16" _dtype: paddle.dtype = paddle.bfloat16 - max_len_tensor_cpu: paddle.Tensor = None max_len_tensor_cpu_decoder: paddle.Tensor = None @@ -133,9 +126,6 @@ class FlashAttentionBackend(AttentionBackend): self.start_layer_index: int = fd_config.model_config.start_layer_index - if fd_config.parallel_config.expert_parallel_rank is None: - fd_config.parallel_config.expert_parallel_rank = 0 - self.rank, self.device_id = init_rank_and_device_id(fd_config) if self.flash_attn_func is None: @@ -154,7 +144,8 @@ class FlashAttentionBackend(AttentionBackend): "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." ) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) - self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) + # Note(ZKK): here must be consistent with append_attn_backend.py + self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) self.zero_seq_enc_lens_for_decode = paddle.zeros( shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32 ) @@ -172,27 +163,13 @@ class FlashAttentionBackend(AttentionBackend): Calculate kv cache shape """ key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] - value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - key_cache_shape = [ - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim // 2, - ] - value_cache_shape = [ - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim // 2, - ] + key_cache_shape[-1] = self.head_dim // 2 + value_cache_shape = key_cache_shape return key_cache_shape, value_cache_shape def init_attention_metadata(self, forward_meta: ForwardMeta): metadata = FlashAttentionMetadata() - metadata.cu_seqlens_q = forward_meta.cu_seqlens_q - metadata.rotary_embs = forward_meta.rotary_embs - metadata.block_tables = forward_meta.block_tables get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, @@ -215,18 +192,20 @@ class FlashAttentionBackend(AttentionBackend): self.block_size, ) - ( - metadata.cu_seqlens_k, - metadata.pre_cache_batch_ids, - metadata.pre_cache_tile_ids_per_batch, - metadata.pre_cache_num_blocks_cpu, - metadata.kv_token_num_cpu, - ) = pre_cache_len_concat( - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.max_len_tensor_cpu[2], - self.block_size, - ) + if forward_meta.max_len_tensor_cpu[1] > 0: + ( + metadata.cu_seqlens_k, + metadata.pre_cache_batch_ids, + metadata.pre_cache_tile_ids_per_batch, + metadata.pre_cache_num_blocks_cpu, + metadata.kv_token_num_cpu, + ) = pre_cache_len_concat( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.max_len_tensor_cpu[2], + self.block_size, + ) # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -251,8 +230,7 @@ class FlashAttentionBackend(AttentionBackend): elif metadata._dtype == "float32": metadata._fuse_kernel_compute_dtype = "fp32" - metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu - metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu) + metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu) metadata.max_len_tensor_cpu_decoder[1] = 0 self.attention_metadata = metadata @@ -276,19 +254,21 @@ class FlashAttentionBackend(AttentionBackend): layer.layer_id + self.start_layer_index, ) - if metadata.max_len_tensor_cpu[1] > 0: + use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 + + if use_fa_do_prefill: q, k, v, _ = gqa_rope_write_cache( qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], - metadata.cu_seqlens_q, + forward_meta.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.rotary_embs, + forward_meta.rotary_embs, forward_meta.seq_lens_this_time, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.batch_id_per_token, - metadata.block_tables, + forward_meta.block_tables, forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, @@ -315,7 +295,7 @@ class FlashAttentionBackend(AttentionBackend): q, k, v, - metadata.cu_seqlens_q, + forward_meta.cu_seqlens_q, metadata.cu_seqlens_k, max_seqlen_q=forward_meta.max_len_tensor_cpu[0], max_seqlen_k=forward_meta.max_len_tensor_cpu[3], @@ -327,23 +307,23 @@ class FlashAttentionBackend(AttentionBackend): qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], - self.zero_seq_enc_lens_for_decode, + self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - metadata.block_tables, + forward_meta.block_tables, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, forward_meta.encoder_num_blocks_x_cpu, forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, - forward_meta.decoder_batch_ids, # from buffer - forward_meta.decoder_tile_ids_per_batch, # from buffer + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, - metadata.max_len_tensor_cpu_decoder, - metadata.rotary_embs, + metadata.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, forward_meta.attn_mask, layer.qkv_bias, layer.qkv_scale, @@ -378,7 +358,7 @@ class FlashAttentionBackend(AttentionBackend): self.speculative_method is not None, ) - if metadata.max_len_tensor_cpu[1] > 0: + if use_fa_do_prefill: merge_prefill_decode_output( res_encoder, res_decoder, diff --git a/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py b/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py index 42a931d18..68eed2c8a 100644 --- a/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py +++ b/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py @@ -24,6 +24,7 @@ from fastdeploy.platforms import current_platform def pre_cache_len_concat( + seq_lens_encoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_this_time: paddle.Tensor, max_dec_len: int = 0, @@ -32,7 +33,7 @@ def pre_cache_len_concat( if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat - out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size) + out = pre_cache_len_concat(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size) return out else: raise NotImplementedError diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 106cb93cd..0acbada35 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -71,7 +71,6 @@ class TestAttentionPerformance(unittest.TestCase): self.fd_config.parallel_config.tp_group = [0] # Initialize Attention Layer - os.environ["FD_ATTENTION_BACKEND"] = "APPEND_ATTN" attn_cls = get_attention_backend() self.attn_backend = attn_cls( self.fd_config, @@ -123,10 +122,10 @@ class TestAttentionPerformance(unittest.TestCase): "max_position_embeddings": 131072, "max_model_len": 131072, "head_dim": 128, - "hidden_size": 4096, - "num_attention_heads": 32, - "num_key_value_heads": 4, - "num_hidden_layers": 57, + "hidden_size": 8192, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 2, } model_dir = tempfile.mkdtemp(prefix="tmp_model_config_") config_path = os.path.join(model_dir, "config.json") @@ -158,6 +157,7 @@ class TestAttentionPerformance(unittest.TestCase): dense_quant_type="block_wise_fp8", moe_quant_type="block_wise_fp8", kv_cache_quant_type="float8_e4m3fn", + # kv_cache_quant_type=None, ), graph_opt_config=GraphOptimizationConfig({}), commit_config=CommitConfig(), @@ -270,7 +270,7 @@ class TestAttentionPerformance(unittest.TestCase): partial_rotary_factor=fd_config.model_config.partial_rotary_factor, ) - input_ids = paddle.zeros([batch_size, seq_len if mode == ForwardMode.EXTEND else 1], dtype="int64") + input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64") token_num = paddle.sum(seq_lens_this_time) ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( input_ids, token_num, seq_lens_this_time @@ -294,12 +294,13 @@ class TestAttentionPerformance(unittest.TestCase): attn_mask_offsets=None, **attn_backend_buffers, ) - return forward_meta + + hidden_states = paddle.randn([token_num, self.fd_config.model_config.hidden_size], dtype="bfloat16") + return forward_meta, hidden_states def test_decode_performance_with_prefill(self): # Test parameters test_steps = 100 - act_tensor_dtype = paddle.bfloat16 # prefill_batch_size = 1 # prefill_seq_len = 4096 @@ -356,11 +357,7 @@ class TestAttentionPerformance(unittest.TestCase): # p.step() for decode_batch_size in [32, 16, 8, 4, 2]: - decode_hidden_states = paddle.randn( - [decode_batch_size, self.fd_config.model_config.hidden_size], dtype=act_tensor_dtype - ) - - forward_meta = self.create_forward_meta( + forward_meta, hidden_states = self.create_forward_meta( batch_size=decode_batch_size, seq_len=36 * 1024, mode=ForwardMode.DECODE, @@ -374,12 +371,12 @@ class TestAttentionPerformance(unittest.TestCase): paddle.device.synchronize() # 必须要先预热一次!因为预处理被放到了第一层再做了! - self.attn_forward(forward_meta, decode_hidden_states) + self.attn_forward(forward_meta, hidden_states) attn_cuda_graphs = graphs.CUDAGraph() attn_cuda_graphs.capture_begin() - self.attn_forward(forward_meta, decode_hidden_states) + self.attn_forward(forward_meta, hidden_states) attn_cuda_graphs.capture_end() diff --git a/tests/operators/test_pre_cache_len_concat.py b/tests/operators/test_pre_cache_len_concat.py index 4844c1c71..84389a104 100644 --- a/tests/operators/test_pre_cache_len_concat.py +++ b/tests/operators/test_pre_cache_len_concat.py @@ -69,7 +69,10 @@ class TestPreCacheLenConcat(unittest.TestCase): seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32") seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32") - outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size) + seq_lens_encoder_t = seq_lens_this_time_t + outputs = pre_cache_len_concat( + seq_lens_encoder_t, seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size + ) cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs] # Shape checks @@ -91,8 +94,11 @@ class TestPreCacheLenConcat(unittest.TestCase): seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32") seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32") + seq_lens_encoder_t = seq_lens_this_time_t - outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size) + outputs = pre_cache_len_concat( + seq_lens_encoder_t, seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size + ) cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs] # Reference implementation