mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[SOT] Add sot warmup (NVIDIA GPU Only) (#2929)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* add sot warmup * fix code style * change batch_size list * add param to config * rm free_list settings && set sot_warmup_sizes * finish debug with dynamic dims by type annotations * add profile_run guard * rm sth useless
This commit is contained in:
@@ -319,6 +319,8 @@ class GraphOptimizationConfig:
|
|||||||
- With dyncmic graph backend: ...
|
- With dyncmic graph backend: ...
|
||||||
- With static grpah backend: WIP
|
- With static grpah backend: WIP
|
||||||
"""
|
"""
|
||||||
|
sot_warmup_sizes: Optional[list[int]] = field(default_factory=list)
|
||||||
|
""" Number of warmup runs for SOT warmup. """
|
||||||
use_cudagraph: bool = False
|
use_cudagraph: bool = False
|
||||||
"""Sizes to capture cudagraph.
|
"""Sizes to capture cudagraph.
|
||||||
- None (default): capture sizes are inferred from llm config.
|
- None (default): capture sizes are inferred from llm config.
|
||||||
|
@@ -429,6 +429,7 @@ class GraphOptimizationConfig:
|
|||||||
graph_opt_level: Optional[int] = 0,
|
graph_opt_level: Optional[int] = 0,
|
||||||
use_cudagraph: Optional[bool] = None,
|
use_cudagraph: Optional[bool] = None,
|
||||||
cudagraph_capture_sizes: Optional[List[int]] = None,
|
cudagraph_capture_sizes: Optional[List[int]] = None,
|
||||||
|
sot_warmup_sizes: Optional[List[int]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -444,6 +445,7 @@ class GraphOptimizationConfig:
|
|||||||
self.graph_opt_level = graph_opt_level
|
self.graph_opt_level = graph_opt_level
|
||||||
self.use_cudagraph = use_cudagraph
|
self.use_cudagraph = use_cudagraph
|
||||||
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||||
|
self.sot_warmup_sizes = [] if sot_warmup_sizes is None else sot_warmup_sizes
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
"""
|
"""
|
||||||
|
@@ -31,26 +31,15 @@ from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend im
|
|||||||
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
|
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
|
||||||
resolve_dynamic_dims,
|
resolve_dynamic_dims,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.graph_optimization.utils import in_profile_run_mode
|
||||||
|
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||||
|
in_sot_warmup_mode as in_warmup_mode,
|
||||||
|
)
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
|
|
||||||
def create_in_warmup_mode():
|
|
||||||
cnt = 0
|
|
||||||
|
|
||||||
def in_warmup_mode():
|
|
||||||
nonlocal cnt
|
|
||||||
cnt += 1
|
|
||||||
return cnt < 32
|
|
||||||
|
|
||||||
return in_warmup_mode
|
|
||||||
|
|
||||||
|
|
||||||
in_warmup_mode = create_in_warmup_mode()
|
|
||||||
|
|
||||||
|
|
||||||
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
|
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
|
||||||
forward_fn = fn
|
forward_fn = fn
|
||||||
forward_sig = inspect.signature(forward_fn)
|
forward_sig = inspect.signature(forward_fn)
|
||||||
@@ -99,6 +88,8 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
|
|||||||
|
|
||||||
@functools.wraps(forward_fn)
|
@functools.wraps(forward_fn)
|
||||||
def static_forward(self, *args, **kwargs):
|
def static_forward(self, *args, **kwargs):
|
||||||
|
if in_profile_run_mode():
|
||||||
|
return forward_fn(self, *args, **kwargs)
|
||||||
nonlocal need_warmup
|
nonlocal need_warmup
|
||||||
is_warmup = in_warmup_mode() and need_warmup
|
is_warmup = in_warmup_mode() and need_warmup
|
||||||
if is_warmup:
|
if is_warmup:
|
||||||
|
40
fastdeploy/model_executor/graph_optimization/utils.py
Normal file
40
fastdeploy/model_executor/graph_optimization/utils.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
|
||||||
|
def create_guard(default_value):
|
||||||
|
_state = default_value
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def state_guard(current_state):
|
||||||
|
nonlocal _state
|
||||||
|
old_state = _state
|
||||||
|
_state = current_state
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_state = old_state
|
||||||
|
|
||||||
|
def get_state():
|
||||||
|
return _state
|
||||||
|
|
||||||
|
return state_guard, get_state
|
||||||
|
|
||||||
|
|
||||||
|
sot_warmup_guard, in_sot_warmup_mode = create_guard(False)
|
||||||
|
profile_run_guard, in_profile_run_mode = create_guard(False)
|
@@ -25,6 +25,10 @@ from paddleformers.utils.log import logger
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||||
|
profile_run_guard,
|
||||||
|
sot_warmup_guard,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.guided_decoding import get_guided_backend
|
from fastdeploy.model_executor.guided_decoding import get_guided_backend
|
||||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||||
LogitsProcessorBase,
|
LogitsProcessorBase,
|
||||||
@@ -113,8 +117,10 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# self.kv_caches: list[paddle.Tensor] = []
|
# self.kv_caches: list[paddle.Tensor] = []
|
||||||
|
|
||||||
# Cuda Graph
|
# Cuda Graph
|
||||||
|
self.graph_opt_level = self.graph_opt_config.graph_opt_level
|
||||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||||
|
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||||
|
|
||||||
# Initialize share inputs
|
# Initialize share inputs
|
||||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||||
@@ -367,9 +373,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||||
"""Set dummy prefill inputs to share_inputs"""
|
"""Set dummy prefill inputs to share_inputs"""
|
||||||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
||||||
if self.enable_mm:
|
|
||||||
self.share_inputs["free_list"] = paddle.to_tensor([], dtype="int32")
|
|
||||||
self.share_inputs["free_list_len"][0] = 0
|
|
||||||
max_dec_len = expected_decode_len + 1
|
max_dec_len = expected_decode_len + 1
|
||||||
full_length = min(
|
full_length = min(
|
||||||
num_tokens // batch_size,
|
num_tokens // batch_size,
|
||||||
@@ -1007,6 +1010,17 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
time_after_capture = time.perf_counter()
|
time_after_capture = time.perf_counter()
|
||||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||||
|
|
||||||
|
@sot_warmup_guard(True)
|
||||||
|
def sot_warmup(self) -> None:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
for batch_size in self.sot_warmup_sizes:
|
||||||
|
self._dummy_run(
|
||||||
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||||
|
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||||
|
|
||||||
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
||||||
"""
|
"""
|
||||||
Get the index of the request that needs to be skipped during execution.
|
Get the index of the request that needs to be skipped during execution.
|
||||||
@@ -1208,6 +1222,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
|
@profile_run_guard(True)
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
|
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
|
||||||
|
|
||||||
|
@@ -189,6 +189,8 @@ class GpuWorker(WorkerBase):
|
|||||||
"""
|
"""
|
||||||
Perform the warm-up and the graph optimization
|
Perform the warm-up and the graph optimization
|
||||||
"""
|
"""
|
||||||
|
if self.model_runner.graph_opt_level >= 1:
|
||||||
|
self.model_runner.sot_warmup()
|
||||||
# Triger cuda grpah capture
|
# Triger cuda grpah capture
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
|
|
||||||
|
@@ -632,6 +632,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
|
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
|
||||||
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
|
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
|
||||||
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"],
|
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"],
|
||||||
|
sot_warmup_sizes=args.graph_optimization_config["sot_warmup_sizes"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note(tangbinhan): used for load_checkpoint
|
# Note(tangbinhan): used for load_checkpoint
|
||||||
|
Reference in New Issue
Block a user