diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 19dfb98de..74c6fc14c 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -46,13 +46,9 @@ class ConcreteSizeEntry: # Output buffer of cudagraph output_buffer: Optional[paddle.Tensor] = None - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None - class CudaGraphPiecewiseBackend: - """ """ + """ Manage the capture and replay of CUDA graphs at the subgraph level. """ def __init__( self, @@ -65,33 +61,31 @@ class CudaGraphPiecewiseBackend: self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size - # runtime_bs -> ConcreteSizeEntry + # Runtime batch size -> ConcreteSizeEntry self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} for shape in self.cudagraph_capture_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_bs=shape) - print("[CUDA GRAPH] Created all batch size entry ") + logger.debug("[CUDA GRAPH] Created all batch size entry ") def __call__(self, **kwargs): # Get batch size ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"] batch_size = ids_remove_padding.shape[0] - padding_batch_size = self.batch_size_to_captured_size[batch_size] - # print( - # f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ", - # f"The padded batch size is :{padding_batch_size}" - # ) + logger.debug( + f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ", + f"The padded batch size is :{padding_batch_size}") entry = self.concrete_size_entries.get(padding_batch_size) assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list." if entry.runnable is None: entry.runnable = self.runnable - # print( - # f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}" - # ) + logger.debug( + f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}" + ) if not entry.use_cudagraph: return entry.runnable(**kwargs) @@ -102,10 +96,10 @@ class CudaGraphPiecewiseBackend: for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 entry.runnable(**kwargs) - # print( - # "[CUDA GRAPH] Warm up for batch size ", - # f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times" - # ) + logger.debug( + "[CUDA GRAPH] Warm up for batch size ", + f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times" + ) # Store input addresses for debug input_addresses = [ @@ -129,11 +123,13 @@ class CudaGraphPiecewiseBackend: output._clear paddle.device.synchronize() - # print( - # f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}" - # ) + logger.debug( + f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}" + ) # Replay entry.cuda_graph.replay() - # print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}") + logger.debug( + f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}" + ) return entry.output_buffer diff --git a/fastdeploy/model_executor/graph_optimization/decorator.py b/fastdeploy/model_executor/graph_optimization/decorator.py index ad0ddb5b6..8661a7beb 100644 --- a/fastdeploy/model_executor/graph_optimization/decorator.py +++ b/fastdeploy/model_executor/graph_optimization/decorator.py @@ -28,7 +28,7 @@ _T = TypeVar("_T", bound=type[paddle.nn.Layer]) def support_graph_optimization(cls: Optional[_T] = None) -> _T: """ - A decorator for wrapping models or layers with CUDA graph support. + A decorator for wrapping models or layers with static graph and CUDAGraph support. This enables efficient kernel launch sequencing for improved GPU performance. Example usage: @@ -74,7 +74,7 @@ def support_graph_optimization(cls: Optional[_T] = None) -> _T: class GraphOptWrapper: - """ """ + """ The wrapper for GraphOptBackend """ def __init__( self, @@ -87,7 +87,7 @@ class GraphOptWrapper: @abstractmethod def forward(self, **kwargs): - """ """ + """ Abstract methods for implementing model.forward() """ pass def __call__(self, **kwargs): diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 7189989dd..9ce6f7372 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -24,7 +24,10 @@ from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend im class GraphOptBackend: - """ """ + """ + Integrated various graph optimization functions, including dynamic graph to static graph conversion, + CINN compilation optimization, CudaGraph, and so on. + """ fd_config: FDConfig cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 97e836445..6de3ce633 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -436,8 +436,24 @@ class MTPProposer(Proposer): Initialize forward meta and attention meta data """ # Initialize forward meta - self.forward_meta = ForwardMeta.init_forward_meta( - self.model_inputs, self.attn_backends[0]) + self.forward_meta = ForwardMeta( + input_ids=self.model_inputs["input_ids"], + ids_remove_padding=self.model_inputs["ids_remove_padding"], + rotary_embs=self.model_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.model_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], + seq_lens_encoder=self.model_inputs["seq_lens_encoder"], + seq_lens_decoder=self.model_inputs["seq_lens_decoder"], + seq_lens_this_time=self.model_inputs["seq_lens_this_time"], + cum_offsets=self.model_inputs["cum_offsets"], + padding_offset=self.model_inputs["padding_offset"], + cu_seqlens_q=self.model_inputs["cu_seqlens_q"], + cu_seqlens_k=self.model_inputs["cu_seqlens_k"], + block_tables=self.model_inputs["block_tables"], + caches=self.model_inputs["caches"] + ) + # Initialzie attention meta data for attn_backend in self.attn_backends: diff --git a/fastdeploy/worker/forward_meta.py b/fastdeploy/worker/forward_meta.py index a1007f4e1..4948821e6 100644 --- a/fastdeploy/worker/forward_meta.py +++ b/fastdeploy/worker/forward_meta.py @@ -14,18 +14,15 @@ # limitations under the License. """ -import abc import logging from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional -import numpy as np import paddle if TYPE_CHECKING: - from fastdeploy.model_executor.layers.attention import (Attention, - AttentionBackend) + from fastdeploy.model_executor.layers.attention import AttentionBackend logger = logging.getLogger(__name__) @@ -34,333 +31,79 @@ class ForwardMode(IntEnum): """ Forward mode used during attention. """ - - # for prefill and extend + # Prefill and Extend mode EXTEND = auto() - # for generation + # Decode mode DECODE = auto() - + # Mixed mode MIXED = auto() def is_prefill(self): - """Whether it's a prefill forward""" + """ Is Extend mode """ return self == ForwardMode.EXTEND def is_decode(self): - """Whether it's a decode forward""" + """ Is Decode mode """ return self == ForwardMode.DECODE def is_mixed(self): - """Whether it's a decode forward""" + """ Is Mixed mode """ return self == ForwardMode.MIXED -class ReqToTokenPool: - """A memory pool that maps a request to its token locations.""" - - def __init__(self, size: int, max_context_len: int): - - self.size = size - self.max_context_len = max_context_len - self.req_to_token = paddle.zeros((size, max_context_len), - dtype=paddle.int32) - self.free_slots = list(range(size)) - - def write(self, indices, values): - """Write data into request buffer""" - self.req_to_token[indices] = values - - def available_size(self): - """Get number of slots left""" - return len(self.free_slots) - - def alloc(self, need_size: int) -> List[int]: - """Allocate `need_size` slots""" - if need_size > len(self.free_slots): - return None - - select_index = self.free_slots[:need_size] - self.free_slots = self.free_slots[need_size:] - - return select_index - - def free(self, free_index: Union[int, List[int]]): - """Free slot""" - if isinstance(free_index, (int, )): - self.free_slots.append(free_index) - else: - self.free_slots.extend(free_index) - - def clear(self): - """Clear all slots""" - self.free_slots = list(range(self.size)) - - -class KVCache(abc.ABC): - """Abstract base class representing a key value cache""" - - @abc.abstractmethod - def get_kv_buffer(self, - layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]: - """ - Return cached keys and values given layer id. - Args: - layer_id: int - Returns: - tuple: (keys, values) - """ - raise NotImplementedError() - - @abc.abstractmethod - def set_kv_buffer( - self, - layer: 'Attention', - loc: paddle.Tensor, - cache_k: paddle.Tensor, - cache_v: paddle.Tensor, - ) -> None: - """ - Set cached keys and values given layer id. - Args: - layer: Attention - loc: paddle.Tensor - cache_k: paddle.Tensor - cache_v: paddle.Tensor - """ - raise NotImplementedError() - - @abc.abstractmethod - def transfer(self, indices, flat_data): - """Transfer kv_data between devices""" - raise NotImplementedError() - - @abc.abstractmethod - def transfer_per_layer(self, indices, flat_data, layer_id): - """Not used yet""" - raise NotImplementedError() - - def register_layer_transfer_counter(self, layer_transfer_counter): - """Not used yet""" - self.layer_transfer_counter = layer_transfer_counter - - -class MHATokenToKVPool(KVCache): - """Token To Key Value Pool for MultiHeadAttention""" - - def __init__( - self, - max_block_num: int, - block_size: int, - dtype: paddle.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: str, - ): - self.max_block_num = max_block_num - self.block_size = block_size - self.dtype = dtype - self.device = device - if dtype in (paddle.int8, paddle.float8_e4m3fn): - # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 - self.store_dtype = paddle.uint8 - else: - self.store_dtype = dtype - - self.head_num = head_num - self.head_dim = head_dim - self.layer_num = layer_num - self._create_buffers() - - k_size, v_size = self.get_kv_size_bytes() - GB = 1024 * 1024 * 1024 - logger.info( - f"KV Cache is allocated. #tokens: {self.size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" - ) - - def _create_buffers(self): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - paddle.zeros( - (self.max_block_num, self.head_num, self.block_size, - self.head_dim), - dtype=self.store_dtype, - ) for _ in range(self.layer_num) - ] - self.v_buffer = [ - paddle.zeros( - (self.max_block_num, self.head_num, self.block_size, - self.head_dim), - dtype=self.store_dtype, - ) for _ in range(self.layer_num) - ] - - def _clear_buffers(self): - del self.k_buffer - del self.v_buffer - - def get_kv_size_bytes(self): - """for debugging purpose""" - assert hasattr(self, "k_buffer") - assert hasattr(self, "v_buffer") - k_size_bytes = 0 - for k_cache in self.k_buffer: - k_size_bytes += np.prod(k_cache.shape) * 4 - v_size_bytes = 0 - for v_cache in self.v_buffer: - v_size_bytes += np.prod(v_cache.shape) * 4 - return k_size_bytes, v_size_bytes - - def transfer(self, indices, flat_data): - # transfer prepared data from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - k_data, v_data = flat_data[0], flat_data[1] - for i in range(self.layer_num): - self.k_buffer[i][indices] = k_data[i] - self.v_buffer[i][indices] = v_data[i] - - def transfer_per_layer(self, indices, flat_data, layer_id): - # transfer prepared data for a specific layer from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - k_data, v_data = flat_data[0], flat_data[1] - self.k_buffer[layer_id][indices] = k_data - self.v_buffer[layer_id][indices] = v_data - - def get_key_buffer(self, layer_id: int): - """Return cached keys given layer id.""" - if self.store_dtype != self.dtype: - return self.k_buffer[layer_id].view(self.dtype) - return self.k_buffer[layer_id] - - def get_value_buffer(self, layer_id: int): - """Return cached values given layer id.""" - if self.store_dtype != self.dtype: - return self.v_buffer[layer_id].view(self.dtype) - return self.v_buffer[layer_id] - - def get_kv_buffer(self, layer_id: int): - """Return cached keys and values given layer id.""" - return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) - - def set_kv_buffer( - self, - layer: 'Attention', - loc: paddle.Tensor, - cache_k: paddle.Tensor, - cache_v: paddle.Tensor, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - ): - """Set cached keys and values given layer id.""" - layer_id = layer.layer_id - if cache_k.dtype != self.dtype: - if k_scale is not None: - cache_k.div_(k_scale) - if v_scale is not None: - cache_v.div_(v_scale) - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) - - if self.store_dtype != self.dtype: - cache_k = cache_k.view(self.store_dtype) - cache_v = cache_v.view(self.store_dtype) - - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v - - @dataclass class ForwardMeta(): """ - ForwardMeta is used to store the global meta information of the forward. + ForwardMeta is used to store the global meta information of the model forward. """ - # + # Input tokens IDs input_ids: paddle.Tensor - - #attention meta - forward_mode: ForwardMode = ForwardMode.MIXED - - # - ids_remove_padding: paddle.Tensor = None - - # - seq_lens_encoder: Optional[paddle.Tensor] = None - - # - seq_lens_decoder: Optional[paddle.Tensor] = None - - # - seq_lens_this_time: Optional[paddle.Tensor] = None - - # - cum_offsets: Optional[paddle.Tensor] = None - - # - block_tables: Optional[paddle.Tensor] = None - - # - attn_backend: 'AttentionBackend' = None - - # + # Input tokens IDs of removed padding + ids_remove_padding: paddle.Tensor + # Rotation position embedding rotary_embs: Optional[paddle.Tensor] = None - # - padding_offset: Optional[paddle.Tensor] = None - - # - cu_seqlens_q: Optional[paddle.Tensor] = None - - # - cu_seqlens_k: Optional[paddle.Tensor] = None - - # - caches: Optional[paddle.Tensor] = None - - # - attn_mask: Optional[paddle.Tensor] = None - - # - pre_caches_length: int = 0 - - # Use cuda graph in this step. Used to avoid run cuda graph when in dummy run or prefill stage. + # Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage. step_use_cudagraph: bool = False - - # for attention backend - decoder_batch_ids: Optional[paddle.Tensor] = None - # for attention backend - decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None - # is_decode_batch or not + # Batch type flag is_decode_batch: bool = False - @classmethod - def init_forward_meta(cls, share_inputs: Dict, - attn_backend: "AttentionBackend"): - """ init forward meta """ - # TODO(gongshaotian): delete this func - ret = cls( - forward_mode=ForwardMode.MIXED, - input_ids=share_inputs["input_ids"], - ids_remove_padding=share_inputs["ids_remove_padding"], - seq_lens_encoder=share_inputs["seq_lens_encoder"], - seq_lens_decoder=share_inputs["seq_lens_decoder"], - seq_lens_this_time=share_inputs["seq_lens_this_time"], - cum_offsets=share_inputs["cum_offsets"], - block_tables=share_inputs["block_tables"], - attn_backend=attn_backend, - rotary_embs=share_inputs["rope_emb"], - padding_offset=share_inputs["padding_offset"], - cu_seqlens_q=share_inputs["cu_seqlens_q"], - cu_seqlens_k=share_inputs["cu_seqlens_k"], - caches=share_inputs["caches"], - decoder_batch_ids=share_inputs.get("decoder_batch_ids", None), - decoder_tile_ids_per_batch=share_inputs.get( - "decoder_tile_ids_per_batch", None), - ) - return ret - + # Attention backend object + attn_backend: 'AttentionBackend' = None + # Forward mode used during attention + forward_mode: ForwardMode = ForwardMode.MIXED + # Attention mask + attn_mask: Optional[paddle.Tensor] = None + # Decoder batch id. Used by attention backend. + decoder_batch_ids: Optional[paddle.Tensor] = None + # Tile ID for each batch of the decoder. Used by attention backend. + decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + + # Sequence length of encoder for ever batch + seq_lens_encoder: Optional[paddle.Tensor] = None + # Sequence length of Encoder for ever batch + seq_lens_decoder: Optional[paddle.Tensor] = None + # The sequence length processed in the current step + seq_lens_this_time: Optional[paddle.Tensor] = None + + # Accumulated offset + cum_offsets: Optional[paddle.Tensor] = None + # Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids + padding_offset: Optional[paddle.Tensor] = None + # Accumulated sequence length of query + cu_seqlens_q: Optional[paddle.Tensor] = None + # Accumulated sequence length of key + cu_seqlens_k: Optional[paddle.Tensor] = None + + # Pre-cache length + pre_caches_length: int = 0 + # Block tables + block_tables: Optional[paddle.Tensor] = None + # KV caches + caches: Optional[paddle.Tensor] = None + def clear_caches(self): - """safe clear caches""" + """ Safely clean up the caches """ if self.caches: del self.caches @@ -370,56 +113,42 @@ class XPUForwardMeta(ForwardMeta): """ XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info. """ + # TODO(wanghaitao): Supplementary notes # encoder_batch_map: Optional[paddle.Tensor] = None - # decoder_batch_map: Optional[paddle.Tensor] = None - # encoder_batch_idx: Optional[paddle.Tensor] = None - # decoder_batch_idx: Optional[paddle.Tensor] = None - # encoder_seq_lod: Optional[paddle.Tensor] = None - # decoder_context_len: Optional[paddle.Tensor] = None - # decoder_context_len_cache: Optional[paddle.Tensor] = None # encoder_batch_map_cpu: Optional[paddle.Tensor] = None - # decoder_batch_map_cpu: Optional[paddle.Tensor] = None - # encoder_batch_idx_cpu: Optional[paddle.Tensor] = None - # decoder_batch_idx_cpu: Optional[paddle.Tensor] = None - # encoder_seq_lod_cpu: Optional[paddle.Tensor] = None - # decoder_context_len_cpu: Optional[paddle.Tensor] = None - # decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None # batch_tensor: Optional[paddle.Tensor] = None - # enc_batch: Optional[paddle.Tensor] = None - # dec_batch: Optional[paddle.Tensor] = None - # total_enc_len: Optional[paddle.Tensor] = None diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 5dd8cef1b..811b2b691 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -606,8 +606,23 @@ class GCUModelRunner(ModelRunnerBase): Initialize forward meta and attention meta data """ # Initialize forward meta - self.forward_meta = ForwardMeta.init_forward_meta( - self.share_inputs, self.attn_backends[0]) + self.forward_meta = ForwardMeta( + input_ids=self.share_inputs["input_ids"], + ids_remove_padding=self.share_inputs["ids_remove_padding"], + rotary_embs=self.share_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.share_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + cum_offsets=self.share_inputs["cum_offsets"], + padding_offset=self.share_inputs["padding_offset"], + cu_seqlens_q=self.share_inputs["cu_seqlens_q"], + cu_seqlens_k=self.share_inputs["cu_seqlens_k"], + block_tables=self.share_inputs["block_tables"], + caches=self.share_inputs["caches"] + ) # Initialzie attention meta data for attn_backend in self.attn_backends: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d6ca79a1..e7b78062f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -48,7 +48,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput class GPUModelRunner(ModelRunnerBase): - """ """ def __init__( self, @@ -81,9 +80,6 @@ class GPUModelRunner(ModelRunnerBase): self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list( reversed(self.graph_opt_config.cudagraph_capture_sizes)) - self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups - self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs, - dtype='int32') # Initialize share inputs self._init_share_inputs(self.parallel_config.max_num_seqs) @@ -94,7 +90,7 @@ class GPUModelRunner(ModelRunnerBase): self.restore_chunked_prefill_request = dict() # Initialize attention Backend - # Note(gonshaotian): Currently, all attention layers share one attention backend instance. + # NOTE(gonshaotian): Currently, all attention layers share one attention backend instance. # In the future, we will expand it as a list. self.attn_backends: list[AttentionBackend] = [] # self.attn_metadatas: list[AttentionMetadata] = [] @@ -110,14 +106,14 @@ class GPUModelRunner(ModelRunnerBase): def prefill_finished(self): """ - check whether prefill stage finished + Check whether prefill stage finished """ if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0: return 1 else: return 0 - def init_speculative_proposer(self): + def _init_speculative_proposer(self): """ Init speculative proposer """ @@ -333,8 +329,8 @@ class GPUModelRunner(ModelRunnerBase): (idx + 1) * block_num, 1) def _init_share_inputs(self, max_num_seqs: int): - """Initialize all share buffers for model inputs. - Note: In the future, we may abandon share buffers. + """ + Initialize all share buffers for model inputs. """ self.MAX_INFER_SEED = 9223372036854775806 self.share_inputs = {} @@ -469,6 +465,7 @@ class GPUModelRunner(ModelRunnerBase): # Initialize rotary position embedding tmp_position_ids = paddle.arange( self.parallel_config.max_model_len).reshape((1, -1)) + # TODO(gongshaotian): move to models self.share_inputs["rope_emb"] = get_rope( rotary_dim=self.model_config.head_dim, @@ -536,7 +533,7 @@ class GPUModelRunner(ModelRunnerBase): dtype="int32") def _prepare_inputs(self) -> None: - """ prepare the model inputs """ + """ Prepare the model inputs """ # Remove padding ( ids_remove_padding, @@ -595,7 +592,8 @@ class GPUModelRunner(ModelRunnerBase): if self.fd_config.load_config.dynamic_load_weight: from fastdeploy.rl.dynamic_weight_manager import \ DynamicWeightManager - self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) + self.dynamic_weight_manager = DynamicWeightManager( + self.fd_config, self.model) # 2. Load lora model @@ -606,10 +604,10 @@ class GPUModelRunner(ModelRunnerBase): f"Model loading took {time_after_load - time_before_load} seconds") # 4. Init proposer for speculative method - self.init_speculative_proposer() + self._init_speculative_proposer() def get_model(self) -> nn.Layer: - """ get current model """ + """ Get current model """ return self.model def initialize_forward_meta(self): @@ -617,32 +615,28 @@ class GPUModelRunner(ModelRunnerBase): Initialize forward meta and attention meta data """ # Initialize forward meta - self.forward_meta = ForwardMeta.init_forward_meta( - self.share_inputs, self.attn_backends[0]) + self.forward_meta = ForwardMeta( + input_ids=self.share_inputs["input_ids"], + ids_remove_padding=self.share_inputs["ids_remove_padding"], + rotary_embs=self.share_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.share_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + cum_offsets=self.share_inputs["cum_offsets"], + padding_offset=self.share_inputs["padding_offset"], + cu_seqlens_q=self.share_inputs["cu_seqlens_q"], + cu_seqlens_k=self.share_inputs["cu_seqlens_k"], + block_tables=self.share_inputs["block_tables"], + caches=self.share_inputs["caches"] + ) # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - def clear_cache(self): - """Clear cached data from shared inputs and forward metadata.""" - self.share_inputs.pop("caches", None) - if self.forward_meta is not None: - self.forward_meta.clear_caches() - - def clear_parameters(self, pid): - """"dynamic model loader use to clear parameters use for RL""" - self.dynamic_weight_manager.clear_parameters(pid) - self.clear_cache() - paddle.device.cuda.empty_cache() - self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") - - def update_parameters(self, pid): - """"dynamic model loader use to update parameters use for RL""" - self.dynamic_weight_manager.update_parameters(pid) - self.initialize_kv_cache() - self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") - def initialize_kv_cache(self) -> None: """ Initialize kv cache @@ -701,11 +695,10 @@ class GPUModelRunner(ModelRunnerBase): def initialize_attn_backend(self) -> None: """ - Initialize attention backends and forward metadata + Initialize attention backends """ assert len(self.attn_backends) == 0 - # TODO(gongshaotian): Get rank from config num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree self.model_config.kv_num_heads = int( self.model_config.num_key_value_heads @@ -718,10 +711,7 @@ class GPUModelRunner(ModelRunnerBase): kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim) - if attn_backend is None: - raise NotImplementedError( - "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." - ) + self.attn_backends.append(attn_backend) def _dummy_run(self, @@ -745,14 +735,12 @@ class GPUModelRunner(ModelRunnerBase): expected_decode_len=expected_decode_len) while True: - # 1. Compute real num_tokens + # 1. Initialize forward meta and attention meta data self._prepare_inputs() - # 2. Initialize attention backend and forward meta data + # 2. Prepare lora - # 3. Prepare lora - - # 4. Run model + # 3. Run model is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing @@ -773,7 +761,7 @@ class GPUModelRunner(ModelRunnerBase): self.parallel_config.max_model_len, ) - # 5. Execute spec decode + # 4. Execute spec decode logits = self.model.compute_logits(hiddden_states) if not self.speculative_decoding: @@ -805,7 +793,7 @@ class GPUModelRunner(ModelRunnerBase): paddle.distributed.broadcast( self.share_inputs["stop_flags"], 0) - # 6. post process + # 5. post process model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"], stop_flags=self.share_inputs["stop_flags"], @@ -858,7 +846,7 @@ class GPUModelRunner(ModelRunnerBase): def _update_chunked_prefill(self, tasks): """ - 更新chunked prefill相关参数 + Update chunked prefill related parameters """ if not self.parallel_config.enable_chunked_prefill: return @@ -903,13 +891,9 @@ class GPUModelRunner(ModelRunnerBase): self.proposer.update_task_chunk_prefill(task) task.chunk_idx += 1 - def _dummy_sampler_run(self) -> paddle.Tensor: - """ """ - pass - def capture_model(self) -> None: """ - Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes' + Trigger CUDA Graph capture for all shapes in cuda graph capture list """ if not self.use_cudagraph: logger.info( @@ -933,7 +917,8 @@ class GPUModelRunner(ModelRunnerBase): f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds" ) - def _get_skip_idx(self, model_forward_batch): + def _get_skip_idx(self, + model_forward_batch: Optional[List[Request]] = None): """ Get the index of the request that needs to be skipped during execution. Args: @@ -972,20 +957,19 @@ class GPUModelRunner(ModelRunnerBase): We plan to replace it with 'ModelForwardBatch'. intermediate_tensors: """ - # Note(@wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. + # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, # when there is data on other runner, the current runner is required to execute part of the model. if not self.not_need_stop(): self._execute_empty_input() return None - # 1. Prepare inputs of model and decoder. - # sampler create async operation + # 1. Prepare inputs of model and sampler. skip_idx_list = self._get_skip_idx(model_forward_batch) self._prepare_inputs() self.sampler.pre_process(skip_idx_list) - # 2. Padding inputs for cuda grph + # 2. Padding inputs for cuda graph # 3. Execute model # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch @@ -1136,7 +1120,7 @@ class GPUModelRunner(ModelRunnerBase): f"{type(self.model)} has no attribute 'empty_input_forward") def profile_run(self) -> None: - """Execute a forward pass with dummy inputs to profile the memory usage of the model.""" + """ Execute a forward pass with dummy inputs to profile the memory usage of the model """ # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache @@ -1222,5 +1206,26 @@ class GPUModelRunner(ModelRunnerBase): return required_memory def not_need_stop(self) -> bool: - """ """ + """ Stop decoding if the tensor meets the termination condition """ return self.share_inputs["not_need_stop"][0] + + def clear_cache(self): + """ Clear cached data from shared inputs and forward metadata """ + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def clear_parameters(self, pid): + """" Dynamic model loader use to clear parameters use for RL """ + self.dynamic_weight_manager.clear_parameters(pid) + self.clear_cache() + paddle.device.cuda.empty_cache() + self.dynamic_weight_manager._log_memory( + "dynamic weight manager clear all memory") + + def update_parameters(self, pid): + """" Dynamic model loader use to update parameters use for RL """ + self.dynamic_weight_manager.update_parameters(pid) + self.initialize_kv_cache() + self.dynamic_weight_manager._log_memory( + "dynamic weight manager update all memory") diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 0386485fa..b9f08c6b2 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -32,7 +32,6 @@ logger = get_logger("gpu_worker", "gpu_worker.log") class GpuWorker(WorkerBase): - """ """ def __init__( self, @@ -48,7 +47,8 @@ class GpuWorker(WorkerBase): pass def init_device(self): - """ Initialize device and Construct model runner + """ + Initialize device and construct model runner """ if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda( ): @@ -74,10 +74,10 @@ class GpuWorker(WorkerBase): device_id=self.device_ids[self.local_rank], rank=self.rank, local_rank=self.local_rank) - + def prefill_finished(self): """ - check whether prefill stage finished + Check whether prefill stage finished """ return self.model_runner.prefill_finished() @@ -115,7 +115,8 @@ class GpuWorker(WorkerBase): f"\nDevice used memory: {before_run_meminfo.used / Gb}", f"\nDevice free memory: {before_run_meminfo.free / Gb}", f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}", - f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}")) + f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}" + )) # 2. Profile run self.model_runner.profile_run() @@ -126,15 +127,6 @@ class GpuWorker(WorkerBase): paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated( self.local_rank) - - - # NOTE(gongshaotian): v1 worker - # not_paddle_use_mem = after_run_meminfo.used - paddle_reserved_mem_after_run - # peak_memory = paddle_allocated_mem_after_run + not_paddle_use_mem - # available_kv_cache_memory = after_run_meminfo.total * \ - # self.parallel_config.gpu_memory_utilization - peak_memory - - # v0 worker model_block_memory_used = self.cal_theortical_kvcache() paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run @@ -146,32 +138,31 @@ class GpuWorker(WorkerBase): available_kv_cache_memory = after_run_meminfo.total * \ self.parallel_config.gpu_memory_utilization - after_run_meminfo.used - paddle_peak_increase available_kv_cache_memory += model_block_memory_used * self.parallel_config.max_block_num - end_time = time.perf_counter() - logger.info( - ("After running the profile, the memory usage info is as follows:", - f"\nDevice Total memory: {after_run_meminfo.total / Gb}", - f"\nDevice used memory: {after_run_meminfo.used / Gb}", - f"\nDevice free memory: {after_run_meminfo.free / Gb}", - f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}", - f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}", - f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}", - f"Profile time: {end_time - start_time}")) + logger.info(( + "After running the profile, the memory usage info is as follows:", + f"\nDevice Total memory: {after_run_meminfo.total / Gb}", + f"\nDevice used memory: {after_run_meminfo.used / Gb}", + f"\nDevice free memory: {after_run_meminfo.free / Gb}", + f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}", + f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}", + f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}", + f"Profile time: {end_time - start_time}")) return available_kv_cache_memory # return to caculate the block num in this device def load_model(self) -> None: - """ """ + """ Load model """ self.model_runner.load_model() def get_model(self) -> nn.Layer: - """ """ + """ Get current model """ return self.model_runner.get_model() def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """ """ + """ Initizlize the KV Cache """ pass def execute_model( @@ -193,10 +184,7 @@ class GpuWorker(WorkerBase): """ Perform the warm-up and the graph optimization """ - # 1. Warm up model - # NOTE(gongshaotian): may be not need warm_up at this place - - # 2. Triger cuda grpah capture + # Triger cuda grpah capture self.model_runner.capture_model() def check_health(self) -> bool: @@ -204,10 +192,10 @@ class GpuWorker(WorkerBase): return True def cal_theortical_kvcache(self) -> int: - """ """ + """ Calculate the block memory required """ return self.model_runner.cal_theortical_kvcache() def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ + """ Reinitialize the kv cache using the parameters from the profile """ self.model_runner.update_share_input_block_num( num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 534853c72..42aadd9b6 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -593,8 +593,23 @@ class IluvatarModelRunner(ModelRunnerBase): Initialize forward meta and attention meta data """ # Initialize forward meta - self.forward_meta = ForwardMeta.init_forward_meta( - self.share_inputs, self.attn_backends[0]) + self.forward_meta = ForwardMeta( + input_ids=self.share_inputs["input_ids"], + ids_remove_padding=self.share_inputs["ids_remove_padding"], + rotary_embs=self.share_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.share_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + cum_offsets=self.share_inputs["cum_offsets"], + padding_offset=self.share_inputs["padding_offset"], + cu_seqlens_q=self.share_inputs["cu_seqlens_q"], + cu_seqlens_k=self.share_inputs["cu_seqlens_k"], + block_tables=self.share_inputs["block_tables"], + caches=self.share_inputs["caches"] + ) # Initialzie attention meta data for attn_backend in self.attn_backends: diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index e78c30946..77dd21500 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -816,9 +816,23 @@ class GPUVLModelRunner(VLModelRunnerBase): self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full( [self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32') # initialize_forward_meta - self.forward_meta = ForwardMeta.init_forward_meta( - self.share_inputs, self.attn_backend) - + self.forward_meta = ForwardMeta( + input_ids=self.share_inputs["input_ids"], + ids_remove_padding=self.share_inputs["ids_remove_padding"], + rotary_embs=self.share_inputs["rope_emb"], + attn_backend=self.attn_backend, + decoder_batch_ids=self.share_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + cum_offsets=self.share_inputs["cum_offsets"], + padding_offset=self.share_inputs["padding_offset"], + cu_seqlens_q=self.share_inputs["cu_seqlens_q"], + cu_seqlens_k=self.share_inputs["cu_seqlens_k"], + block_tables=self.share_inputs["block_tables"], + caches=self.share_inputs["caches"] + ) self.attn_backend.init_attention_metadata(self.forward_meta) self.sampling_metadata = SamplingMetadata( diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 41c7f9fe7..b82eda700 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -70,7 +70,21 @@ def xpu_pre_process( share_inputs["cu_seqlens_q"] = cu_seqlens_q share_inputs["cu_seqlens_k"] = cu_seqlens_k - xpu_forward_meta = XPUForwardMeta.init_forward_meta(share_inputs, None) + xpu_forward_meta = XPUForwardMeta( + input_ids=share_inputs["input_ids"], + ids_remove_padding=share_inputs["ids_remove_padding"], + rotary_embs=share_inputs["rope_emb"], + attn_backend=None, + seq_lens_encoder=share_inputs["seq_lens_encoder"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + seq_lens_this_time=share_inputs["seq_lens_this_time"], + cum_offsets=share_inputs["cum_offsets"], + padding_offset=share_inputs["padding_offset"], + cu_seqlens_q=share_inputs["cu_seqlens_q"], + cu_seqlens_k=share_inputs["cu_seqlens_k"], + block_tables=share_inputs["block_tables"], + caches=share_inputs["caches"] + ) # Get xpu extra param ( diff --git a/test/layers/test_attention.py b/test/layers/test_attention.py index 9d4b09679..b499ee1c2 100644 --- a/test/layers/test_attention.py +++ b/test/layers/test_attention.py @@ -21,8 +21,7 @@ import paddle from fastdeploy.model_executor.layers.attention import ( Attention, PaddleNativeAttnBackend) -from fastdeploy.worker.forward_meta import (ForwardMeta, ForwardMode, - MHATokenToKVPool) +from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode class MockModelRunner: @@ -63,15 +62,6 @@ class MockModelRunner: }, ) self.page_size = page_size - max_total_num_tokens = max_batch_size * max_context_len - self.token_to_kv_pool = MHATokenToKVPool( - size=max_total_num_tokens, - page_size=page_size, - dtype=self.dtype, - head_num=num_heads, - head_dim=head_dim, - layer_num=1, # only consider layer=1 for unit test - device=self.device) class TestNativePaddleAttentionBackend(unittest.TestCase):