mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,13 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import pynvml
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
@@ -33,7 +34,6 @@ logger = get_logger("gpu_worker", "gpu_worker.log")
|
||||
|
||||
|
||||
class GpuWorker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
@@ -52,8 +52,7 @@ class GpuWorker(WorkerBase):
|
||||
Initialize device and construct model runner
|
||||
"""
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda(
|
||||
):
|
||||
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda():
|
||||
# Set evironment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"gpu:{self.local_rank % self.max_chips_per_node}"
|
||||
@@ -63,12 +62,11 @@ class GpuWorker(WorkerBase):
|
||||
gc.collect()
|
||||
paddle.device.cuda.empty_cache()
|
||||
if self.parallel_config.enable_custom_all_reduce:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
use_custom_allreduce
|
||||
from fastdeploy.distributed.communication_op import use_custom_allreduce
|
||||
|
||||
use_custom_allreduce()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct model runner
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
@@ -76,7 +74,8 @@ class GpuWorker(WorkerBase):
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank)
|
||||
local_rank=self.local_rank,
|
||||
)
|
||||
|
||||
def prefill_finished(self):
|
||||
"""
|
||||
@@ -102,33 +101,30 @@ class GpuWorker(WorkerBase):
|
||||
Gb = 1024**3
|
||||
paddle.device.cuda.reset_max_memory_reserved(self.local_rank)
|
||||
paddle.device.cuda.reset_max_memory_allocated(self.local_rank)
|
||||
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(
|
||||
self.local_rank)
|
||||
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(
|
||||
self.local_rank) # not reserved
|
||||
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(self.local_rank)
|
||||
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(self.local_rank) # not reserved
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(
|
||||
int(self.device_ids[self.local_rank]))
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(int(self.device_ids[self.local_rank]))
|
||||
before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
logger.info((
|
||||
"Before running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Total memory: {before_run_meminfo.total / Gb}",
|
||||
f"\nDevice used memory: {before_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {before_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"
|
||||
))
|
||||
logger.info(
|
||||
(
|
||||
"Before running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Total memory: {before_run_meminfo.total / Gb}",
|
||||
f"\nDevice used memory: {before_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {before_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Profile run
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# 3. Statistical memory information
|
||||
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(
|
||||
self.local_rank)
|
||||
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(
|
||||
self.local_rank)
|
||||
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(self.local_rank)
|
||||
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(self.local_rank)
|
||||
|
||||
model_block_memory_used = self.cal_theortical_kvcache()
|
||||
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
||||
@@ -138,34 +134,39 @@ class GpuWorker(WorkerBase):
|
||||
after_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
available_kv_cache_memory = after_run_meminfo.total * \
|
||||
self.parallel_config.gpu_memory_utilization - after_run_meminfo.used - paddle_peak_increase
|
||||
available_kv_cache_memory = (
|
||||
after_run_meminfo.total * self.parallel_config.gpu_memory_utilization
|
||||
- after_run_meminfo.used
|
||||
- paddle_peak_increase
|
||||
)
|
||||
available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num
|
||||
|
||||
end_time = time.perf_counter()
|
||||
logger.info((
|
||||
"After running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
|
||||
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||
f"Profile time: {end_time - start_time}"))
|
||||
logger.info(
|
||||
(
|
||||
"After running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Total memory: {after_run_meminfo.total / Gb}",
|
||||
f"\nDevice used memory: {after_run_meminfo.used / Gb}",
|
||||
f"\nDevice free memory: {after_run_meminfo.free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||
f"Profile time: {end_time - start_time}",
|
||||
)
|
||||
)
|
||||
|
||||
return available_kv_cache_memory # return to caculate the block num in this device
|
||||
|
||||
def load_model(self) -> None:
|
||||
""" Load model """
|
||||
"""Load model"""
|
||||
self.model_runner.load_model()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
""" Get current model """
|
||||
"""Get current model"""
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
""" Initizlize the KV Cache """
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
"""Initizlize the KV Cache"""
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
@@ -177,7 +178,7 @@ class GpuWorker(WorkerBase):
|
||||
return output
|
||||
|
||||
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
|
||||
""" Process new requests and then start the decode loop
|
||||
"""Process new requests and then start the decode loop
|
||||
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
|
||||
and workers and modelrunners should not perceive it.
|
||||
"""
|
||||
@@ -195,10 +196,9 @@ class GpuWorker(WorkerBase):
|
||||
return True
|
||||
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
""" Calculate the block memory required """
|
||||
"""Calculate the block memory required"""
|
||||
return self.model_runner.cal_theortical_kvcache()
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" Reinitialize the kv cache using the parameters from the profile """
|
||||
self.model_runner.update_share_input_block_num(
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
"""Reinitialize the kv cache using the parameters from the profile"""
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
Reference in New Issue
Block a user