mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
418
fastdeploy/worker/forward_meta.py
Normal file
418
fastdeploy/worker/forward_meta.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.layers.attention import (Attention,
|
||||
AttentionBackend)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
"""
|
||||
Forward mode used during attention.
|
||||
"""
|
||||
|
||||
# for prefill and extend
|
||||
EXTEND = auto()
|
||||
# for generation
|
||||
DECODE = auto()
|
||||
|
||||
MIXED = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
"""Whether it's a prefill forward"""
|
||||
return self == ForwardMode.EXTEND
|
||||
|
||||
def is_decode(self):
|
||||
"""Whether it's a decode forward"""
|
||||
return self == ForwardMode.DECODE
|
||||
|
||||
def is_mixed(self):
|
||||
"""Whether it's a decode forward"""
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(self, size: int, max_context_len: int):
|
||||
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.req_to_token = paddle.zeros((size, max_context_len),
|
||||
dtype=paddle.int32)
|
||||
self.free_slots = list(range(size))
|
||||
|
||||
def write(self, indices, values):
|
||||
"""Write data into request buffer"""
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def available_size(self):
|
||||
"""Get number of slots left"""
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
"""Allocate `need_size` slots"""
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
"""Free slot"""
|
||||
if isinstance(free_index, (int, )):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all slots"""
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
"""Abstract base class representing a key value cache"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self,
|
||||
layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Return cached keys and values given layer id.
|
||||
Args:
|
||||
layer_id: int
|
||||
Returns:
|
||||
tuple: (keys, values)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: 'Attention',
|
||||
loc: paddle.Tensor,
|
||||
cache_k: paddle.Tensor,
|
||||
cache_v: paddle.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Set cached keys and values given layer id.
|
||||
Args:
|
||||
layer: Attention
|
||||
loc: paddle.Tensor
|
||||
cache_k: paddle.Tensor
|
||||
cache_v: paddle.Tensor
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
"""Transfer kv_data between devices"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
"""Not used yet"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
"""Not used yet"""
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
"""Token To Key Value Pool for MultiHeadAttention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_block_num: int,
|
||||
block_size: int,
|
||||
dtype: paddle.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
):
|
||||
self.max_block_num = max_block_num
|
||||
self.block_size = block_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (paddle.int8, paddle.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = paddle.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
GB = 1024 * 1024 * 1024
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {self.size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
paddle.zeros(
|
||||
(self.max_block_num, self.head_num, self.block_size,
|
||||
self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
) for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
paddle.zeros(
|
||||
(self.max_block_num, self.head_num, self.block_size,
|
||||
self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
) for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
"""for debugging purpose"""
|
||||
assert hasattr(self, "k_buffer")
|
||||
assert hasattr(self, "v_buffer")
|
||||
k_size_bytes = 0
|
||||
for k_cache in self.k_buffer:
|
||||
k_size_bytes += np.prod(k_cache.shape) * 4
|
||||
v_size_bytes = 0
|
||||
for v_cache in self.v_buffer:
|
||||
v_size_bytes += np.prod(v_cache.shape) * 4
|
||||
return k_size_bytes, v_size_bytes
|
||||
|
||||
def transfer(self, indices, flat_data):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
for i in range(self.layer_num):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data for a specific layer from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
self.k_buffer[layer_id][indices] = k_data
|
||||
self.v_buffer[layer_id][indices] = v_data
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
"""Return cached keys given layer id."""
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
"""Return cached values given layer id."""
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id].view(self.dtype)
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
"""Return cached keys and values given layer id."""
|
||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: 'Attention',
|
||||
loc: paddle.Tensor,
|
||||
cache_k: paddle.Tensor,
|
||||
cache_v: paddle.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
"""Set cached keys and values given layer id."""
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
if v_scale is not None:
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMeta():
|
||||
"""
|
||||
ForwardMeta is used to store the global meta information of the forward.
|
||||
"""
|
||||
#
|
||||
input_ids: paddle.Tensor
|
||||
|
||||
#attention meta
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
|
||||
#
|
||||
ids_remove_padding: paddle.Tensor = None
|
||||
|
||||
#
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attn_backend: 'AttentionBackend' = None
|
||||
|
||||
#
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
pre_caches_length: int = 0
|
||||
|
||||
# Use cuda graph in this step. Used to avoid run cuda graph when in dummy run or prefill stage.
|
||||
step_use_cudagraph: bool = False
|
||||
|
||||
# for attention backend
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# for attention backend
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def init_forward_meta(cls, share_inputs: Dict,
|
||||
attn_backend: "AttentionBackend"):
|
||||
""" init forward meta """
|
||||
# TODO(gongshaotian): delete this func
|
||||
ret = cls(
|
||||
forward_mode=ForwardMode.MIXED,
|
||||
input_ids=share_inputs["input_ids"],
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=share_inputs["cum_offsets"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
attn_backend=attn_backend,
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
padding_offset=share_inputs["padding_offset"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
caches=share_inputs["caches"],
|
||||
decoder_batch_ids=share_inputs.get("decoder_batch_ids", None),
|
||||
decoder_tile_ids_per_batch=share_inputs.get(
|
||||
"decoder_tile_ids_per_batch", None),
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class XPUForwardMeta(ForwardMeta):
|
||||
"""
|
||||
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
|
||||
"""
|
||||
#
|
||||
encoder_batch_map: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_map: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_batch_idx: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_idx: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_seq_lod: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len_cache: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_batch_map_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_map_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_batch_idx_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_idx_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_seq_lod_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
batch_tensor: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
enc_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
dec_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
total_enc_len: Optional[paddle.Tensor] = None
|
Reference in New Issue
Block a user