mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
@@ -41,13 +42,10 @@ from fastdeploy.model_executor.pre_and_post_process import (post_process,
|
||||
rebuild_padding,
|
||||
step_cuda)
|
||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||
|
||||
logger = get_logger("gpu_model_runner", "gpu_model_runner.log")
|
||||
|
||||
|
||||
class GPUModelRunner(ModelRunnerBase):
|
||||
""" """
|
||||
@@ -593,6 +591,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
time_before_load = time.perf_counter()
|
||||
# 1. Load original model
|
||||
self.model = get_model_from_loader(fd_config=self.fd_config)
|
||||
# 1.1 Load RL dynamic model
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
||||
|
||||
# 2. Load lora model
|
||||
|
||||
@@ -620,6 +622,25 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cached data from shared inputs and forward metadata."""
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
""""dynamic model loader use to clear parameters use for RL"""
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
""""dynamic model loader use to update parameters use for RL"""
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
"""
|
||||
@@ -691,15 +712,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend(
|
||||
self.parallel_config.attention_backend)
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(self.fd_config,
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
f"{ self.parallel_config.attention_backend} attention backend is not support by GPUModelRunner"
|
||||
"Attention backend which you chose is not support by GPUModelRunner"
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
@@ -735,6 +755,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta)
|
||||
@@ -967,6 +988,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta)
|
||||
@@ -1124,9 +1146,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
batch_size=min(self.parallel_config.max_num_seqs, 3))
|
||||
|
||||
# 3. gc
|
||||
del self.share_inputs["caches"]
|
||||
if self.forward_meta is not None:
|
||||
del self.forward_meta.caches
|
||||
self.clear_cache()
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.clear_dummy_input()
|
||||
|
Reference in New Issue
Block a user