Files
FastDeploy/fastdeploy/worker/model_runner/forward_meta.py
2025-06-16 00:04:48 +08:00

324 lines
10 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import Optional
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING
import abc
import paddle
import numpy as np
import logging
import paddle
if TYPE_CHECKING:
from fastdeploy.model_executor.layers.attention import AttentionBackend, Attention
from fastdeploy.worker.model_runner.model_runner_base import ModelRunnerBase
logger = logging.getLogger(__name__)
class ForwardMode(IntEnum):
"""
Forward mode used during attention.
"""
# for prefill and extend
EXTEND = auto()
# for generation
DECODE = auto()
MIXED = auto()
def is_prefill(self):
"""Whether it's a prefill forward"""
return self == ForwardMode.EXTEND
def is_decode(self):
"""Whether it's a decode forward"""
return self == ForwardMode.DECODE
def is_mixed(self):
"""Whether it's a decode forward"""
return self == ForwardMode.MIXED
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(
self,
size: int,
max_context_len: int
):
self.size = size
self.max_context_len = max_context_len
self.req_to_token = paddle.zeros(
(size, max_context_len), dtype=paddle.int32
)
self.free_slots = list(range(size))
def write(self, indices, values):
"""Write data into request buffer"""
self.req_to_token[indices] = values
def available_size(self):
"""Get number of slots left"""
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
"""Allocate `need_size` slots"""
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
"""Free slot"""
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def clear(self):
"""Clear all slots"""
self.free_slots = list(range(self.size))
class KVCache(abc.ABC):
"""Abstract base class representing a key value cache"""
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Return cached keys and values given layer id.
Args:
layer_id: int
Returns:
tuple: (keys, values)
"""
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: 'Attention',
loc: paddle.Tensor,
cache_k: paddle.Tensor,
cache_v: paddle.Tensor,
) -> None:
"""
Set cached keys and values given layer id.
Args:
layer: Attention
loc: paddle.Tensor
cache_k: paddle.Tensor
cache_v: paddle.Tensor
"""
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
"""Transfer kv_data between devices"""
raise NotImplementedError()
@abc.abstractmethod
def transfer_per_layer(self, indices, flat_data, layer_id):
"""Not used yet"""
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
"""Not used yet"""
self.layer_transfer_counter = layer_transfer_counter
class MHATokenToKVPool(KVCache):
"""Token To Key Value Pool for MultiHeadAttention"""
def __init__(
self,
max_block_num: int,
block_size: int,
dtype: paddle.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
):
self.max_block_num = max_block_num
self.block_size = block_size
self.dtype = dtype
self.device = device
if dtype in (paddle.int8, paddle.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = paddle.uint8
else:
self.store_dtype = dtype
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
k_size, v_size = self.get_kv_size_bytes()
GB = 1024 * 1024 * 1024
logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
paddle.zeros(
(self.max_block_num, self.head_num,
self.block_size, self.head_dim),
dtype=self.store_dtype,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
paddle.zeros(
(self.max_block_num, self.head_num,
self.block_size, self.head_dim),
dtype=self.store_dtype,
)
for _ in range(self.layer_num)
]
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
def get_kv_size_bytes(self):
"""for debugging purpose"""
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * 4
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * 4
return k_size_bytes, v_size_bytes
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
for i in range(self.layer_num):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data for a specific layer from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int):
"""Return cached keys given layer id."""
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
"""Return cached values given layer id."""
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
"""Return cached keys and values given layer id."""
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer: 'Attention',
loc: paddle.Tensor,
cache_k: paddle.Tensor,
cache_v: paddle.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
"""Set cached keys and values given layer id."""
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
@dataclass
class ForwardMeta():
"""
ForwardMeta is used to store the global meta information of the forward.
"""
input_ids:paddle.Tensor
#attention meta
forward_mode: ForwardMode = ForwardMode.MIXED
ids_remove_padding:paddle.Tensor = None
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
seq_lens_this_time: Optional[paddle.Tensor] = None
cum_offsets: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
attn_backend: 'AttentionBackend' = None
rotary_embs:Optional[paddle.Tensor] = None
padding_offset:Optional[paddle.Tensor] = None
cum_offsets:Optional[paddle.Tensor] = None
cu_seqlens_q:Optional[paddle.Tensor] = None
cu_seqlens_k:Optional[paddle.Tensor] = None
caches:Optional[paddle.Tensor] = None
attn_mask:Optional[paddle.Tensor] = None
pre_caches_length:int=0
@classmethod
def init_forward_mata(
cls,
model_runner: "ModelRunnerBase"
):
ret = cls(
forward_mode=ForwardMode.MIXED,
input_ids=model_runner.share_inputs["input_ids"],
ids_remove_padding=model_runner.share_inputs["ids_remove_padding"],
seq_lens_encoder=model_runner.share_inputs["seq_lens_encoder"],
seq_lens_decoder=model_runner.share_inputs["seq_lens_decoder"],
seq_lens_this_time=model_runner.share_inputs["seq_lens_this_time"],
cum_offsets=model_runner.share_inputs["cum_offsets"],
block_tables=model_runner.share_inputs["block_tables"],
attn_backend=model_runner.attn_backend,
rotary_embs=model_runner.share_inputs["rope_emb"],
padding_offset=model_runner.share_inputs["padding_offset"],
cu_seqlens_q=model_runner.share_inputs["cu_seqlens_q"],
cu_seqlens_k=model_runner.share_inputs["cu_seqlens_k"],
caches=model_runner.share_inputs["caches"]
)
return ret