mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
Add GPU memory utilization warning for values >= 0.95
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfig
|
|||||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
|
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger, console_logger
|
||||||
|
|
||||||
logger = get_logger("config", "config.log")
|
logger = get_logger("config", "config.log")
|
||||||
|
|
||||||
@@ -967,6 +967,12 @@ class CacheConfig:
|
|||||||
def _verify_args(self):
|
def _verify_args(self):
|
||||||
if self.gpu_memory_utilization > 1.0:
|
if self.gpu_memory_utilization > 1.0:
|
||||||
raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
|
raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
|
||||||
|
if self.gpu_memory_utilization >= 0.95:
|
||||||
|
console_logger.warning(
|
||||||
|
f"GPU memory utilization is set to {self.gpu_memory_utilization}, which is >= 0.95. "
|
||||||
|
"This may cause out-of-memory (OOM) issues during inference due to GPU memory fluctuations. "
|
||||||
|
"It is recommended to configure gpu_memory_utilization below 0.9 for stable operation."
|
||||||
|
)
|
||||||
if self.kv_cache_ratio > 1.0:
|
if self.kv_cache_ratio > 1.0:
|
||||||
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
|
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
|
||||||
|
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
import logging
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
@@ -79,6 +81,38 @@ class TestConfig(unittest.TestCase):
|
|||||||
fd_config.init_cache_info()
|
fd_config.init_cache_info()
|
||||||
assert fd_config.disaggregate_info["role"] == "prefill"
|
assert fd_config.disaggregate_info["role"] == "prefill"
|
||||||
|
|
||||||
|
def test_gpu_memory_utilization_warning(self):
|
||||||
|
"""Test that a warning is issued when gpu_memory_utilization >= 0.95"""
|
||||||
|
with patch('fastdeploy.utils.console_logger') as mock_logger:
|
||||||
|
# Test case 1: gpu_memory_utilization = 0.95 should trigger warning
|
||||||
|
cache_config = CacheConfig({"gpu_memory_utilization": 0.95})
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
warning_call = mock_logger.warning.call_args[0][0]
|
||||||
|
self.assertIn("0.95", warning_call)
|
||||||
|
self.assertIn("out-of-memory", warning_call)
|
||||||
|
self.assertIn("below 0.9", warning_call)
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_logger.reset_mock()
|
||||||
|
|
||||||
|
# Test case 2: gpu_memory_utilization = 0.99 should trigger warning
|
||||||
|
cache_config = CacheConfig({"gpu_memory_utilization": 0.99})
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_logger.reset_mock()
|
||||||
|
|
||||||
|
# Test case 3: gpu_memory_utilization = 0.9 should NOT trigger warning
|
||||||
|
cache_config = CacheConfig({"gpu_memory_utilization": 0.9})
|
||||||
|
mock_logger.warning.assert_not_called()
|
||||||
|
|
||||||
|
# Reset mock
|
||||||
|
mock_logger.reset_mock()
|
||||||
|
|
||||||
|
# Test case 4: gpu_memory_utilization = 0.8 should NOT trigger warning
|
||||||
|
cache_config = CacheConfig({"gpu_memory_utilization": 0.8})
|
||||||
|
mock_logger.warning.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Reference in New Issue
Block a user