mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			452 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			452 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import os
 | |
| from dataclasses import dataclass
 | |
| from math import sqrt
 | |
| from typing import TYPE_CHECKING, Optional
 | |
| 
 | |
| import paddle
 | |
| from paddle.nn.functional.flash_attention import flash_attn_unpadded
 | |
| 
 | |
| from fastdeploy.config import FDConfig
 | |
| from fastdeploy.model_executor.layers.attention.attention import Attention
 | |
| from fastdeploy.model_executor.layers.attention.base_attention_backend import (
 | |
|     AttentionBackend,
 | |
|     AttentionMetadata,
 | |
| )
 | |
| from fastdeploy.model_executor.ops.iluvatar import paged_attention
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from fastdeploy.model_executor.forward_meta import ForwardMeta
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class IluvatarAttentionMetadata(AttentionMetadata):
 | |
|     """
 | |
|     IluvatarAttentionMetadata
 | |
|     """
 | |
| 
 | |
|     # flash_attn metadata
 | |
|     cu_seqlens_q: Optional[paddle.Tensor] = None
 | |
|     cu_seqlens_k: Optional[paddle.Tensor] = None
 | |
|     fixed_seed_offset: Optional[paddle.Tensor] = None
 | |
|     attn_mask: Optional[paddle.Tensor] = None
 | |
|     attn_mask_start_row_indices: Optional[paddle.Tensor] = None
 | |
|     dropout: float = 0.0
 | |
|     causal: bool = True
 | |
|     return_softmax: bool = False
 | |
|     rng_name: str = ""
 | |
| 
 | |
|     # paged_attn metadata
 | |
|     block_tables: Optional[paddle.Tensor] = None
 | |
|     seq_lens: Optional[paddle.Tensor] = None
 | |
|     num_kv_heads: int = 1
 | |
|     scale: float = 1.0
 | |
|     block_size: int = 1
 | |
|     max_context_len: int = 1
 | |
|     alibi_slopes: Optional[paddle.Tensor] = None
 | |
|     # causal: bool = True
 | |
|     window_left: int = -1
 | |
|     window_right: int = -1
 | |
|     softcap: float = 0.0
 | |
|     use_cuda_graph: bool = False
 | |
|     use_sqrt_alibi: bool = False
 | |
| 
 | |
| 
 | |
| # qk[seq, h, d], cos/sin [seq, 1, d]
 | |
| def apply_rope(qk, cos, sin):
 | |
|     rotate_half = paddle.reshape(
 | |
|         paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
 | |
|         paddle.shape(qk),
 | |
|     )
 | |
|     out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
 | |
|     return paddle.cast(out, qk.dtype)
 | |
| 
 | |
| 
 | |
| class IluvatarAttnBackend(AttentionBackend):
 | |
