mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
Clear dead code And supplementary notes (#2757)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* 1.supplementary notes 2.delete dead code * fix bug of forward meta * Global modification of forward meta * fix vl model_runner bug
This commit is contained in:
@@ -46,13 +46,9 @@ class ConcreteSizeEntry:
|
||||
# Output buffer of cudagraph
|
||||
output_buffer: Optional[paddle.Tensor] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class CudaGraphPiecewiseBackend:
|
||||
""" """
|
||||
""" Manage the capture and replay of CUDA graphs at the subgraph level. """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -65,33 +61,31 @@ class CudaGraphPiecewiseBackend:
|
||||
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
|
||||
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
|
||||
|
||||
# runtime_bs -> ConcreteSizeEntry
|
||||
# Runtime batch size -> ConcreteSizeEntry
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
for shape in self.cudagraph_capture_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_bs=shape)
|
||||
|
||||
print("[CUDA GRAPH] Created all batch size entry ")
|
||||
logger.debug("[CUDA GRAPH] Created all batch size entry ")
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get batch size
|
||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||
batch_size = ids_remove_padding.shape[0]
|
||||
|
||||
padding_batch_size = self.batch_size_to_captured_size[batch_size]
|
||||
# print(
|
||||
# f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
|
||||
# f"The padded batch size is :{padding_batch_size}"
|
||||
# )
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
|
||||
f"The padded batch size is :{padding_batch_size}")
|
||||
|
||||
entry = self.concrete_size_entries.get(padding_batch_size)
|
||||
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.runnable
|
||||
# print(
|
||||
# f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
|
||||
# )
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
|
||||
)
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(**kwargs)
|
||||
@@ -102,10 +96,10 @@ class CudaGraphPiecewiseBackend:
|
||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||
entry.num_finished_warmup += 1
|
||||
entry.runnable(**kwargs)
|
||||
# print(
|
||||
# "[CUDA GRAPH] Warm up for batch size ",
|
||||
# f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
# )
|
||||
logger.debug(
|
||||
"[CUDA GRAPH] Warm up for batch size ",
|
||||
f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
)
|
||||
|
||||
# Store input addresses for debug
|
||||
input_addresses = [
|
||||
@@ -129,11 +123,13 @@ class CudaGraphPiecewiseBackend:
|
||||
output._clear
|
||||
|
||||
paddle.device.synchronize()
|
||||
# print(
|
||||
# f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
|
||||
# )
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
|
||||
)
|
||||
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
# print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}"
|
||||
)
|
||||
return entry.output_buffer
|
||||
|
@@ -28,7 +28,7 @@ _T = TypeVar("_T", bound=type[paddle.nn.Layer])
|
||||
|
||||
def support_graph_optimization(cls: Optional[_T] = None) -> _T:
|
||||
"""
|
||||
A decorator for wrapping models or layers with CUDA graph support.
|
||||
A decorator for wrapping models or layers with static graph and CUDAGraph support.
|
||||
This enables efficient kernel launch sequencing for improved GPU performance.
|
||||
|
||||
Example usage:
|
||||
@@ -74,7 +74,7 @@ def support_graph_optimization(cls: Optional[_T] = None) -> _T:
|
||||
|
||||
|
||||
class GraphOptWrapper:
|
||||
""" """
|
||||
""" The wrapper for GraphOptBackend """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -87,7 +87,7 @@ class GraphOptWrapper:
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, **kwargs):
|
||||
""" """
|
||||
""" Abstract methods for implementing model.forward() """
|
||||
pass
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
|
@@ -24,7 +24,10 @@ from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend im
|
||||
|
||||
|
||||
class GraphOptBackend:
|
||||
""" """
|
||||
"""
|
||||
Integrated various graph optimization functions, including dynamic graph to static graph conversion,
|
||||
CINN compilation optimization, CudaGraph, and so on.
|
||||
"""
|
||||
|
||||
fd_config: FDConfig
|
||||
cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None
|
||||
|
@@ -436,8 +436,24 @@ class MTPProposer(Proposer):
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
self.forward_meta = ForwardMeta.init_forward_meta(
|
||||
self.model_inputs, self.attn_backends[0])
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.model_inputs["input_ids"],
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.model_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
|
||||
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
|
||||
cum_offsets=self.model_inputs["cum_offsets"],
|
||||
padding_offset=self.model_inputs["padding_offset"],
|
||||
cu_seqlens_q=self.model_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
|
||||
block_tables=self.model_inputs["block_tables"],
|
||||
caches=self.model_inputs["caches"]
|
||||
)
|
||||
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
|
@@ -14,18 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.layers.attention import (Attention,
|
||||
AttentionBackend)
|
||||
from fastdeploy.model_executor.layers.attention import AttentionBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,333 +31,79 @@ class ForwardMode(IntEnum):
|
||||
"""
|
||||
Forward mode used during attention.
|
||||
"""
|
||||
|
||||
# for prefill and extend
|
||||
# Prefill and Extend mode
|
||||
EXTEND = auto()
|
||||
# for generation
|
||||
# Decode mode
|
||||
DECODE = auto()
|
||||
|
||||
# Mixed mode
|
||||
MIXED = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
"""Whether it's a prefill forward"""
|
||||
""" Is Extend mode """
|
||||
return self == ForwardMode.EXTEND
|
||||
|
||||
def is_decode(self):
|
||||
"""Whether it's a decode forward"""
|
||||
""" Is Decode mode """
|
||||
return self == ForwardMode.DECODE
|
||||
|
||||
def is_mixed(self):
|
||||
"""Whether it's a decode forward"""
|
||||
""" Is Mixed mode """
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(self, size: int, max_context_len: int):
|
||||
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.req_to_token = paddle.zeros((size, max_context_len),
|
||||
dtype=paddle.int32)
|
||||
self.free_slots = list(range(size))
|
||||
|
||||
def write(self, indices, values):
|
||||
"""Write data into request buffer"""
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def available_size(self):
|
||||
"""Get number of slots left"""
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
"""Allocate `need_size` slots"""
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
"""Free slot"""
|
||||
if isinstance(free_index, (int, )):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all slots"""
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
"""Abstract base class representing a key value cache"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self,
|
||||
layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Return cached keys and values given layer id.
|
||||
Args:
|
||||
layer_id: int
|
||||
Returns:
|
||||
tuple: (keys, values)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: 'Attention',
|
||||
loc: paddle.Tensor,
|
||||
cache_k: paddle.Tensor,
|
||||
cache_v: paddle.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Set cached keys and values given layer id.
|
||||
Args:
|
||||
layer: Attention
|
||||
loc: paddle.Tensor
|
||||
cache_k: paddle.Tensor
|
||||
cache_v: paddle.Tensor
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
"""Transfer kv_data between devices"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
"""Not used yet"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
"""Not used yet"""
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
"""Token To Key Value Pool for MultiHeadAttention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_block_num: int,
|
||||
block_size: int,
|
||||
dtype: paddle.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
):
|
||||
self.max_block_num = max_block_num
|
||||
self.block_size = block_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (paddle.int8, paddle.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = paddle.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
GB = 1024 * 1024 * 1024
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {self.size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
paddle.zeros(
|
||||
(self.max_block_num, self.head_num, self.block_size,
|
||||
self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
) for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
paddle.zeros(
|
||||
(self.max_block_num, self.head_num, self.block_size,
|
||||
self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
) for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
"""for debugging purpose"""
|
||||
assert hasattr(self, "k_buffer")
|
||||
assert hasattr(self, "v_buffer")
|
||||
k_size_bytes = 0
|
||||
for k_cache in self.k_buffer:
|
||||
k_size_bytes += np.prod(k_cache.shape) * 4
|
||||
v_size_bytes = 0
|
||||
for v_cache in self.v_buffer:
|
||||
v_size_bytes += np.prod(v_cache.shape) * 4
|
||||
return k_size_bytes, v_size_bytes
|
||||
|
||||
def transfer(self, indices, flat_data):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
for i in range(self.layer_num):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data for a specific layer from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
self.k_buffer[layer_id][indices] = k_data
|
||||
self.v_buffer[layer_id][indices] = v_data
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
"""Return cached keys given layer id."""
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
"""Return cached values given layer id."""
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id].view(self.dtype)
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
"""Return cached keys and values given layer id."""
|
||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: 'Attention',
|
||||
loc: paddle.Tensor,
|
||||
cache_k: paddle.Tensor,
|
||||
cache_v: paddle.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
"""Set cached keys and values given layer id."""
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
if v_scale is not None:
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMeta():
|
||||
"""
|
||||
ForwardMeta is used to store the global meta information of the forward.
|
||||
ForwardMeta is used to store the global meta information of the model forward.
|
||||
"""
|
||||
#
|
||||
# Input tokens IDs
|
||||
input_ids: paddle.Tensor
|
||||
|
||||
#attention meta
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
|
||||
#
|
||||
ids_remove_padding: paddle.Tensor = None
|
||||
|
||||
#
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attn_backend: 'AttentionBackend' = None
|
||||
|
||||
#
|
||||
# Input tokens IDs of removed padding
|
||||
ids_remove_padding: paddle.Tensor
|
||||
# Rotation position embedding
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
pre_caches_length: int = 0
|
||||
|
||||
# Use cuda graph in this step. Used to avoid run cuda graph when in dummy run or prefill stage.
|
||||
# Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage.
|
||||
step_use_cudagraph: bool = False
|
||||
|
||||
# for attention backend
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# for attention backend
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# is_decode_batch or not
|
||||
# Batch type flag
|
||||
is_decode_batch: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_forward_meta(cls, share_inputs: Dict,
|
||||
attn_backend: "AttentionBackend"):
|
||||
""" init forward meta """
|
||||
# TODO(gongshaotian): delete this func
|
||||
ret = cls(
|
||||
forward_mode=ForwardMode.MIXED,
|
||||
input_ids=share_inputs["input_ids"],
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=share_inputs["cum_offsets"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
attn_backend=attn_backend,
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
padding_offset=share_inputs["padding_offset"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
caches=share_inputs["caches"],
|
||||
decoder_batch_ids=share_inputs.get("decoder_batch_ids", None),
|
||||
decoder_tile_ids_per_batch=share_inputs.get(
|
||||
"decoder_tile_ids_per_batch", None),
|
||||
)
|
||||
return ret
|
||||
|
||||
# Attention backend object
|
||||
attn_backend: 'AttentionBackend' = None
|
||||
# Forward mode used during attention
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
# Attention mask
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
# Decoder batch id. Used by attention backend.
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Tile ID for each batch of the decoder. Used by attention backend.
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
# Sequence length of encoder for ever batch
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
# Sequence length of Encoder for ever batch
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
# The sequence length processed in the current step
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
|
||||
# Accumulated offset
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
# Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
# Accumulated sequence length of query
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
# Accumulated sequence length of key
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
|
||||
# Pre-cache length
|
||||
pre_caches_length: int = 0
|
||||
# Block tables
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
# KV caches
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
def clear_caches(self):
|
||||
"""safe clear caches"""
|
||||
""" Safely clean up the caches """
|
||||
if self.caches:
|
||||
del self.caches
|
||||
|
||||
@@ -370,56 +113,42 @@ class XPUForwardMeta(ForwardMeta):
|
||||
"""
|
||||
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
|
||||
"""
|
||||
# TODO(wanghaitao): Supplementary notes
|
||||
#
|
||||
encoder_batch_map: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_map: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_batch_idx: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_idx: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_seq_lod: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len_cache: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_batch_map_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_map_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_batch_idx_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_batch_idx_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
encoder_seq_lod_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
batch_tensor: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
enc_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
dec_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
total_enc_len: Optional[paddle.Tensor] = None
|
||||
|
@@ -606,8 +606,23 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
self.forward_meta = ForwardMeta.init_forward_meta(
|
||||
self.share_inputs, self.attn_backends[0])
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=self.share_inputs["cum_offsets"],
|
||||
padding_offset=self.share_inputs["padding_offset"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"]
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
|
@@ -48,7 +48,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||
|
||||
|
||||
class GPUModelRunner(ModelRunnerBase):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -81,9 +80,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||
self.cudagraph_capture_sizes = list(
|
||||
reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||
self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups
|
||||
self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs,
|
||||
dtype='int32')
|
||||
|
||||
# Initialize share inputs
|
||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||
@@ -94,7 +90,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.restore_chunked_prefill_request = dict()
|
||||
|
||||
# Initialize attention Backend
|
||||
# Note(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
# NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
# In the future, we will expand it as a list.
|
||||
self.attn_backends: list[AttentionBackend] = []
|
||||
# self.attn_metadatas: list[AttentionMetadata] = []
|
||||
@@ -110,14 +106,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def prefill_finished(self):
|
||||
"""
|
||||
check whether prefill stage finished
|
||||
Check whether prefill stage finished
|
||||
"""
|
||||
if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def init_speculative_proposer(self):
|
||||
def _init_speculative_proposer(self):
|
||||
"""
|
||||
Init speculative proposer
|
||||
"""
|
||||
@@ -333,8 +329,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
(idx + 1) * block_num, 1)
|
||||
|
||||
def _init_share_inputs(self, max_num_seqs: int):
|
||||
"""Initialize all share buffers for model inputs.
|
||||
Note: In the future, we may abandon share buffers.
|
||||
"""
|
||||
Initialize all share buffers for model inputs.
|
||||
"""
|
||||
self.MAX_INFER_SEED = 9223372036854775806
|
||||
self.share_inputs = {}
|
||||
@@ -469,6 +465,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(
|
||||
self.parallel_config.max_model_len).reshape((1, -1))
|
||||
|
||||
# TODO(gongshaotian): move to models
|
||||
self.share_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
@@ -536,7 +533,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
dtype="int32")
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
""" prepare the model inputs """
|
||||
""" Prepare the model inputs """
|
||||
# Remove padding
|
||||
(
|
||||
ids_remove_padding,
|
||||
@@ -595,7 +592,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import \
|
||||
DynamicWeightManager
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
||||
self.dynamic_weight_manager = DynamicWeightManager(
|
||||
self.fd_config, self.model)
|
||||
|
||||
# 2. Load lora model
|
||||
|
||||
@@ -606,10 +604,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
f"Model loading took {time_after_load - time_before_load} seconds")
|
||||
|
||||
# 4. Init proposer for speculative method
|
||||
self.init_speculative_proposer()
|
||||
self._init_speculative_proposer()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
""" get current model """
|
||||
""" Get current model """
|
||||
return self.model
|
||||
|
||||
def initialize_forward_meta(self):
|
||||
@@ -617,32 +615,28 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
self.forward_meta = ForwardMeta.init_forward_meta(
|
||||
self.share_inputs, self.attn_backends[0])
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=self.share_inputs["cum_offsets"],
|
||||
padding_offset=self.share_inputs["padding_offset"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"]
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cached data from shared inputs and forward metadata."""
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
""""dynamic model loader use to clear parameters use for RL"""
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
""""dynamic model loader use to update parameters use for RL"""
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
"""
|
||||
Initialize kv cache
|
||||
@@ -701,11 +695,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def initialize_attn_backend(self) -> None:
|
||||
"""
|
||||
Initialize attention backends and forward metadata
|
||||
Initialize attention backends
|
||||
"""
|
||||
assert len(self.attn_backends) == 0
|
||||
|
||||
# TODO(gongshaotian): Get rank from config
|
||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree
|
||||
self.model_config.kv_num_heads = int(
|
||||
self.model_config.num_key_value_heads
|
||||
@@ -718,10 +711,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def _dummy_run(self,
|
||||
@@ -745,14 +735,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
expected_decode_len=expected_decode_len)
|
||||
while True:
|
||||
|
||||
# 1. Compute real num_tokens
|
||||
# 1. Initialize forward meta and attention meta data
|
||||
self._prepare_inputs()
|
||||
|
||||
# 2. Initialize attention backend and forward meta data
|
||||
# 2. Prepare lora
|
||||
|
||||
# 3. Prepare lora
|
||||
|
||||
# 4. Run model
|
||||
# 3. Run model
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
|
||||
@@ -773,7 +761,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.parallel_config.max_model_len,
|
||||
)
|
||||
|
||||
# 5. Execute spec decode
|
||||
# 4. Execute spec decode
|
||||
logits = self.model.compute_logits(hiddden_states)
|
||||
|
||||
if not self.speculative_decoding:
|
||||
@@ -805,7 +793,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"], 0)
|
||||
|
||||
# 6. post process
|
||||
# 5. post process
|
||||
model_output_data = ModelOutputData(
|
||||
next_tokens=self.share_inputs["next_tokens"],
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
@@ -858,7 +846,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def _update_chunked_prefill(self, tasks):
|
||||
"""
|
||||
更新chunked prefill相关参数
|
||||
Update chunked prefill related parameters
|
||||
"""
|
||||
if not self.parallel_config.enable_chunked_prefill:
|
||||
return
|
||||
@@ -903,13 +891,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.proposer.update_task_chunk_prefill(task)
|
||||
task.chunk_idx += 1
|
||||
|
||||
def _dummy_sampler_run(self) -> paddle.Tensor:
|
||||
""" """
|
||||
pass
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
|
||||
Trigger CUDA Graph capture for all shapes in cuda graph capture list
|
||||
"""
|
||||
if not self.use_cudagraph:
|
||||
logger.info(
|
||||
@@ -933,7 +917,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds"
|
||||
)
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch):
|
||||
def _get_skip_idx(self,
|
||||
model_forward_batch: Optional[List[Request]] = None):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
@@ -972,20 +957,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
We plan to replace it with 'ModelForwardBatch'.
|
||||
intermediate_tensors:
|
||||
"""
|
||||
# Note(@wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
if not self.not_need_stop():
|
||||
self._execute_empty_input()
|
||||
return None
|
||||
|
||||
# 1. Prepare inputs of model and decoder.
|
||||
# sampler create async operation
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# 2. Padding inputs for cuda grph
|
||||
# 2. Padding inputs for cuda graph
|
||||
|
||||
# 3. Execute model
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
@@ -1136,7 +1120,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||
|
||||
def profile_run(self) -> None:
|
||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
|
||||
""" Execute a forward pass with dummy inputs to profile the memory usage of the model """
|
||||
|
||||
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
||||
# TODO(gongshaotian): Optimize the management logic of kvcache
|
||||
@@ -1222,5 +1206,26 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
return required_memory
|
||||
|
||||
def not_need_stop(self) -> bool:
|
||||
""" """
|
||||
""" Stop decoding if the tensor meets the termination condition """
|
||||
return self.share_inputs["not_need_stop"][0]
|
||||
|
||||
def clear_cache(self):
|
||||
""" Clear cached data from shared inputs and forward metadata """
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
"""" Dynamic model loader use to clear parameters use for RL """
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
self.dynamic_weight_manager._log_memory(
|
||||
"dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
"""" Dynamic model loader use to update parameters use for RL """
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
self.dynamic_weight_manager._log_memory(
|
||||
"dynamic weight manager update all memory")
|
||||
|
@@ -32,7 +32,6 @@ logger = get_logger("gpu_worker", "gpu_worker.log")
|
||||
|
||||
|
||||
class GpuWorker(WorkerBase):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -48,7 +47,8 @@ class GpuWorker(WorkerBase):
|
||||
pass
|
||||
|
||||
def init_device(self):
|
||||
""" Initialize device and Construct model runner
|
||||
"""
|
||||
Initialize device and construct model runner
|
||||
"""
|
||||
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda(
|
||||
):
|
||||
@@ -74,10 +74,10 @@ class GpuWorker(WorkerBase):
|
||||
device_id=self.device_ids[self.local_rank],
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank)
|
||||
|
||||
|
||||
def prefill_finished(self):
|
||||
"""
|
||||
check whether prefill stage finished
|
||||
Check whether prefill stage finished
|
||||
"""
|
||||
return self.model_runner.prefill_finished()
|
||||
|
||||
@@ -115,7 +115,8 @@ class GpuWorker(WorkerBase):
|
||||
f"\nDevice used memory: {before_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {before_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"))
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"
|
||||
))
|
||||
|
||||
# 2. Profile run
|
||||
self.model_runner.profile_run()
|
||||
@@ -126,15 +127,6 @@ class GpuWorker(WorkerBase):
|
||||
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(
|
||||
self.local_rank)
|
||||
|
||||
|
||||
|
||||
# NOTE(gongshaotian): v1 worker
|
||||
# not_paddle_use_mem = after_run_meminfo.used - paddle_reserved_mem_after_run
|
||||
# peak_memory = paddle_allocated_mem_after_run + not_paddle_use_mem
|
||||
# available_kv_cache_memory = after_run_meminfo.total * \
|
||||
# self.parallel_config.gpu_memory_utilization - peak_memory
|
||||
|
||||
# v0 worker
|
||||
model_block_memory_used = self.cal_theortical_kvcache()
|
||||
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
||||
|
||||
@@ -146,32 +138,31 @@ class GpuWorker(WorkerBase):
|
||||
available_kv_cache_memory = after_run_meminfo.total * \
|
||||
self.parallel_config.gpu_memory_utilization - after_run_meminfo.used - paddle_peak_increase
|
||||
available_kv_cache_memory += model_block_memory_used * self.parallel_config.max_block_num
|
||||
|
||||
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
("After running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
|
||||
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||
f"Profile time: {end_time - start_time}"))
|
||||
logger.info((
|
||||
"After running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
|
||||
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||
f"Profile time: {end_time - start_time}"))
|
||||
|
||||
return available_kv_cache_memory # return to caculate the block num in this device
|
||||
|
||||
def load_model(self) -> None:
|
||||
""" """
|
||||
""" Load model """
|
||||
self.model_runner.load_model()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
""" """
|
||||
""" Get current model """
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
""" """
|
||||
""" Initizlize the KV Cache """
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
@@ -193,10 +184,7 @@ class GpuWorker(WorkerBase):
|
||||
"""
|
||||
Perform the warm-up and the graph optimization
|
||||
"""
|
||||
# 1. Warm up model
|
||||
# NOTE(gongshaotian): may be not need warm_up at this place
|
||||
|
||||
# 2. Triger cuda grpah capture
|
||||
# Triger cuda grpah capture
|
||||
self.model_runner.capture_model()
|
||||
|
||||
def check_health(self) -> bool:
|
||||
@@ -204,10 +192,10 @@ class GpuWorker(WorkerBase):
|
||||
return True
|
||||
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
""" """
|
||||
""" Calculate the block memory required """
|
||||
return self.model_runner.cal_theortical_kvcache()
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
""" Reinitialize the kv cache using the parameters from the profile """
|
||||
self.model_runner.update_share_input_block_num(
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
@@ -593,8 +593,23 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
self.forward_meta = ForwardMeta.init_forward_meta(
|
||||
self.share_inputs, self.attn_backends[0])
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=self.share_inputs["cum_offsets"],
|
||||
padding_offset=self.share_inputs["padding_offset"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"]
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
|
@@ -816,9 +816,23 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full(
|
||||
[self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32')
|
||||
# initialize_forward_meta
|
||||
self.forward_meta = ForwardMeta.init_forward_meta(
|
||||
self.share_inputs, self.attn_backend)
|
||||
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backend,
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=self.share_inputs["cum_offsets"],
|
||||
padding_offset=self.share_inputs["padding_offset"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"]
|
||||
)
|
||||
self.attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
|
@@ -70,7 +70,21 @@ def xpu_pre_process(
|
||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||
|
||||
xpu_forward_meta = XPUForwardMeta.init_forward_meta(share_inputs, None)
|
||||
xpu_forward_meta = XPUForwardMeta(
|
||||
input_ids=share_inputs["input_ids"],
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
attn_backend=None,
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=share_inputs["cum_offsets"],
|
||||
padding_offset=share_inputs["padding_offset"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
caches=share_inputs["caches"]
|
||||
)
|
||||
|
||||
# Get xpu extra param
|
||||
(
|
||||
|
@@ -21,8 +21,7 @@ import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention import (
|
||||
Attention, PaddleNativeAttnBackend)
|
||||
from fastdeploy.worker.forward_meta import (ForwardMeta, ForwardMode,
|
||||
MHATokenToKVPool)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
|
||||
class MockModelRunner:
|
||||
@@ -63,15 +62,6 @@ class MockModelRunner:
|
||||
},
|
||||
)
|
||||
self.page_size = page_size
|
||||
max_total_num_tokens = max_batch_size * max_context_len
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
size=max_total_num_tokens,
|
||||
page_size=page_size,
|
||||
dtype=self.dtype,
|
||||
head_num=num_heads,
|
||||
head_dim=head_dim,
|
||||
layer_num=1, # only consider layer=1 for unit test
|
||||
device=self.device)
|
||||
|
||||
|
||||
class TestNativePaddleAttentionBackend(unittest.TestCase):
|
||||
|
Reference in New Issue
Block a user