mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[SOT] Extend SOT warmup support to new hardware (#3032)
* add new hardware * add_sot_warmup4new_hardware * fix conflict * rm Optional
This commit is contained in:
@@ -26,6 +26,10 @@ from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
profile_run_guard,
|
||||
sot_warmup_guard,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
@@ -346,7 +350,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# self.kv_caches: list[paddle.Tensor] = []
|
||||
|
||||
# Cuda Graph
|
||||
self.graph_opt_level = self.graph_opt_config.graph_opt_level
|
||||
self.use_cudagraph = False
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs, dtype="int32")
|
||||
|
||||
# Initialize share inputs
|
||||
@@ -764,6 +770,17 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
logger.warn("XPU not support cuda graph currently")
|
||||
pass
|
||||
|
||||
@sot_warmup_guard(True)
|
||||
def sot_warmup(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
for batch_size in self.sot_warmup_sizes:
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
@@ -901,6 +918,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
|
||||
@profile_run_guard(True)
|
||||
def profile_run(self) -> None:
|
||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
|
||||
|
||||
|
Reference in New Issue
Block a user