|     """
 | |
|     The backend class that uses paddle native attention implementation.
 | |
|     Which is used only for testing purpose.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
 | |
|         super().__init__()
 | |
|         self.attention_metadata = IluvatarAttentionMetadata()
 | |
|         self.attention_metadata.block_size = fd_config.parallel_config.block_size
 | |
|         assert (
 | |
|             fd_config.parallel_config.enc_dec_block_num == 0
 | |
|         ), f"Iluvatar does not support yet, {fd_config.parallel_config.enc_dec_block_num}"
 | |
|         assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
 | |
| 
 | |
|         self.attention_metadata.max_context_len = fd_config.parallel_config.max_model_len
 | |
|         self.attention_metadata.causal = getattr(fd_config.model_config, "causal", True)
 | |
|         self.speculate_method = getattr(fd_config.parallel_config, "speculate_method", None)
 | |
|         self.use_speculate = self.speculate_method is not None
 | |
|         self.attention_metadata.num_kv_heads = kv_num_heads
 | |
|         self.attention_metadata.dropout = fd_config.model_config.hidden_dropout_prob
 | |
|         self.num_heads = num_heads
 | |
|         self.total_num_heads = num_heads + 2 * kv_num_heads
 | |
|         self.head_dim = head_dim
 | |
|         self.hidden_dim = num_heads * head_dim
 | |
|         self.total_hidden_dim = self.total_num_heads * head_dim
 | |
|         # note: scale need to change if using MLA
 | |
|         self.attention_metadata.scale = 1.0 / sqrt(head_dim)
 | |
|         self.num_layers = fd_config.model_config.num_hidden_layers
 | |
|         self.dtype = paddle.get_default_dtype()
 | |
| 
 | |
|         self.record_block_table_metadata = {}
 | |
|         self.enable_fused_attention = int(os.getenv("FD_ILUVATAR_ENABLE_FUSED_ATTN", 1))
 | |
| 
 | |
|     def init_attention_metadata(self, forward_meta: ForwardMeta):
 | |
|         """Initialize attntion metadata hence all layers in the forward pass can reuse it."""
 | |
|         self.prefill_info_dict = {}
 | |
|         self.decode_info_dict = {}
 | |
| 
 | |
|         prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1
 | |
|         decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1
 | |
|         self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0]
 | |
|         self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0]
 | |
| 
 | |
|         self.prefill_len = len(self.prefill_info_dict["batch_ids"])
 | |
|         self.decode_len = len(self.decode_info_dict["batch_ids"])
 | |
|         # only prefill
 | |
|         if self.decode_len == 0:
 | |
|             cu_seq_ids = list(range(self.prefill_len + 1))
 | |
|             self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids]
 | |
|         # only decode
 | |
|         elif self.prefill_len == 0:
 | |
|             pass
 | |
|         # both prefill and decode
 | |
|         else:
 | |
|             prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids])
 | |
|             decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids])
 | |
| 
 | |
|             self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
 | |
|                 [self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
 | |
|             )
 | |
|             self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[
 | |
|                 self.prefill_info_dict["batch_ids"], 0
 | |
|             ]
 | |
|             self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"])
 | |
| 
 | |
|             self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype)
 | |
|             self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype)
 | |
|             self.merged_output = paddle.zeros(
 | |
|                 [prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype
 | |
|             )
 | |
| 
 | |
|             prefill_start, decode_start, start = 0, 0, 0
 | |
|             non_zeros_ids = forward_meta.seq_lens_this_time != 0
 | |
|             non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
 | |
|             end = non_zeros_seq_lens[0]
 | |
|             if end > 1:
 | |
|                 last_stage = "prefill"
 | |
|                 prefill_end = end
 | |
|                 decode_end = 0
 | |
|             else:
 | |
|                 last_stage = "decode"
 | |
|                 prefill_end = 0
 | |
|                 decode_end = end
 | |
| 
 | |
|             self.prefill_info_dict["id_group"] = []
 | |
|             self.prefill_info_dict["reverse_id_group"] = []
 | |
|             self.decode_info_dict["id_group"] = []
 | |
|             self.decode_info_dict["reverse_id_group"] = []
 | |
|             self.record_stages = []
 | |
|             for seq_len in non_zeros_seq_lens[1:]:
 | |
|                 if seq_len > 1:
 | |
|                     if last_stage == "decode":
 | |
|                         self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"])))
 | |
|                         self.decode_info_dict["id_group"].append((decode_start, decode_end))
 | |
|                         self.decode_info_dict["reverse_id_group"].append((start, end))
 | |
|                         decode_start = decode_end
 | |
|                         start = end
 | |
|                         last_stage = "prefill"
 | |
|                     prefill_end += seq_len
 | |
|                     end += seq_len
 | |
|                 else:
 | |
|                     if last_stage == "prefill":
 | |
|                         self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"])))
 | |
|                         self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
 | |
|                         self.prefill_info_dict["reverse_id_group"].append((start, end))
 | |
|                         prefill_start = prefill_end
 | |
|                         start = end
 | |
|                         last_stage = "decode"
 | |
|                     decode_end += seq_len
 | |
|                     end += seq_len
 | |
| 
 | |
|             if prefill_start < prefill_end:
 | |
|                 self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"])))
 | |
|                 self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
 | |
|                 self.prefill_info_dict["reverse_id_group"].append((start, end))
 | |
|             if decode_start < decode_end:
 | |
|                 self.record_stages.append(("decode", len(self.decode_info_dict["id_group"])))
 | |
|                 self.decode_info_dict["id_group"].append((decode_start, decode_end))
 | |
|                 self.decode_info_dict["reverse_id_group"].append((start, end))
 | |
| 
 | |
|     def get_attntion_meta(self):
 | |
|         """get_attntion_meta"""
 | |
|         return self.attention_metadata
 | |
| 
 | |
|     def get_kv_cache_shape(
 | |
|         self,
 | |
|         max_num_blocks: int,
 | |
|         kv_cache_quant_type: str = None,
 | |
|     ):
 | |
|         """
 | |
