Files
FastDeploy/fastdeploy/worker/forward_meta.py
2025-06-29 23:29:37 +00:00

419 lines
12 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import abc
import logging
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
if TYPE_CHECKING:
from fastdeploy.model_executor.layers.attention import (Attention,
AttentionBackend)
logger = logging.getLogger(__name__)
class ForwardMode(IntEnum):
"""
Forward mode used during attention.
"""
# for prefill and extend
EXTEND = auto()
# for generation
DECODE = auto()
MIXED = auto()
def is_prefill(self):
"""Whether it's a prefill forward"""
return self == ForwardMode.EXTEND
def is_decode(self):
"""Whether it's a decode forward"""
return self == ForwardMode.DECODE
def is_mixed(self):
"""Whether it's a decode forward"""
return self == ForwardMode.MIXED
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int):
self.size = size
self.max_context_len = max_context_len
self.req_to_token = paddle.zeros((size, max_context_len),
dtype=paddle.int32)
self.free_slots = list(range(size))
def write(self, indices, values):
"""Write data into request buffer"""
self.req_to_token[indices] = values
def available_size(self):
"""Get number of slots left"""
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
"""Allocate `need_size` slots"""
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
"""Free slot"""
if isinstance(free_index, (int, )):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def clear(self):
"""Clear all slots"""
self.free_slots = list(range(self.size))
class KVCache(abc.ABC):
"""Abstract base class representing a key value cache"""
@abc.abstractmethod
def get_kv_buffer(self,
layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Return cached keys and values given layer id.
Args:
layer_id: int
Returns:
tuple: (keys, values)
"""
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: 'Attention',
loc: paddle.Tensor,
cache_k: paddle.Tensor,
cache_v: paddle.Tensor,
) -> None:
"""
Set cached keys and values given layer id.
Args:
layer: Attention
loc: paddle.Tensor
cache_k: paddle.Tensor
cache_v: paddle.Tensor
"""
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
"""Transfer kv_data between devices"""
raise NotImplementedError()
@abc.abstractmethod
def transfer_per_layer(self, indices, flat_data, layer_id):
"""Not used yet"""
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
"""Not used yet"""
self.layer_transfer_counter = layer_transfer_counter
class MHATokenToKVPool(KVCache):
"""Token To Key Value Pool for MultiHeadAttention"""
def __init__(
self,
max_block_num: int,
block_size: int,
dtype: paddle.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
):
self.max_block_num = max_block_num
self.block_size = block_size
self.dtype = dtype
self.device = device
if dtype in (paddle.int8, paddle.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = paddle.uint8
else:
self.store_dtype = dtype
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
k_size, v_size = self.get_kv_size_bytes()
GB = 1024 * 1024 * 1024
logger.info(
f"KV Cache is allocated. #tokens: {self.size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
paddle.zeros(
(self.max_block_num, self.head_num, self.block_size,
self.head_dim),
dtype=self.store_dtype,
) for _ in range(self.layer_num)
]
self.v_buffer = [
paddle.zeros(
(self.max_block_num, self.head_num, self.block_size,
self.head_dim),
dtype=self.store_dtype,
) for _ in range(self.layer_num)
]
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
def get_kv_size_bytes(self):
"""for debugging purpose"""
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * 4
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * 4
return k_size_bytes, v_size_bytes
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
for i in range(self.layer_num):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data for a specific layer from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int):
"""Return cached keys given layer id."""
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
"""Return cached values given layer id."""
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
"""Return cached keys and values given layer id."""
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer: 'Attention',
loc: paddle.Tensor,
cache_k: paddle.Tensor,
cache_v: paddle.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
"""Set cached keys and values given layer id."""
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
@dataclass
class ForwardMeta():
"""
ForwardMeta is used to store the global meta information of the forward.
"""
#
input_ids: paddle.Tensor
#attention meta
forward_mode: ForwardMode = ForwardMode.MIXED
#
ids_remove_padding: paddle.Tensor = None
#
seq_lens_encoder: Optional[paddle.Tensor] = None
#
seq_lens_decoder: Optional[paddle.Tensor] = None
#
seq_lens_this_time: Optional[paddle.Tensor] = None
#
cum_offsets: Optional[paddle.Tensor] = None
#
block_tables: Optional[paddle.Tensor] = None
#
attn_backend: 'AttentionBackend' = None
#
rotary_embs: Optional[paddle.Tensor] = None
#
padding_offset: Optional[paddle.Tensor] = None
#
cu_seqlens_q: Optional[paddle.Tensor] = None
#
cu_seqlens_k: Optional[paddle.Tensor] = None
#
caches: Optional[paddle.Tensor] = None
#
attn_mask: Optional[paddle.Tensor] = None
#
pre_caches_length: int = 0
# Use cuda graph in this step. Used to avoid run cuda graph when in dummy run or prefill stage.
step_use_cudagraph: bool = False
# for attention backend
decoder_batch_ids: Optional[paddle.Tensor] = None
# for attention backend
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
@classmethod
def init_forward_meta(cls, share_inputs: Dict,
attn_backend: "AttentionBackend"):
""" init forward meta """
# TODO(gongshaotian): delete this func
ret = cls(
forward_mode=ForwardMode.MIXED,
input_ids=share_inputs["input_ids"],
ids_remove_padding=share_inputs["ids_remove_padding"],
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
cum_offsets=share_inputs["cum_offsets"],
block_tables=share_inputs["block_tables"],
attn_backend=attn_backend,
rotary_embs=share_inputs["rope_emb"],
padding_offset=share_inputs["padding_offset"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
caches=share_inputs["caches"],
decoder_batch_ids=share_inputs.get("decoder_batch_ids", None),
decoder_tile_ids_per_batch=share_inputs.get(
"decoder_tile_ids_per_batch", None),
)
return ret
@dataclass
class XPUForwardMeta(ForwardMeta):
"""
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
"""
#
encoder_batch_map: Optional[paddle.Tensor] = None
#
decoder_batch_map: Optional[paddle.Tensor] = None
#
encoder_batch_idx: Optional[paddle.Tensor] = None
#
decoder_batch_idx: Optional[paddle.Tensor] = None
#
encoder_seq_lod: Optional[paddle.Tensor] = None
#
decoder_context_len: Optional[paddle.Tensor] = None
#
decoder_context_len_cache: Optional[paddle.Tensor] = None
#
encoder_batch_map_cpu: Optional[paddle.Tensor] = None
#
decoder_batch_map_cpu: Optional[paddle.Tensor] = None
#
encoder_batch_idx_cpu: Optional[paddle.Tensor] = None
#
decoder_batch_idx_cpu: Optional[paddle.Tensor] = None
#
encoder_seq_lod_cpu: Optional[paddle.Tensor] = None
#
decoder_context_len_cpu: Optional[paddle.Tensor] = None
#
decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None
#
batch_tensor: Optional[paddle.Tensor] = None
#
enc_batch: Optional[paddle.Tensor] = None
#
dec_batch: Optional[paddle.Tensor] = None
#
total_enc_len: Optional[paddle.Tensor] = None