mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
Clear dead code And supplementary notes (#2757)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* 1.supplementary notes 2.delete dead code * fix bug of forward meta * Global modification of forward meta * fix vl model_runner bug
This commit is contained in:
@@ -48,7 +48,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||
|
||||
|
||||
class GPUModelRunner(ModelRunnerBase):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -81,9 +80,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||
self.cudagraph_capture_sizes = list(
|
||||
reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||
self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups
|
||||
self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs,
|
||||
dtype='int32')
|
||||
|
||||
# Initialize share inputs
|
||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||
@@ -94,7 +90,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.restore_chunked_prefill_request = dict()
|
||||
|
||||
# Initialize attention Backend
|
||||
# Note(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
# NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
# In the future, we will expand it as a list.
|
||||
self.attn_backends: list[AttentionBackend] = []
|
||||
# self.attn_metadatas: list[AttentionMetadata] = []
|
||||
@@ -110,14 +106,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def prefill_finished(self):
|
||||
"""
|
||||
check whether prefill stage finished
|
||||
Check whether prefill stage finished
|
||||
"""
|
||||
if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def init_speculative_proposer(self):
|
||||
def _init_speculative_proposer(self):
|
||||
"""
|
||||
Init speculative proposer
|
||||
"""
|
||||
@@ -333,8 +329,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
(idx + 1) * block_num, 1)
|
||||
|
||||
def _init_share_inputs(self, max_num_seqs: int):
|
||||
"""Initialize all share buffers for model inputs.
|
||||
Note: In the future, we may abandon share buffers.
|
||||
"""
|
||||
Initialize all share buffers for model inputs.
|
||||
"""
|
||||
self.MAX_INFER_SEED = 9223372036854775806
|
||||
self.share_inputs = {}
|
||||
@@ -469,6 +465,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(
|
||||
self.parallel_config.max_model_len).reshape((1, -1))
|
||||
|
||||
# TODO(gongshaotian): move to models
|
||||
self.share_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
@@ -536,7 +533,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
dtype="int32")
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
""" prepare the model inputs """
|
||||
""" Prepare the model inputs """
|
||||
# Remove padding
|
||||
(
|
||||
ids_remove_padding,
|
||||
@@ -595,7 +592,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import \
|
||||
DynamicWeightManager
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
||||
self.dynamic_weight_manager = DynamicWeightManager(
|
||||
self.fd_config, self.model)
|
||||
|
||||
# 2. Load lora model
|
||||
|
||||
@@ -606,10 +604,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
f"Model loading took {time_after_load - time_before_load} seconds")
|
||||
|
||||
# 4. Init proposer for speculative method
|
||||
self.init_speculative_proposer()
|
||||
self._init_speculative_proposer()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
""" get current model """
|
||||
""" Get current model """
|
||||
return self.model
|
||||
|
||||
def initialize_forward_meta(self):
|
||||
@@ -617,32 +615,28 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
self.forward_meta = ForwardMeta.init_forward_meta(
|
||||
self.share_inputs, self.attn_backends[0])
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=self.share_inputs["cum_offsets"],
|
||||
padding_offset=self.share_inputs["padding_offset"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"]
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cached data from shared inputs and forward metadata."""
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
""""dynamic model loader use to clear parameters use for RL"""
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
""""dynamic model loader use to update parameters use for RL"""
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
"""
|
||||
Initialize kv cache
|
||||
@@ -701,11 +695,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def initialize_attn_backend(self) -> None:
|
||||
"""
|
||||
Initialize attention backends and forward metadata
|
||||
Initialize attention backends
|
||||
"""
|
||||
assert len(self.attn_backends) == 0
|
||||
|
||||
# TODO(gongshaotian): Get rank from config
|
||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree
|
||||
self.model_config.kv_num_heads = int(
|
||||
self.model_config.num_key_value_heads
|
||||
@@ -718,10 +711,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def _dummy_run(self,
|
||||
@@ -745,14 +735,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
expected_decode_len=expected_decode_len)
|
||||
while True:
|
||||
|
||||
# 1. Compute real num_tokens
|
||||
# 1. Initialize forward meta and attention meta data
|
||||
self._prepare_inputs()
|
||||
|
||||
# 2. Initialize attention backend and forward meta data
|
||||
# 2. Prepare lora
|
||||
|
||||
# 3. Prepare lora
|
||||
|
||||
# 4. Run model
|
||||
# 3. Run model
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
|
||||
@@ -773,7 +761,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.parallel_config.max_model_len,
|
||||
)
|
||||
|
||||
# 5. Execute spec decode
|
||||
# 4. Execute spec decode
|
||||
logits = self.model.compute_logits(hiddden_states)
|
||||
|
||||
if not self.speculative_decoding:
|
||||
@@ -805,7 +793,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"], 0)
|
||||
|
||||
# 6. post process
|
||||
# 5. post process
|
||||
model_output_data = ModelOutputData(
|
||||
next_tokens=self.share_inputs["next_tokens"],
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
@@ -858,7 +846,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def _update_chunked_prefill(self, tasks):
|
||||
"""
|
||||
更新chunked prefill相关参数
|
||||
Update chunked prefill related parameters
|
||||
"""
|
||||
if not self.parallel_config.enable_chunked_prefill:
|
||||
return
|
||||
@@ -903,13 +891,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.proposer.update_task_chunk_prefill(task)
|
||||
task.chunk_idx += 1
|
||||
|
||||
def _dummy_sampler_run(self) -> paddle.Tensor:
|
||||
""" """
|
||||
pass
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
|
||||
Trigger CUDA Graph capture for all shapes in cuda graph capture list
|
||||
"""
|
||||
if not self.use_cudagraph:
|
||||
logger.info(
|
||||
@@ -933,7 +917,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds"
|
||||
)
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch):
|
||||
def _get_skip_idx(self,
|
||||
model_forward_batch: Optional[List[Request]] = None):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
@@ -972,20 +957,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
We plan to replace it with 'ModelForwardBatch'.
|
||||
intermediate_tensors:
|
||||
"""
|
||||
# Note(@wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
if not self.not_need_stop():
|
||||
self._execute_empty_input()
|
||||
return None
|
||||
|
||||
# 1. Prepare inputs of model and decoder.
|
||||
# sampler create async operation
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# 2. Padding inputs for cuda grph
|
||||
# 2. Padding inputs for cuda graph
|
||||
|
||||
# 3. Execute model
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
@@ -1136,7 +1120,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||
|
||||
def profile_run(self) -> None:
|
||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
|
||||
""" Execute a forward pass with dummy inputs to profile the memory usage of the model """
|
||||
|
||||
# Initialize kv cache for profile run. After profile run kv cache will be reset.
|
||||
# TODO(gongshaotian): Optimize the management logic of kvcache
|
||||
@@ -1222,5 +1206,26 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
return required_memory
|
||||
|
||||
def not_need_stop(self) -> bool:
|
||||
""" """
|
||||
""" Stop decoding if the tensor meets the termination condition """
|
||||
return self.share_inputs["not_need_stop"][0]
|
||||
|
||||
def clear_cache(self):
|
||||
""" Clear cached data from shared inputs and forward metadata """
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
"""" Dynamic model loader use to clear parameters use for RL """
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
self.dynamic_weight_manager._log_memory(
|
||||
"dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
"""" Dynamic model loader use to update parameters use for RL """
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
self.dynamic_weight_manager._log_memory(
|
||||
"dynamic weight manager update all memory")
|
||||
|
Reference in New Issue
Block a user