mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
107
tests/distributed/test_communication.py
Normal file
107
tests/distributed/test_communication.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.distributed import communication
|
||||
|
||||
|
||||
class TestCommunicationBasic(unittest.TestCase):
|
||||
def setUp(self):
|
||||
communication._TP_AR = None
|
||||
|
||||
def test_capture_custom_allreduce_no_tp_ar(self):
|
||||
with communication.capture_custom_allreduce():
|
||||
pass
|
||||
|
||||
def test_capture_custom_allreduce_with_tp_ar(self):
|
||||
mock_tp_ar = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_tp_ar.capture.return_value = mock_context
|
||||
communication._TP_AR = mock_tp_ar
|
||||
with communication.capture_custom_allreduce():
|
||||
pass
|
||||
mock_tp_ar.capture.assert_called_once()
|
||||
|
||||
@patch("paddle.distributed.fleet.get_hybrid_communicate_group")
|
||||
@patch("fastdeploy.distributed.custom_all_reduce.CustomAllreduce")
|
||||
def test_use_custom_allreduce(self, mock_custom_ar, mock_get_hcg):
|
||||
mock_hcg = MagicMock()
|
||||
mock_get_hcg.return_value = mock_hcg
|
||||
|
||||
# fake group with required attributes used by CustomAllreduce
|
||||
fake_group = MagicMock()
|
||||
fake_group.rank = 0
|
||||
fake_group.world_size = 2
|
||||
mock_hcg.get_model_parallel_group.return_value = fake_group
|
||||
|
||||
communication.use_custom_allreduce()
|
||||
|
||||
self.assertIsNotNone(communication._TP_AR)
|
||||
mock_custom_ar.assert_called_once_with(fake_group, 8192 * 1024)
|
||||
|
||||
def test_custom_ar_clear_ipc_handles(self):
|
||||
mock_tp_ar = MagicMock()
|
||||
communication._TP_AR = mock_tp_ar
|
||||
communication.custom_ar_clear_ipc_handles()
|
||||
mock_tp_ar.clear_ipc_handles.assert_called_once()
|
||||
|
||||
@patch("fastdeploy.distributed.communication.dist.all_reduce")
|
||||
@patch("paddle.distributed.fleet.get_hybrid_communicate_group")
|
||||
def test_tensor_model_parallel_all_reduce(self, mock_get_hcg, mock_all_reduce):
|
||||
# ensure group exists
|
||||
mock_hcg = MagicMock()
|
||||
mock_get_hcg.return_value = mock_hcg
|
||||
fake_group = MagicMock()
|
||||
fake_group.world_size = 2
|
||||
mock_hcg.get_model_parallel_group.return_value = fake_group
|
||||
|
||||
# make all_reduce callable
|
||||
def fake_all_reduce(x, group=None):
|
||||
return x
|
||||
|
||||
mock_all_reduce.side_effect = fake_all_reduce
|
||||
|
||||
x = paddle.to_tensor([1.0])
|
||||
# call should not raise, ensure all_reduce was invoked
|
||||
_ = communication.tensor_model_parallel_all_reduce(x)
|
||||
mock_all_reduce.assert_called()
|
||||
|
||||
@patch("fastdeploy.distributed.communication.stream.all_reduce")
|
||||
@patch("paddle.distributed.fleet.get_hybrid_communicate_group")
|
||||
def test_tensor_model_parallel_all_reduce_custom(self, mock_get_hcg, mock_stream_ar):
|
||||
# ensure group exists
|
||||
mock_hcg = MagicMock()
|
||||
mock_get_hcg.return_value = mock_hcg
|
||||
fake_group = MagicMock()
|
||||
fake_group.world_size = 2
|
||||
mock_hcg.get_model_parallel_group.return_value = fake_group
|
||||
|
||||
# stream.all_reduce may not return value in source; ensure callable
|
||||
def fake_stream_all_reduce(x, **kwargs):
|
||||
return None
|
||||
|
||||
mock_stream_ar.side_effect = fake_stream_all_reduce
|
||||
|
||||
x = paddle.to_tensor([2.0])
|
||||
# the function does not return input_ in source; ensure call succeeds and stream.all_reduce used
|
||||
_ = communication.tensor_model_parallel_all_reduce_custom(x)
|
||||
mock_stream_ar.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
96
tests/distributed/test_cuda_wrapper.py
Normal file
96
tests/distributed/test_cuda_wrapper.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ctypes
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
|
||||
|
||||
|
||||
class TestCudaRTLibrary(unittest.TestCase):
|
||||
|
||||
@patch("fastdeploy.distributed.custom_all_reduce.cuda_wrapper.find_loaded_library")
|
||||
@patch("ctypes.CDLL")
|
||||
def test_basic_init_and_function_calls(self, mock_cdll, mock_find_lib):
|
||||
"""Test initialization and basic function calls of CudaRTLibrary"""
|
||||
mock_find_lib.return_value = "/usr/local/cuda/lib64/libcudart.so"
|
||||
mock_lib = MagicMock()
|
||||
mock_cdll.return_value = mock_lib
|
||||
|
||||
# Mock all exported functions to return success (0)
|
||||
for func in cuda_wrapper.CudaRTLibrary.exported_functions:
|
||||
setattr(mock_lib, func.name, MagicMock(return_value=0))
|
||||
mock_lib.cudaGetErrorString.return_value = b"no error"
|
||||
|
||||
lib = cuda_wrapper.CudaRTLibrary()
|
||||
ptr = lib.cudaMalloc(64)
|
||||
lib.cudaMemset(ptr, 1, 64)
|
||||
lib.cudaMemcpy(ptr, ptr, 64)
|
||||
lib.cudaFree(ptr)
|
||||
lib.cudaSetDevice(0)
|
||||
lib.cudaDeviceSynchronize()
|
||||
lib.cudaDeviceReset()
|
||||
handle = lib.cudaIpcGetMemHandle(ptr)
|
||||
lib.cudaIpcOpenMemHandle(handle)
|
||||
lib.cudaStreamIsCapturing(ctypes.c_void_p(0))
|
||||
|
||||
self.assertTrue(mock_lib.cudaMalloc.called)
|
||||
self.assertTrue(mock_lib.cudaFree.called)
|
||||
|
||||
@patch("builtins.open", create=True)
|
||||
def test_find_loaded_library_found(self, mock_open):
|
||||
"""Test find_loaded_library returns correct path when library is found"""
|
||||
# Simulate maps file containing libcudart path
|
||||
mock_open.return_value.__enter__.return_value = ["7f... /usr/local/cuda/lib64/libcudart.so.11.0\n"]
|
||||
result = cuda_wrapper.find_loaded_library("libcudart")
|
||||
self.assertIn("libcudart.so.11.0", result)
|
||||
|
||||
def test_find_loaded_library_not_found(self):
|
||||
"""Test find_loaded_library returns None when library is not found"""
|
||||
with patch("builtins.open", unittest.mock.mock_open(read_data="")):
|
||||
path = cuda_wrapper.find_loaded_library("libcudart")
|
||||
self.assertIsNone(path)
|
||||
|
||||
def test_cudart_check_raises_error(self):
|
||||
"""Test CUDART_CHECK raises RuntimeError for non-zero error codes"""
|
||||
lib = MagicMock()
|
||||
lib.cudaGetErrorString.return_value = b"mock error"
|
||||
fake = cuda_wrapper.CudaRTLibrary.__new__(cuda_wrapper.CudaRTLibrary)
|
||||
fake.funcs = {"cudaGetErrorString": lib.cudaGetErrorString}
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
fake.CUDART_CHECK(1) # Non-zero error code triggers exception
|
||||
|
||||
@patch("fastdeploy.distributed.custom_all_reduce.cuda_wrapper.find_loaded_library")
|
||||
@patch("ctypes.CDLL")
|
||||
def test_cache_path_reuse(self, mock_cdll, mock_find_lib):
|
||||
"""Test library path caching is reused between instances"""
|
||||
mock_find_lib.return_value = "/usr/local/cuda/lib64/libcudart.so"
|
||||
mock_lib = MagicMock()
|
||||
mock_cdll.return_value = mock_lib
|
||||
|
||||
for func in cuda_wrapper.CudaRTLibrary.exported_functions:
|
||||
setattr(mock_lib, func.name, MagicMock(return_value=0))
|
||||
mock_lib.cudaGetErrorString.return_value = b"ok"
|
||||
|
||||
first = cuda_wrapper.CudaRTLibrary()
|
||||
second = cuda_wrapper.CudaRTLibrary()
|
||||
|
||||
self.assertIs(first.funcs, second.funcs)
|
||||
self.assertIs(first.lib, second.lib)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -33,63 +33,63 @@ from fastdeploy.logger.handlers import (
|
||||
|
||||
class TestIntervalRotatingFileHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 创建临时目录
|
||||
# Create temporary directory
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.base_filename = os.path.join(self.temp_dir, "test.log")
|
||||
|
||||
def tearDown(self):
|
||||
# 清理临时目录
|
||||
# Clean up temporary directory
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_initialization(self):
|
||||
"""测试初始化参数校验"""
|
||||
# 测试无效interval
|
||||
"""Test initialization parameter validation"""
|
||||
# Test invalid interval
|
||||
with self.assertRaises(ValueError):
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=7)
|
||||
IntervalRotatingFileHandler(self.base_filename, interval=7)
|
||||
|
||||
# 测试有效初始化
|
||||
# Test valid initialization
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=6, backupDays=3)
|
||||
self.assertEqual(handler.interval, 6)
|
||||
self.assertEqual(handler.backup_days, 3)
|
||||
handler.close()
|
||||
|
||||
def test_file_rotation(self):
|
||||
"""测试日志文件滚动"""
|
||||
"""Test log file rotation mechanism"""
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=6, backupDays=1)
|
||||
|
||||
# 模拟初始状态
|
||||
# Get initial state
|
||||
initial_day = handler.current_day
|
||||
initial_hour = handler.current_hour
|
||||
|
||||
# 首次写入
|
||||
# First log write
|
||||
record = LogRecord("test", 20, "/path", 1, "Test message", [], None)
|
||||
handler.emit(record)
|
||||
|
||||
# 验证文件存在
|
||||
# Verify file existence
|
||||
expected_dir = Path(self.temp_dir) / initial_day
|
||||
expected_file = f"test_{initial_day}-{initial_hour:02d}.log"
|
||||
self.assertTrue((expected_dir / expected_file).exists())
|
||||
|
||||
# 验证符号链接
|
||||
# Verify symlink creation
|
||||
symlink = Path(self.temp_dir) / "current_test.log"
|
||||
self.assertTrue(symlink.is_symlink())
|
||||
|
||||
handler.close()
|
||||
|
||||
def test_time_based_rollover(self):
|
||||
"""测试基于时间的滚动触发"""
|
||||
"""Test time-based rollover triggers"""
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=1, backupDays=1)
|
||||
|
||||
# 强制设置初始时间
|
||||
# Force initial time settings
|
||||
handler.current_day = "2000-01-01"
|
||||
handler.current_hour = 0
|
||||
|
||||
# 测试小时变化触发
|
||||
# Test hour change trigger
|
||||
with unittest.mock.patch.object(handler, "_get_current_day", return_value="2000-01-01"):
|
||||
with unittest.mock.patch.object(handler, "_get_current_hour", return_value=1):
|
||||
self.assertTrue(handler.shouldRollover(None))
|
||||
|
||||
# 测试日期变化触发
|
||||
# Test day change trigger
|
||||
with unittest.mock.patch.object(handler, "_get_current_day", return_value="2000-01-02"):
|
||||
with unittest.mock.patch.object(handler, "_get_current_hour", return_value=0):
|
||||
self.assertTrue(handler.shouldRollover(None))
|
||||
@@ -97,40 +97,38 @@ class TestIntervalRotatingFileHandler(unittest.TestCase):
|
||||
handler.close()
|
||||
|
||||
def test_cleanup_logic(self):
|
||||
"""测试过期文件清理"""
|
||||
# 使用固定测试时间
|
||||
"""Test expired file cleanup mechanism"""
|
||||
# Use fixed test time
|
||||
test_time = datetime(2023, 1, 1, 12, 0)
|
||||
with unittest.mock.patch("time.time", return_value=time.mktime(test_time.timetuple())):
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=1, backupDays=0) # 立即清理
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=1, backupDays=0) # Clean immediately
|
||||
|
||||
# 创建测试目录结构
|
||||
# Create test directory structure
|
||||
old_day = (test_time - timedelta(days=2)).strftime("%Y-%m-%d")
|
||||
old_dir = Path(self.temp_dir) / old_day
|
||||
old_dir.mkdir()
|
||||
|
||||
# 创建测试文件
|
||||
# Create test file
|
||||
old_file = old_dir / f"test_{old_day}-00.log"
|
||||
old_file.write_text("test content")
|
||||
|
||||
# 确保文件时间戳正确
|
||||
# Ensure correct timestamps
|
||||
old_time = time.mktime((test_time - timedelta(days=2)).timetuple())
|
||||
os.utime(str(old_dir), (old_time, old_time))
|
||||
os.utime(str(old_file), (old_time, old_time))
|
||||
|
||||
# 验证文件创建成功
|
||||
# Verify file creation
|
||||
self.assertTrue(old_file.exists())
|
||||
|
||||
# 执行清理
|
||||
# Execute cleanup
|
||||
handler._clean_expired_data()
|
||||
|
||||
# 添加短暂延迟确保文件系统操作完成
|
||||
# Short delay for filesystem operations
|
||||
time.sleep(0.1)
|
||||
|
||||
# 验证清理结果
|
||||
# Verify cleanup result
|
||||
if old_dir.exists():
|
||||
# 调试输出:列出目录内容
|
||||
print(f"Directory contents: {list(old_dir.glob('*'))}")
|
||||
# 尝试强制删除以清理测试环境
|
||||
try:
|
||||
shutil.rmtree(str(old_dir))
|
||||
except Exception as e:
|
||||
@@ -144,7 +142,7 @@ class TestIntervalRotatingFileHandler(unittest.TestCase):
|
||||
handler.close()
|
||||
|
||||
def test_multi_interval(self):
|
||||
"""测试多间隔配置"""
|
||||
"""Test multiple interval configurations"""
|
||||
for interval in [1, 2, 3, 4, 6, 8, 12, 24]:
|
||||
with self.subTest(interval=interval):
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, interval=interval)
|
||||
@@ -154,35 +152,35 @@ class TestIntervalRotatingFileHandler(unittest.TestCase):
|
||||
handler.close()
|
||||
|
||||
def test_utc_mode(self):
|
||||
"""测试UTC时间模式"""
|
||||
"""Test UTC time mode"""
|
||||
handler = IntervalRotatingFileHandler(self.base_filename, utc=True)
|
||||
self.assertTrue(time.strftime("%Y-%m-%d", time.gmtime()).startswith(handler.current_day))
|
||||
handler.close()
|
||||
|
||||
def test_symlink_creation(self):
|
||||
"""测试符号链接创建和更新"""
|
||||
"""Test symlink creation and updates"""
|
||||
handler = IntervalRotatingFileHandler(self.base_filename)
|
||||
symlink = Path(self.temp_dir) / "current_test.log"
|
||||
|
||||
# 获取初始符号链接目标
|
||||
# Get initial symlink target
|
||||
initial_target = os.readlink(str(symlink))
|
||||
|
||||
# 强制触发滚动(模拟时间变化)
|
||||
# Force rollover (simulate time change)
|
||||
with unittest.mock.patch.object(handler, "_get_current_day", return_value="2000-01-01"):
|
||||
with unittest.mock.patch.object(handler, "_get_current_hour", return_value=12):
|
||||
handler.doRollover()
|
||||
|
||||
# 获取新符号链接目标
|
||||
# Get new symlink target
|
||||
new_target = os.readlink(str(symlink))
|
||||
|
||||
# 验证目标已更新
|
||||
# Verify target updated
|
||||
self.assertNotEqual(initial_target, new_target)
|
||||
self.assertIn("2000-01-01/test_2000-01-01-12.log", new_target)
|
||||
handler.close()
|
||||
|
||||
|
||||
class TestDailyRotatingFileHandler(unittest.TestCase):
|
||||
"""测试 DailyRotatingFileHandler"""
|
||||
"""Tests for DailyRotatingFileHandler"""
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="fd_handler_test_")
|
||||
@@ -191,112 +189,112 @@ class TestDailyRotatingFileHandler(unittest.TestCase):
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_daily_rotation(self):
|
||||
"""测试每天滚动"""
|
||||
"""Test daily log rotation"""
|
||||
log_file = os.path.join(self.temp_dir, "test.log")
|
||||
handler = DailyRotatingFileHandler(log_file, backupCount=3)
|
||||
logger = getLogger("test_daily_rotation")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(INFO)
|
||||
|
||||
# 写入第一条日志
|
||||
# Write first log
|
||||
logger.info("Test log message day 1")
|
||||
handler.flush()
|
||||
|
||||
# 模拟时间变化到第二天
|
||||
# Simulate time change to next day
|
||||
with patch.object(handler, "_compute_fn") as mock_compute:
|
||||
tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
new_filename = f"test.log.{tomorrow}"
|
||||
mock_compute.return_value = new_filename
|
||||
|
||||
# 手动触发滚动检查和执行
|
||||
# Manually trigger rollover check
|
||||
mock_record = MagicMock()
|
||||
if handler.shouldRollover(mock_record):
|
||||
handler.doRollover()
|
||||
|
||||
# 写入第二条日志
|
||||
# Write second log
|
||||
logger.info("Test log message day 2")
|
||||
handler.flush()
|
||||
handler.close()
|
||||
|
||||
# 验证文件存在
|
||||
# Verify file existence
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# 检查原始文件和带日期的文件
|
||||
base_file = os.path.join(self.temp_dir, "test.log")
|
||||
today_file = os.path.join(self.temp_dir, f"test.log.{today}")
|
||||
tomorrow_file = os.path.join(self.temp_dir, f"test.log.{tomorrow}")
|
||||
|
||||
# 至少应该有一个文件存在
|
||||
# At least one file should exist
|
||||
files_exist = any([os.path.isfile(base_file), os.path.isfile(today_file), os.path.isfile(tomorrow_file)])
|
||||
self.assertTrue(files_exist, f"No log files found in {self.temp_dir}")
|
||||
|
||||
def test_backup_count(self):
|
||||
"""测试备份文件数量限制"""
|
||||
"""Test backup file count limitation"""
|
||||
log_file = os.path.join(self.temp_dir, "test.log")
|
||||
handler = DailyRotatingFileHandler(log_file, backupCount=2)
|
||||
logger = getLogger("test_backup_count")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(INFO)
|
||||
|
||||
# 创建多个日期的日志文件
|
||||
# Create log files for multiple dates
|
||||
base_date = datetime.now()
|
||||
|
||||
for i in range(5): # 创建5天的日志
|
||||
for i in range(5): # Create 5 days of logs
|
||||
date_str = (base_date - timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
test_file = os.path.join(self.temp_dir, f"test.log.{date_str}")
|
||||
|
||||
# 直接创建文件
|
||||
# Create file directly
|
||||
with open(test_file, "w") as f:
|
||||
f.write(f"Test log for {date_str}\n")
|
||||
|
||||
# 触发清理
|
||||
# Trigger cleanup
|
||||
handler.delete_expired_files()
|
||||
handler.close()
|
||||
|
||||
# 验证备份文件数量(应该保留最新的2个 + 当前文件)
|
||||
# Verify backup count (should keep latest 2 + current file)
|
||||
log_files = [f for f in os.listdir(self.temp_dir) if f.startswith("test.log.")]
|
||||
print(f"Log files found: {log_files}") # 调试输出
|
||||
print(f"Log files found: {log_files}") # Debug output
|
||||
|
||||
# backupCount=2 意味着应该最多保留2个备份文件
|
||||
self.assertLessEqual(len(log_files), 3) # 2个备份 + 可能的当前文件
|
||||
# backupCount=2 means max 2 backup files should remain
|
||||
self.assertLessEqual(len(log_files), 3) # 2 backups + possible current file
|
||||
|
||||
|
||||
class TestLazyFileHandler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# 创建临时目录
|
||||
# Create temporary directory
|
||||
self.tmpdir = tempfile.TemporaryDirectory()
|
||||
self.logfile = Path(self.tmpdir.name) / "test.log"
|
||||
|
||||
def tearDown(self):
|
||||
# 清理临时目录
|
||||
# Clean up temporary directory
|
||||
self.tmpdir.cleanup()
|
||||
|
||||
def test_lazy_initialization_and_write(self):
|
||||
"""Test lazy initialization and log writing"""
|
||||
logger = logging.getLogger("test_lazy")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 初始化 LazyFileHandler
|
||||
# Initialize LazyFileHandler
|
||||
handler = LazyFileHandler(str(self.logfile), backupCount=3, level=logging.DEBUG)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# 此时 _real_handler 应该还没创建
|
||||
# _real_handler should not be created yet
|
||||
self.assertIsNone(handler._real_handler)
|
||||
|
||||
# 写一条日志
|
||||
# Write a log
|
||||
logger.info("Hello Lazy Handler")
|
||||
|
||||
# 写入后 _real_handler 应该被创建
|
||||
# _real_handler should be created after writing
|
||||
self.assertIsNotNone(handler._real_handler)
|
||||
|
||||
# 日志文件应该存在且内容包含日志信息
|
||||
# Log file should exist with correct content
|
||||
self.assertTrue(self.logfile.exists())
|
||||
with open(self.logfile, "r") as f:
|
||||
content = f.read()
|
||||
self.assertIn("Hello Lazy Handler", content)
|
||||
|
||||
# 关闭 handler
|
||||
# Close handler
|
||||
handler.close()
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class LoggerTests(unittest.TestCase):
|
||||
self.env_patchers = [
|
||||
patch("fastdeploy.envs.FD_LOG_DIR", self.tmp_dir),
|
||||
patch("fastdeploy.envs.FD_DEBUG", 0),
|
||||
patch("fastdeploy.envs.FD_LOG_BACKUP_COUNT", "1"),
|
||||
patch("fastdeploy.envs.FD_LOG_BACKUP_COUNT", 1),
|
||||
]
|
||||
for p in self.env_patchers:
|
||||
p.start()
|
||||
@@ -77,5 +77,54 @@ class LoggerTests(unittest.TestCase):
|
||||
self.assertTrue(legacy_logger.propagate)
|
||||
|
||||
|
||||
class LoggerExtraTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.logger = FastDeployLogger()
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(FastDeployLogger, "_instance"):
|
||||
FastDeployLogger._instance = None
|
||||
if hasattr(FastDeployLogger, "_initialized"):
|
||||
FastDeployLogger._initialized = False
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Ensure multiple instances are same"""
|
||||
a = FastDeployLogger()
|
||||
b = FastDeployLogger()
|
||||
self.assertIs(a, b)
|
||||
|
||||
def test_initialize_only_once(self):
|
||||
"""Ensure _initialize won't re-run if already initialized"""
|
||||
self.logger._initialized = True
|
||||
with patch("fastdeploy.logger.logger.setup_logging") as mock_setup:
|
||||
self.logger._initialize()
|
||||
mock_setup.assert_not_called()
|
||||
|
||||
def test_get_logger_unified_path(self):
|
||||
"""Directly test get_logger unified path"""
|
||||
with patch("fastdeploy.logger.logger.setup_logging") as mock_setup:
|
||||
log = self.logger.get_logger("utils")
|
||||
self.assertTrue(log.name.startswith("fastdeploy."))
|
||||
mock_setup.assert_called_once()
|
||||
|
||||
def test_get_logger_legacy_path(self):
|
||||
"""Test legacy get_logger path"""
|
||||
with patch("fastdeploy.logger.logger.FastDeployLogger._get_legacy_logger") as mock_legacy:
|
||||
self.logger.get_logger("x", "y.log", False, False)
|
||||
mock_legacy.assert_called_once()
|
||||
|
||||
def test_get_legacy_logger_debug_mode(self):
|
||||
"""Ensure debug level is set when FD_DEBUG=1"""
|
||||
with patch("fastdeploy.envs.FD_DEBUG", 1):
|
||||
logger = self.logger._get_legacy_logger("debug_case", "d.log")
|
||||
self.assertEqual(logger.level, logging.DEBUG)
|
||||
|
||||
def test_get_legacy_logger_without_formatter(self):
|
||||
"""Test legacy logger without formatter"""
|
||||
logger = self.logger._get_legacy_logger("nofmt", "n.log", without_formater=True)
|
||||
for h in logger.handlers:
|
||||
self.assertIsNone(h.formatter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -26,13 +25,8 @@ from fastdeploy.logger.setup_logging import setup_logging
|
||||
|
||||
|
||||
class TestSetupLogging(unittest.TestCase):
|
||||
|
||||
# -------------------------------------------------
|
||||
# 夹具:每个测试独占临时目录
|
||||
# -------------------------------------------------
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="logger_setup_test_")
|
||||
# 统一 patch 环境变量
|
||||
self.patches = [
|
||||
patch("fastdeploy.envs.FD_LOG_DIR", self.temp_dir),
|
||||
patch("fastdeploy.envs.FD_DEBUG", 0),
|
||||
@@ -43,92 +37,91 @@ class TestSetupLogging(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
[p.stop() for p in self.patches]
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
# 清理单例标记,避免影响其他测试
|
||||
if hasattr(setup_logging, "_configured"):
|
||||
delattr(setup_logging, "_configured")
|
||||
|
||||
# -------------------------------------------------
|
||||
# 基础:目录自动创建
|
||||
# -------------------------------------------------
|
||||
def test_log_dir_created(self):
|
||||
nested = os.path.join(self.temp_dir, "a", "b", "c")
|
||||
setup_logging(log_dir=nested)
|
||||
self.assertTrue(Path(nested).is_dir())
|
||||
|
||||
# -------------------------------------------------
|
||||
# 默认配置文件:文件 handler 不带颜色
|
||||
# -------------------------------------------------
|
||||
def test_default_config_file_no_ansi(self):
|
||||
setup_logging()
|
||||
def test_default_config_fallback(self):
|
||||
"""Pass a non-existent config_file to trigger default_config"""
|
||||
fake_cfg = os.path.join(self.temp_dir, "no_such_cfg.json")
|
||||
setup_logging(config_file=fake_cfg)
|
||||
logger = logging.getLogger("fastdeploy")
|
||||
logger.error("test ansi")
|
||||
self.assertTrue(logger.handlers)
|
||||
handler_classes = [h.__class__.__name__ for h in logger.handlers]
|
||||
self.assertIn("TimedRotatingFileHandler", handler_classes[0])
|
||||
|
||||
default_file = Path(self.temp_dir) / "default.log"
|
||||
self.assertTrue(default_file.exists())
|
||||
with default_file.open() as f:
|
||||
content = f.read()
|
||||
# 文件中不应出现 ANSI 转义
|
||||
self.assertNotIn("\033[", content)
|
||||
|
||||
# -------------------------------------------------
|
||||
# 调试级别开关
|
||||
# -------------------------------------------------
|
||||
def test_debug_level(self):
|
||||
def test_debug_level_affects_handlers(self):
|
||||
"""FD_DEBUG=1 should force DEBUG level"""
|
||||
with patch("fastdeploy.envs.FD_DEBUG", 1):
|
||||
setup_logging()
|
||||
logger = logging.getLogger("fastdeploy")
|
||||
self.assertEqual(logger.level, logging.DEBUG)
|
||||
# debug 消息应该能落到文件
|
||||
logger.debug("debug msg")
|
||||
default_file = Path(self.temp_dir) / "default.log"
|
||||
self.assertIn("debug msg", default_file.read_text())
|
||||
with patch("logging.config.dictConfig") as mock_cfg:
|
||||
setup_logging()
|
||||
called_config = mock_cfg.call_args[0][0]
|
||||
for handler in called_config["handlers"].values():
|
||||
self.assertIn("formatter", handler)
|
||||
self.assertEqual(called_config["handlers"]["console"]["level"], "DEBUG")
|
||||
|
||||
# -------------------------------------------------
|
||||
# 自定义 JSON 配置文件加载
|
||||
# -------------------------------------------------
|
||||
def test_custom_config_file(self):
|
||||
@patch("logging.config.dictConfig")
|
||||
def test_custom_config_with_dailyrotating_and_debug(self, mock_dict):
|
||||
custom_cfg = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {"plain": {"format": "%(message)s"}},
|
||||
"handlers": {
|
||||
"custom": {
|
||||
"class": "logging.FileHandler",
|
||||
"filename": os.path.join(self.temp_dir, "custom.log"),
|
||||
"daily": {
|
||||
"class": "logging.handlers.DailyRotatingFileHandler",
|
||||
"level": "INFO",
|
||||
"formatter": "plain",
|
||||
}
|
||||
},
|
||||
"loggers": {"fastdeploy": {"handlers": ["custom"], "level": "INFO"}},
|
||||
"loggers": {"fastdeploy": {"handlers": ["daily"], "level": "INFO"}},
|
||||
}
|
||||
cfg_path = Path(self.temp_dir) / "cfg.json"
|
||||
cfg_path.write_text(json.dumps(custom_cfg))
|
||||
|
||||
setup_logging(config_file=str(cfg_path))
|
||||
logger = logging.getLogger("fastdeploy")
|
||||
logger.info("from custom cfg")
|
||||
with patch("fastdeploy.envs.FD_DEBUG", 1):
|
||||
setup_logging(config_file=str(cfg_path))
|
||||
|
||||
custom_file = Path(self.temp_dir) / "custom.log"
|
||||
self.assertEqual(custom_file.read_text().strip(), "from custom cfg")
|
||||
config_used = mock_dict.call_args[0][0]
|
||||
self.assertIn("daily", config_used["handlers"])
|
||||
self.assertEqual(config_used["handlers"]["daily"]["level"], "DEBUG")
|
||||
self.assertIn("backupCount", config_used["handlers"]["daily"])
|
||||
|
||||
# -------------------------------------------------
|
||||
# 重复调用 setup_logging 不会重复配置
|
||||
# -------------------------------------------------
|
||||
def test_configure_once(self):
|
||||
logger1 = setup_logging()
|
||||
logger2 = setup_logging()
|
||||
self.assertIs(logger1, logger2)
|
||||
"""Ensure idempotent setup"""
|
||||
l1 = setup_logging()
|
||||
l2 = setup_logging()
|
||||
self.assertIs(l1, l2)
|
||||
|
||||
def test_envs_priority_used_for_log_dir(self):
|
||||
"""When log_dir=None, should use envs.FD_LOG_DIR"""
|
||||
with patch("fastdeploy.envs.FD_LOG_DIR", self.temp_dir):
|
||||
setup_logging()
|
||||
self.assertTrue(os.path.exists(self.temp_dir))
|
||||
|
||||
# -------------------------------------------------
|
||||
# 控制台 handler 使用 ColoredFormatter
|
||||
# -------------------------------------------------
|
||||
@patch("logging.StreamHandler.emit")
|
||||
def test_console_colored(self, mock_emit):
|
||||
setup_logging()
|
||||
logger = logging.getLogger("fastdeploy")
|
||||
logger.error("color test")
|
||||
# 只要 ColoredFormatter 被实例化即可,简单断言 emit 被调用
|
||||
self.assertTrue(mock_emit.called)
|
||||
|
||||
@patch("logging.config.dictConfig")
|
||||
def test_backup_count_merging(self, mock_dict):
|
||||
custom_cfg = {
|
||||
"version": 1,
|
||||
"handlers": {"daily": {"class": "logging.handlers.DailyRotatingFileHandler", "formatter": "plain"}},
|
||||
"loggers": {"fastdeploy": {"handlers": ["daily"], "level": "INFO"}},
|
||||
}
|
||||
cfg_path = Path(self.temp_dir) / "cfg.json"
|
||||
cfg_path.write_text(json.dumps(custom_cfg))
|
||||
|
||||
setup_logging(config_file=str(cfg_path))
|
||||
|
||||
config_used = mock_dict.call_args[0][0]
|
||||
self.assertEqual(config_used["handlers"]["daily"]["backupCount"], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user