|         Caculate kv cache shape
 | |
|         """
 | |
|         return (
 | |
|             max_num_blocks,
 | |
|             self.attention_metadata.num_kv_heads,
 | |
|             self.attention_metadata.block_size,
 | |
|             self.head_dim,
 | |
|         )
 | |
| 
 | |
|     def prefill_update_kv_cache(
 | |
|         self, k, v, k_cache_id: int, v_cache_id: int, layer_id: int, forward_meta: ForwardMeta, prefill_batch_ids: list
 | |
|     ):
 | |
|         # [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim]
 | |
|         trans_k = k.transpose([1, 0, 2]).contiguous()
 | |
|         trans_v = v.transpose([1, 0, 2]).contiguous()
 | |
|         tensor_start = 0
 | |
|         for batch_idx in prefill_batch_ids:
 | |
|             seq_len = forward_meta.seq_lens_this_time[batch_idx]
 | |
| 
 | |
|             tensor_end = tensor_start + seq_len
 | |
|             slice_trans_k = trans_k[:, tensor_start:tensor_end, :]
 | |
|             slice_trans_v = trans_v[:, tensor_start:tensor_end, :]
 | |
| 
 | |
|             cur_block_tables = forward_meta.block_tables[batch_idx]
 | |
|             cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
 | |
| 
 | |
|             cache_start = 0
 | |
|             cur_used_num_blocks = cur_used_block_tables.shape[0]
 | |
|             for i, block_id in enumerate(cur_used_block_tables):
 | |
|                 # last block: seq_len - cache_start <= block_size
 | |
|                 if i == cur_used_num_blocks - 1:
 | |
|                     cache_end = seq_len - cache_start
 | |
|                     assert cache_end <= self.attention_metadata.block_size
 | |
|                     paddle.assign(
 | |
|                         slice_trans_k[:, cache_start:seq_len, :],
 | |
|                         output=forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :],
 | |
|                     )
 | |
|                     paddle.assign(
 | |
|                         slice_trans_v[:, cache_start:seq_len, :],
 | |
|                         output=forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :],
 | |
|                     )
 | |
|                     if layer_id == self.num_layers - 1:
 | |
|                         self.record_block_table_metadata[batch_idx] = {
 | |
|                             "block_id": block_id.item(),
 | |
|                             "cache_end": cache_end.item(),
 | |
|                         }
 | |
|                 # non last block: seq_lens_this_time > block_size
 | |
|                 else:
 | |
|                     assert seq_len > self.attention_metadata.block_size
 | |
|                     cache_end = cache_start + self.attention_metadata.block_size
 | |
|                     paddle.assign(
 | |
|                         slice_trans_k[:, cache_start:cache_end, :], output=forward_meta.caches[k_cache_id][block_id]
 | |
|                     )
 | |
|                     paddle.assign(
 | |
|                         slice_trans_v[:, cache_start:cache_end, :], output=forward_meta.caches[v_cache_id][block_id]
 | |
|                     )
 | |
|                     cache_start += self.attention_metadata.block_size
 | |
| 
 | |
|             tensor_start = tensor_end
 | |
| 
 | |
|     def get_splited_qkv(
 | |
|         self, qkv: paddle.Tensor, forward_meta: ForwardMeta, cu_seqlens_q: paddle.Tensor, batch_ids=None
 | |
|     ):
 | |
|         q_end = self.hidden_dim
 | |
|         k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim
 | |
|         v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim
 | |
|         assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
 | |
|         assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
 | |
| 
 | |
|         if batch_ids is None:
 | |
|             batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
 | |
| 
 | |
|         q = qkv[..., 0:q_end]
 | |
|         k = qkv[..., q_end:k_end]
 | |
|         v = qkv[..., k_end:v_end]
 | |
|         q = q.view([-1, self.num_heads, self.head_dim])
 | |
|         k = k.view([-1, self.attention_metadata.num_kv_heads, self.head_dim])
 | |
|         v = v.view([-1, self.attention_metadata.num_kv_heads, self.head_dim])
 | |
| 
 | |
|         for idx in range(len(cu_seqlens_q) - 1):
 | |
|             batch_idx = batch_ids[idx]
 | |
|             seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
 | |
|             if seq_len_i == 0:
 | |
|                 continue
 | |
|             cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
 | |
|             cu_seq_start_q = cu_seqlens_q[idx]
 | |
|             cu_seq_end_q = cu_seqlens_q[idx + 1]
 | |
|             # forward_meta.rotary_embs is [2, 1, S, 1, D]
 | |
|             if forward_meta.rotary_embs is not None:
 | |
|                 cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
 | |
|                 sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
 | |
|                 q[cu_seq_start_q:cu_seq_end_q] = apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin)
 | |
|                 k[cu_seq_start_q:cu_seq_end_q] = apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin)
 | |
| 
 | |
|         return q, k, v
 | |
| 
 | |
|     def split_pd_qkv(self, qkv):
 | |
| 
 | |
|         for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
 | |
|             self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
 | |
| 
 | |
|         for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]):
 | |
|             self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
 | |
| 
 | |
|         return self.prefill_qkv, self.decode_qkv
 | |
| 
 | |
|     def merge_pd_output(self, prefill_out, decode_out):
 | |
|         for stage, idx in self.record_stages:
 | |
|             if stage == "prefill":
 | |
|                 ids = self.prefill_info_dict["id_group"][idx]
 | |
|                 reverse_ids = self.prefill_info_dict["reverse_id_group"][idx]
 | |
|                 self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :]
 | |
|             else:
 | |
|                 ids = self.decode_info_dict["id_group"][idx]
 | |
|                 reverse_ids = self.decode_info_dict["reverse_id_group"][idx]
 | |
|                 self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :]
 | |
|         return self.merged_output
 | |
| 
 | |
|     def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
 | |
|         prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
 | |
|             prefill_qkv,
 | |
|             forward_meta,
 | |
|             self.prefill_info_dict["cu_seqlens_q"],
 | |
|             batch_ids=self.prefill_info_dict["batch_ids"],
 | |
|         )
 | |
| 
 | |
|         prefill_out = flash_attn_unpadded(
 | |
|             prefill_q,
 | |
|             prefill_k,
 | |
|             prefill_v,
 | |
|             cu_seqlens_q=self.prefill_info_dict["cu_seqlens_q"],
 | |
|             cu_seqlens_k=self.prefill_info_dict["cu_seqlens_q"],
 | |
|             max_seqlen_q=self.attention_metadata.max_context_len,
 | |
|             max_seqlen_k=self.attention_metadata.max_context_len,
 | |
|             scale=self.attention_metadata.scale,
 | |
|             dropout=self.attention_metadata.dropout,
 | |
|             causal=self.attention_metadata.causal,
 | |
|             return_softmax=self.attention_metadata.return_softmax,
 | |
|         )[0]
 | |
|         self.prefill_update_kv_cache(
 | |
|             prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.prefill_info_dict["batch_ids"]
 | |
|         )
 | |
| 
 | |
|         return prefill_out
 | |
| 
 | |
|     def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
 | |
|         k_cache = forward_meta.caches[k_cache_id]
 | |
|         v_cache = forward_meta.caches[v_cache_id]
 | |
|         if self.enable_fused_attention:
 | |
|             rope_cos = forward_meta.rotary_embs[0, 0, :, :, :]
 | |
|             rope_sin = forward_meta.rotary_embs[1, 0, :, :, :]
 | |
|             decode_out = paged_attention(
 | |
|                 decode_qkv.view([-1, self.total_num_heads, self.head_dim]),
 | |
|                 k_cache,
 | |
|                 v_cache,
 | |
|                 block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
 | |
|                 seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
 | |
|                 num_kv_heads=self.attention_metadata.num_kv_heads,
 | |
|                 scale=self.attention_metadata.scale,
 | |
|                 block_size=self.attention_metadata.block_size,
 | |
|                 max_context_len=self.attention_metadata.max_context_len,
 | |
|                 alibi_slopes=self.attention_metadata.alibi_slopes,
 | |
|                 causal=self.attention_metadata.causal,
 | |
|                 window_left=self.attention_metadata.window_left,
 | |
|                 window_right=self.attention_metadata.window_right,
 | |
|                 softcap=self.attention_metadata.softcap,
 | |
|                 use_cuda_graph=self.attention_metadata.use_cuda_graph,
 | |
|                 use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
 | |
|                 merged_qkv=True,
 | |
|                 k=decode_qkv,
 | |
|                 v=decode_qkv,
 | |
|                 rope_sin=rope_sin,
 | |
|                 rope_cos=rope_cos,
 | |
|             )
 | |
|         else:
 | |
|             decode_q, decode_k, decode_v = self.get_splited_qkv(
 | |
|                 decode_qkv,
 | |
|                 forward_meta,
 | |
|                 self.decode_info_dict["cu_seqlens_q"],
 | |
|                 batch_ids=self.decode_info_dict["batch_ids"],
 | |
|             )
 | |
| 
 | |
|             decode_out = paged_attention(
 | |
|                 decode_q,
 | |
|                 k_cache,
 | |
|                 v_cache,
 | |
|                 block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
 | |
|                 seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
 | |
|                 num_kv_heads=self.attention_metadata.num_kv_heads,
 | |
|                 scale=self.attention_metadata.scale,
 | |
|                 block_size=self.attention_metadata.block_size,
 | |
|                 max_context_len=self.attention_metadata.max_context_len,
 | |
|                 alibi_slopes=self.attention_metadata.alibi_slopes,
 | |
|                 causal=self.attention_metadata.causal,
 | |
|                 window_left=self.attention_metadata.window_left,
 | |
|                 window_right=self.attention_metadata.window_right,
 | |
|                 softcap=self.attention_metadata.softcap,
 | |
|                 use_cuda_graph=self.attention_metadata.use_cuda_graph,
 | |
|                 use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
 | |
|                 k=decode_k,
 | |
|                 v=decode_v,
 | |
|             )
 | |
| 
 | |
|         return decode_out
 | |
| 
 | |
|     def forward_mixed(
 | |
|         self,
 | |
|         q: paddle.Tensor,
 | |
|         k: paddle.Tensor,
 | |
|         v: paddle.Tensor,
 | |
|         qkv: paddle.Tensor,
 | |
|         compressed_kv: paddle.Tensor,
 | |
|         k_pe: paddle.Tensor,
 | |
|         layer: Attention,
 | |
|         forward_meta: ForwardMeta,
 | |
|     ):
 | |
|         """
 | |
|         forward_mixed
 | |
|         """
 | |
|         assert not self.use_speculate, "IluvatarAttnBackend cannot support speculate now"
 | |
|         layer_id = layer.layer_id
 | |
|         k_cache_id = layer_id * 2
 | |
|         v_cache_id = k_cache_id + 1
 | |
|         q_dim = qkv.dim()
 | |
|         assert q_dim == 2
 | |
| 
 | |
|         if self.decode_len == 0:
 | |
|             output = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
 | |
| 
 | |
|         elif self.prefill_len == 0:
 | |
|             output = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta)
 | |
|         else:
 | |
|             prefill_qkv, decode_qkv = self.split_pd_qkv(qkv)
 | |
|             prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
 | |
|             decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta)
 | |
|             output = self.merge_pd_output(prefill_output, decode_output)
 | |
| 
 | |
|         output = output.view([-1, self.num_heads * self.head_dim])
 | |
|         return output
 | 
