Clear dead code And supplementary notes (#2757)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* 1.supplementary notes 2.delete dead code

* fix bug of forward meta

* Global modification of forward meta

* fix vl model_runner bug
This commit is contained in:
RAM
2025-07-09 16:17:34 +08:00
committed by GitHub
parent b89180f1cd
commit 03a74995b8
12 changed files with 248 additions and 463 deletions

View File

@@ -46,13 +46,9 @@ class ConcreteSizeEntry:
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CudaGraphPiecewiseBackend:
""" """
""" Manage the capture and replay of CUDA graphs at the subgraph level. """
def __init__(
self,
@@ -65,33 +61,31 @@ class CudaGraphPiecewiseBackend:
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
# runtime_bs -> ConcreteSizeEntry
# Runtime batch size -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_bs=shape)
print("[CUDA GRAPH] Created all batch size entry ")
logger.debug("[CUDA GRAPH] Created all batch size entry ")
def __call__(self, **kwargs):
# Get batch size
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
batch_size = ids_remove_padding.shape[0]
padding_batch_size = self.batch_size_to_captured_size[batch_size]
# print(
# f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
# f"The padded batch size is :{padding_batch_size}"
# )
logger.debug(
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
f"The padded batch size is :{padding_batch_size}")
entry = self.concrete_size_entries.get(padding_batch_size)
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
# print(
# f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
# )
logger.debug(
f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
)
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
@@ -102,10 +96,10 @@ class CudaGraphPiecewiseBackend:
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
# print(
# "[CUDA GRAPH] Warm up for batch size ",
# f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
# )
logger.debug(
"[CUDA GRAPH] Warm up for batch size ",
f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
)
# Store input addresses for debug
input_addresses = [
@@ -129,11 +123,13 @@ class CudaGraphPiecewiseBackend:
output._clear
paddle.device.synchronize()
# print(
# f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
# )
logger.debug(
f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
)
# Replay
entry.cuda_graph.replay()
# print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
logger.debug(
f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}"
)
return entry.output_buffer

View File

@@ -28,7 +28,7 @@ _T = TypeVar("_T", bound=type[paddle.nn.Layer])
def support_graph_optimization(cls: Optional[_T] = None) -> _T:
"""
A decorator for wrapping models or layers with CUDA graph support.
A decorator for wrapping models or layers with static graph and CUDAGraph support.
This enables efficient kernel launch sequencing for improved GPU performance.
Example usage:
@@ -74,7 +74,7 @@ def support_graph_optimization(cls: Optional[_T] = None) -> _T:
class GraphOptWrapper:
""" """
""" The wrapper for GraphOptBackend """
def __init__(
self,
@@ -87,7 +87,7 @@ class GraphOptWrapper:
@abstractmethod
def forward(self, **kwargs):
""" """
""" Abstract methods for implementing model.forward() """
pass
def __call__(self, **kwargs):

View File

@@ -24,7 +24,10 @@ from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend im
class GraphOptBackend:
""" """
"""
Integrated various graph optimization functions, including dynamic graph to static graph conversion,
CINN compilation optimization, CudaGraph, and so on.
"""
fd_config: FDConfig
cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None

View File

@@ -436,8 +436,24 @@ class MTPProposer(Proposer):
Initialize forward meta and attention meta data
"""
# Initialize forward meta
self.forward_meta = ForwardMeta.init_forward_meta(
self.model_inputs, self.attn_backends[0])
self.forward_meta = ForwardMeta(
input_ids=self.model_inputs["input_ids"],
ids_remove_padding=self.model_inputs["ids_remove_padding"],
rotary_embs=self.model_inputs["rope_emb"],
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
cum_offsets=self.model_inputs["cum_offsets"],
padding_offset=self.model_inputs["padding_offset"],
cu_seqlens_q=self.model_inputs["cu_seqlens_q"],
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
block_tables=self.model_inputs["block_tables"],
caches=self.model_inputs["caches"]
)
# Initialzie attention meta data
for attn_backend in self.attn_backends:

View File

@@ -14,18 +14,15 @@
# limitations under the License.
"""
import abc
import logging
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional
import numpy as np
import paddle
if TYPE_CHECKING:
from fastdeploy.model_executor.layers.attention import (Attention,
AttentionBackend)
from fastdeploy.model_executor.layers.attention import AttentionBackend
logger = logging.getLogger(__name__)
@@ -34,333 +31,79 @@ class ForwardMode(IntEnum):
"""
Forward mode used during attention.
"""
# for prefill and extend
# Prefill and Extend mode
EXTEND = auto()
# for generation
# Decode mode
DECODE = auto()
# Mixed mode
MIXED = auto()
def is_prefill(self):
"""Whether it's a prefill forward"""
""" Is Extend mode """
return self == ForwardMode.EXTEND
def is_decode(self):
"""Whether it's a decode forward"""
""" Is Decode mode """
return self == ForwardMode.DECODE
def is_mixed(self):
"""Whether it's a decode forward"""
""" Is Mixed mode """
return self == ForwardMode.MIXED
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int):
self.size = size
self.max_context_len = max_context_len
self.req_to_token = paddle.zeros((size, max_context_len),
dtype=paddle.int32)
self.free_slots = list(range(size))
def write(self, indices, values):
"""Write data into request buffer"""
self.req_to_token[indices] = values
def available_size(self):
"""Get number of slots left"""
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
"""Allocate `need_size` slots"""
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
"""Free slot"""
if isinstance(free_index, (int, )):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def clear(self):
"""Clear all slots"""
self.free_slots = list(range(self.size))
class KVCache(abc.ABC):
"""Abstract base class representing a key value cache"""
@abc.abstractmethod
def get_kv_buffer(self,
layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Return cached keys and values given layer id.
Args:
layer_id: int
Returns:
tuple: (keys, values)
"""
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: 'Attention',
loc: paddle.Tensor,
cache_k: paddle.Tensor,
cache_v: paddle.Tensor,
) -> None:
"""
Set cached keys and values given layer id.
Args:
layer: Attention
loc: paddle.Tensor
cache_k: paddle.Tensor
cache_v: paddle.Tensor
"""
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
"""Transfer kv_data between devices"""
raise NotImplementedError()
@abc.abstractmethod
def transfer_per_layer(self, indices, flat_data, layer_id):
"""Not used yet"""
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
"""Not used yet"""
self.layer_transfer_counter = layer_transfer_counter
class MHATokenToKVPool(KVCache):
"""Token To Key Value Pool for MultiHeadAttention"""
def __init__(
self,
max_block_num: int,
block_size: int,
dtype: paddle.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
):
self.max_block_num = max_block_num
self.block_size = block_size
self.dtype = dtype
self.device = device
if dtype in (paddle.int8, paddle.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = paddle.uint8
else:
self.store_dtype = dtype
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
k_size, v_size = self.get_kv_size_bytes()
GB = 1024 * 1024 * 1024
logger.info(
f"KV Cache is allocated. #tokens: {self.size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
paddle.zeros(
(self.max_block_num, self.head_num, self.block_size,
self.head_dim),
dtype=self.store_dtype,
) for _ in range(self.layer_num)
]
self.v_buffer = [
paddle.zeros(
(self.max_block_num, self.head_num, self.block_size,
self.head_dim),
dtype=self.store_dtype,
) for _ in range(self.layer_num)
]
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
def get_kv_size_bytes(self):
"""for debugging purpose"""
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * 4
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * 4
return k_size_bytes, v_size_bytes
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
for i in range(self.layer_num):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data for a specific layer from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
def get_key_buffer(self, layer_id: int):
"""Return cached keys given layer id."""
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
"""Return cached values given layer id."""
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
"""Return cached keys and values given layer id."""
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer: 'Attention',
loc: paddle.Tensor,
cache_k: paddle.Tensor,
cache_v: paddle.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
"""Set cached keys and values given layer id."""
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
@dataclass
class ForwardMeta():
"""
ForwardMeta is used to store the global meta information of the forward.
ForwardMeta is used to store the global meta information of the model forward.
"""
#
# Input tokens IDs
input_ids: paddle.Tensor
#attention meta
forward_mode: ForwardMode = ForwardMode.MIXED
#
ids_remove_padding: paddle.Tensor = None
#
seq_lens_encoder: Optional[paddle.Tensor] = None
#
seq_lens_decoder: Optional[paddle.Tensor] = None
#
seq_lens_this_time: Optional[paddle.Tensor] = None
#
cum_offsets: Optional[paddle.Tensor] = None
#
block_tables: Optional[paddle.Tensor] = None
#
attn_backend: 'AttentionBackend' = None
#
# Input tokens IDs of removed padding
ids_remove_padding: paddle.Tensor
# Rotation position embedding
rotary_embs: Optional[paddle.Tensor] = None
#
padding_offset: Optional[paddle.Tensor] = None
#
cu_seqlens_q: Optional[paddle.Tensor] = None
#
cu_seqlens_k: Optional[paddle.Tensor] = None
#
caches: Optional[paddle.Tensor] = None
#
attn_mask: Optional[paddle.Tensor] = None
#
pre_caches_length: int = 0
# Use cuda graph in this step. Used to avoid run cuda graph when in dummy run or prefill stage.
# Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage.
step_use_cudagraph: bool = False
# for attention backend
decoder_batch_ids: Optional[paddle.Tensor] = None
# for attention backend
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# is_decode_batch or not
# Batch type flag
is_decode_batch: bool = False
@classmethod
def init_forward_meta(cls, share_inputs: Dict,
attn_backend: "AttentionBackend"):
""" init forward meta """
# TODO(gongshaotian): delete this func
ret = cls(
forward_mode=ForwardMode.MIXED,
input_ids=share_inputs["input_ids"],
ids_remove_padding=share_inputs["ids_remove_padding"],
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
cum_offsets=share_inputs["cum_offsets"],
block_tables=share_inputs["block_tables"],
attn_backend=attn_backend,
rotary_embs=share_inputs["rope_emb"],
padding_offset=share_inputs["padding_offset"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
caches=share_inputs["caches"],
decoder_batch_ids=share_inputs.get("decoder_batch_ids", None),
decoder_tile_ids_per_batch=share_inputs.get(
"decoder_tile_ids_per_batch", None),
)
return ret
# Attention backend object
attn_backend: 'AttentionBackend' = None
# Forward mode used during attention
forward_mode: ForwardMode = ForwardMode.MIXED
# Attention mask
attn_mask: Optional[paddle.Tensor] = None
# Decoder batch id. Used by attention backend.
decoder_batch_ids: Optional[paddle.Tensor] = None
# Tile ID for each batch of the decoder. Used by attention backend.
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None
# Sequence length of Encoder for ever batch
seq_lens_decoder: Optional[paddle.Tensor] = None
# The sequence length processed in the current step
seq_lens_this_time: Optional[paddle.Tensor] = None
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
# Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids
padding_offset: Optional[paddle.Tensor] = None
# Accumulated sequence length of query
cu_seqlens_q: Optional[paddle.Tensor] = None
# Accumulated sequence length of key
cu_seqlens_k: Optional[paddle.Tensor] = None
# Pre-cache length
pre_caches_length: int = 0
# Block tables
block_tables: Optional[paddle.Tensor] = None
# KV caches
caches: Optional[paddle.Tensor] = None
def clear_caches(self):
"""safe clear caches"""
""" Safely clean up the caches """
if self.caches:
del self.caches
@@ -370,56 +113,42 @@ class XPUForwardMeta(ForwardMeta):
"""
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
"""
# TODO(wanghaitao): Supplementary notes
#
encoder_batch_map: Optional[paddle.Tensor] = None
#
decoder_batch_map: Optional[paddle.Tensor] = None
#
encoder_batch_idx: Optional[paddle.Tensor] = None
#
decoder_batch_idx: Optional[paddle.Tensor] = None
#
encoder_seq_lod: Optional[paddle.Tensor] = None
#
decoder_context_len: Optional[paddle.Tensor] = None
#
decoder_context_len_cache: Optional[paddle.Tensor] = None
#
encoder_batch_map_cpu: Optional[paddle.Tensor] = None
#
decoder_batch_map_cpu: Optional[paddle.Tensor] = None
#
encoder_batch_idx_cpu: Optional[paddle.Tensor] = None
#
decoder_batch_idx_cpu: Optional[paddle.Tensor] = None
#
encoder_seq_lod_cpu: Optional[paddle.Tensor] = None
#
decoder_context_len_cpu: Optional[paddle.Tensor] = None
#
decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None
#
batch_tensor: Optional[paddle.Tensor] = None
#
enc_batch: Optional[paddle.Tensor] = None
#
dec_batch: Optional[paddle.Tensor] = None
#
total_enc_len: Optional[paddle.Tensor] = None

View File

@@ -606,8 +606,23 @@ class GCUModelRunner(ModelRunnerBase):
Initialize forward meta and attention meta data
"""
# Initialize forward meta
self.forward_meta = ForwardMeta.init_forward_meta(
self.share_inputs, self.attn_backends[0])
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
cum_offsets=self.share_inputs["cum_offsets"],
padding_offset=self.share_inputs["padding_offset"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"]
)
# Initialzie attention meta data
for attn_backend in self.attn_backends:

View File

@@ -48,7 +48,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
class GPUModelRunner(ModelRunnerBase):
""" """
def __init__(
self,
@@ -81,9 +80,6 @@ class GPUModelRunner(ModelRunnerBase):
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(
reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups
self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs,
dtype='int32')
# Initialize share inputs
self._init_share_inputs(self.parallel_config.max_num_seqs)
@@ -94,7 +90,7 @@ class GPUModelRunner(ModelRunnerBase):
self.restore_chunked_prefill_request = dict()
# Initialize attention Backend
# Note(gonshaotian): Currently, all attention layers share one attention backend instance.
# NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
# In the future, we will expand it as a list.
self.attn_backends: list[AttentionBackend] = []
# self.attn_metadatas: list[AttentionMetadata] = []
@@ -110,14 +106,14 @@ class GPUModelRunner(ModelRunnerBase):
def prefill_finished(self):
"""
check whether prefill stage finished
Check whether prefill stage finished
"""
if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0:
return 1
else:
return 0
def init_speculative_proposer(self):
def _init_speculative_proposer(self):
"""
Init speculative proposer
"""
@@ -333,8 +329,8 @@ class GPUModelRunner(ModelRunnerBase):
(idx + 1) * block_num, 1)
def _init_share_inputs(self, max_num_seqs: int):
"""Initialize all share buffers for model inputs.
Note: In the future, we may abandon share buffers.
"""
Initialize all share buffers for model inputs.
"""
self.MAX_INFER_SEED = 9223372036854775806
self.share_inputs = {}
@@ -469,6 +465,7 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(
self.parallel_config.max_model_len).reshape((1, -1))
# TODO(gongshaotian): move to models
self.share_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim,
@@ -536,7 +533,7 @@ class GPUModelRunner(ModelRunnerBase):
dtype="int32")
def _prepare_inputs(self) -> None:
""" prepare the model inputs """
""" Prepare the model inputs """
# Remove padding
(
ids_remove_padding,
@@ -595,7 +592,8 @@ class GPUModelRunner(ModelRunnerBase):
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import \
DynamicWeightManager
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
self.dynamic_weight_manager = DynamicWeightManager(
self.fd_config, self.model)
# 2. Load lora model
@@ -606,10 +604,10 @@ class GPUModelRunner(ModelRunnerBase):
f"Model loading took {time_after_load - time_before_load} seconds")
# 4. Init proposer for speculative method
self.init_speculative_proposer()
self._init_speculative_proposer()
def get_model(self) -> nn.Layer:
""" get current model """
""" Get current model """
return self.model
def initialize_forward_meta(self):
@@ -617,32 +615,28 @@ class GPUModelRunner(ModelRunnerBase):
Initialize forward meta and attention meta data
"""
# Initialize forward meta
self.forward_meta = ForwardMeta.init_forward_meta(
self.share_inputs, self.attn_backends[0])
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
cum_offsets=self.share_inputs["cum_offsets"],
padding_offset=self.share_inputs["padding_offset"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"]
)
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
def clear_cache(self):
"""Clear cached data from shared inputs and forward metadata."""
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
def clear_parameters(self, pid):
""""dynamic model loader use to clear parameters use for RL"""
self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache()
paddle.device.cuda.empty_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
def update_parameters(self, pid):
""""dynamic model loader use to update parameters use for RL"""
self.dynamic_weight_manager.update_parameters(pid)
self.initialize_kv_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
def initialize_kv_cache(self) -> None:
"""
Initialize kv cache
@@ -701,11 +695,10 @@ class GPUModelRunner(ModelRunnerBase):
def initialize_attn_backend(self) -> None:
"""
Initialize attention backends and forward metadata
Initialize attention backends
"""
assert len(self.attn_backends) == 0
# TODO(gongshaotian): Get rank from config
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree
self.model_config.kv_num_heads = int(
self.model_config.num_key_value_heads
@@ -718,10 +711,7 @@ class GPUModelRunner(ModelRunnerBase):
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim)
if attn_backend is None:
raise NotImplementedError(
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
)
self.attn_backends.append(attn_backend)
def _dummy_run(self,
@@ -745,14 +735,12 @@ class GPUModelRunner(ModelRunnerBase):
expected_decode_len=expected_decode_len)
while True:
# 1. Compute real num_tokens
# 1. Initialize forward meta and attention meta data
self._prepare_inputs()
# 2. Initialize attention backend and forward meta data
# 2. Prepare lora
# 3. Prepare lora
# 4. Run model
# 3. Run model
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
@@ -773,7 +761,7 @@ class GPUModelRunner(ModelRunnerBase):
self.parallel_config.max_model_len,
)
# 5. Execute spec decode
# 4. Execute spec decode
logits = self.model.compute_logits(hiddden_states)
if not self.speculative_decoding:
@@ -805,7 +793,7 @@ class GPUModelRunner(ModelRunnerBase):
paddle.distributed.broadcast(
self.share_inputs["stop_flags"], 0)
# 6. post process
# 5. post process
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
@@ -858,7 +846,7 @@ class GPUModelRunner(ModelRunnerBase):
def _update_chunked_prefill(self, tasks):
"""
更新chunked prefill相关参数
Update chunked prefill related parameters
"""
if not self.parallel_config.enable_chunked_prefill:
return
@@ -903,13 +891,9 @@ class GPUModelRunner(ModelRunnerBase):
self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1
def _dummy_sampler_run(self) -> paddle.Tensor:
""" """
pass
def capture_model(self) -> None:
"""
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
Trigger CUDA Graph capture for all shapes in cuda graph capture list
"""
if not self.use_cudagraph:
logger.info(
@@ -933,7 +917,8 @@ class GPUModelRunner(ModelRunnerBase):
f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds"
)
def _get_skip_idx(self, model_forward_batch):
def _get_skip_idx(self,
model_forward_batch: Optional[List[Request]] = None):
"""
Get the index of the request that needs to be skipped during execution.
Args:
@@ -972,20 +957,19 @@ class GPUModelRunner(ModelRunnerBase):
We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors:
"""
# Note(@wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop():
self._execute_empty_input()
return None
# 1. Prepare inputs of model and decoder.
# sampler create async operation
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# 2. Padding inputs for cuda grph
# 2. Padding inputs for cuda graph
# 3. Execute model
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
@@ -1136,7 +1120,7 @@ class GPUModelRunner(ModelRunnerBase):
f"{type(self.model)} has no attribute 'empty_input_forward")
def profile_run(self) -> None:
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
""" Execute a forward pass with dummy inputs to profile the memory usage of the model """
# Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache
@@ -1222,5 +1206,26 @@ class GPUModelRunner(ModelRunnerBase):
return required_memory
def not_need_stop(self) -> bool:
""" """
""" Stop decoding if the tensor meets the termination condition """
return self.share_inputs["not_need_stop"][0]
def clear_cache(self):
""" Clear cached data from shared inputs and forward metadata """
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
def clear_parameters(self, pid):
"""" Dynamic model loader use to clear parameters use for RL """
self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache()
paddle.device.cuda.empty_cache()
self.dynamic_weight_manager._log_memory(
"dynamic weight manager clear all memory")
def update_parameters(self, pid):
"""" Dynamic model loader use to update parameters use for RL """
self.dynamic_weight_manager.update_parameters(pid)
self.initialize_kv_cache()
self.dynamic_weight_manager._log_memory(
"dynamic weight manager update all memory")

View File

@@ -32,7 +32,6 @@ logger = get_logger("gpu_worker", "gpu_worker.log")
class GpuWorker(WorkerBase):
""" """
def __init__(
self,
@@ -48,7 +47,8 @@ class GpuWorker(WorkerBase):
pass
def init_device(self):
""" Initialize device and Construct model runner
"""
Initialize device and construct model runner
"""
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda(
):
@@ -74,10 +74,10 @@ class GpuWorker(WorkerBase):
device_id=self.device_ids[self.local_rank],
rank=self.rank,
local_rank=self.local_rank)
def prefill_finished(self):
"""
check whether prefill stage finished
Check whether prefill stage finished
"""
return self.model_runner.prefill_finished()
@@ -115,7 +115,8 @@ class GpuWorker(WorkerBase):
f"\nDevice used memory: {before_run_meminfo.used / Gb}",
f"\nDevice free memory: {before_run_meminfo.free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"))
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"
))
# 2. Profile run
self.model_runner.profile_run()
@@ -126,15 +127,6 @@ class GpuWorker(WorkerBase):
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(
self.local_rank)
# NOTE(gongshaotian): v1 worker
# not_paddle_use_mem = after_run_meminfo.used - paddle_reserved_mem_after_run
# peak_memory = paddle_allocated_mem_after_run + not_paddle_use_mem
# available_kv_cache_memory = after_run_meminfo.total * \
# self.parallel_config.gpu_memory_utilization - peak_memory
# v0 worker
model_block_memory_used = self.cal_theortical_kvcache()
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
@@ -146,32 +138,31 @@ class GpuWorker(WorkerBase):
available_kv_cache_memory = after_run_meminfo.total * \
self.parallel_config.gpu_memory_utilization - after_run_meminfo.used - paddle_peak_increase
available_kv_cache_memory += model_block_memory_used * self.parallel_config.max_block_num
end_time = time.perf_counter()
logger.info(
("After running the profile, the memory usage info is as follows:",
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
f"Profile time: {end_time - start_time}"))
logger.info((
"After running the profile, the memory usage info is as follows:",
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
f"Profile time: {end_time - start_time}"))
return available_kv_cache_memory # return to caculate the block num in this device
def load_model(self) -> None:
""" """
""" Load model """
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
""" """
""" Get current model """
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
""" """
""" Initizlize the KV Cache """
pass
def execute_model(
@@ -193,10 +184,7 @@ class GpuWorker(WorkerBase):
"""
Perform the warm-up and the graph optimization
"""
# 1. Warm up model
# NOTE(gongshaotian): may be not need warm_up at this place
# 2. Triger cuda grpah capture
# Triger cuda grpah capture
self.model_runner.capture_model()
def check_health(self) -> bool:
@@ -204,10 +192,10 @@ class GpuWorker(WorkerBase):
return True
def cal_theortical_kvcache(self) -> int:
""" """
""" Calculate the block memory required """
return self.model_runner.cal_theortical_kvcache()
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
""" Reinitialize the kv cache using the parameters from the profile """
self.model_runner.update_share_input_block_num(
num_gpu_blocks=num_gpu_blocks)

View File

@@ -593,8 +593,23 @@ class IluvatarModelRunner(ModelRunnerBase):
Initialize forward meta and attention meta data
"""
# Initialize forward meta
self.forward_meta = ForwardMeta.init_forward_meta(
self.share_inputs, self.attn_backends[0])
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
cum_offsets=self.share_inputs["cum_offsets"],
padding_offset=self.share_inputs["padding_offset"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"]
)
# Initialzie attention meta data
for attn_backend in self.attn_backends:

View File

@@ -816,9 +816,23 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full(
[self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32')
# initialize_forward_meta
self.forward_meta = ForwardMeta.init_forward_meta(
self.share_inputs, self.attn_backend)
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
attn_backend=self.attn_backend,
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
cum_offsets=self.share_inputs["cum_offsets"],
padding_offset=self.share_inputs["padding_offset"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"]
)
self.attn_backend.init_attention_metadata(self.forward_meta)
self.sampling_metadata = SamplingMetadata(

View File

@@ -70,7 +70,21 @@ def xpu_pre_process(
share_inputs["cu_seqlens_q"] = cu_seqlens_q
share_inputs["cu_seqlens_k"] = cu_seqlens_k
xpu_forward_meta = XPUForwardMeta.init_forward_meta(share_inputs, None)
xpu_forward_meta = XPUForwardMeta(
input_ids=share_inputs["input_ids"],
ids_remove_padding=share_inputs["ids_remove_padding"],
rotary_embs=share_inputs["rope_emb"],
attn_backend=None,
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
cum_offsets=share_inputs["cum_offsets"],
padding_offset=share_inputs["padding_offset"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"]
)
# Get xpu extra param
(

View File

@@ -21,8 +21,7 @@ import paddle
from fastdeploy.model_executor.layers.attention import (
Attention, PaddleNativeAttnBackend)
from fastdeploy.worker.forward_meta import (ForwardMeta, ForwardMode,
MHATokenToKVPool)
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
class MockModelRunner:
@@ -63,15 +62,6 @@ class MockModelRunner:
},
)
self.page_size = page_size
max_total_num_tokens = max_batch_size * max_context_len
self.token_to_kv_pool = MHATokenToKVPool(
size=max_total_num_tokens,
page_size=page_size,
dtype=self.dtype,
head_num=num_heads,
head_dim=head_dim,
layer_num=1, # only consider layer=1 for unit test
device=self.device)
class TestNativePaddleAttentionBackend(unittest.TestCase):