From 2aabaecbc2c8c8c95fd6713fdc820ece7485fae2 Mon Sep 17 00:00:00 2001 From: Echo-Nie <157974576+Echo-Nie@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:43:33 +0800 Subject: [PATCH] [CI] Add five unittest (#4958) * add unittest * Update test_logger.py --- tests/distributed/test_communication.py | 107 +++++++++++++++++++++ tests/distributed/test_cuda_wrapper.py | 96 +++++++++++++++++++ tests/logger/test_handlers.py | 118 ++++++++++++------------ tests/logger/test_logger.py | 51 +++++++++- tests/logger/test_setup_logging.py | 111 +++++++++++----------- 5 files changed, 363 insertions(+), 120 deletions(-) create mode 100644 tests/distributed/test_communication.py create mode 100644 tests/distributed/test_cuda_wrapper.py diff --git a/tests/distributed/test_communication.py b/tests/distributed/test_communication.py new file mode 100644 index 000000000..0d35adc57 --- /dev/null +++ b/tests/distributed/test_communication.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import paddle + +from fastdeploy.distributed import communication + + +class TestCommunicationBasic(unittest.TestCase): + def setUp(self): + communication._TP_AR = None + + def test_capture_custom_allreduce_no_tp_ar(self): + with communication.capture_custom_allreduce(): + pass + + def test_capture_custom_allreduce_with_tp_ar(self): + mock_tp_ar = MagicMock() + mock_context = MagicMock() + mock_tp_ar.capture.return_value = mock_context + communication._TP_AR = mock_tp_ar + with communication.capture_custom_allreduce(): + pass + mock_tp_ar.capture.assert_called_once() + + @patch("paddle.distributed.fleet.get_hybrid_communicate_group") + @patch("fastdeploy.distributed.custom_all_reduce.CustomAllreduce") + def test_use_custom_allreduce(self, mock_custom_ar, mock_get_hcg): + mock_hcg = MagicMock() + mock_get_hcg.return_value = mock_hcg + + # fake group with required attributes used by CustomAllreduce + fake_group = MagicMock() + fake_group.rank = 0 + fake_group.world_size = 2 + mock_hcg.get_model_parallel_group.return_value = fake_group + + communication.use_custom_allreduce() + + self.assertIsNotNone(communication._TP_AR) + mock_custom_ar.assert_called_once_with(fake_group, 8192 * 1024) + + def test_custom_ar_clear_ipc_handles(self): + mock_tp_ar = MagicMock() + communication._TP_AR = mock_tp_ar + communication.custom_ar_clear_ipc_handles() + mock_tp_ar.clear_ipc_handles.assert_called_once() + + @patch("fastdeploy.distributed.communication.dist.all_reduce") + @patch("paddle.distributed.fleet.get_hybrid_communicate_group") + def test_tensor_model_parallel_all_reduce(self, mock_get_hcg, mock_all_reduce): + # ensure group exists + mock_hcg = MagicMock() + mock_get_hcg.return_value = mock_hcg + fake_group = MagicMock() + fake_group.world_size = 2 + mock_hcg.get_model_parallel_group.return_value = fake_group + + # make all_reduce callable + def fake_all_reduce(x, group=None): + return x + + mock_all_reduce.side_effect = fake_all_reduce + + x = paddle.to_tensor([1.0]) + # call should not raise, ensure all_reduce was invoked + _ = communication.tensor_model_parallel_all_reduce(x) + mock_all_reduce.assert_called() + + @patch("fastdeploy.distributed.communication.stream.all_reduce") + @patch("paddle.distributed.fleet.get_hybrid_communicate_group") + def test_tensor_model_parallel_all_reduce_custom(self, mock_get_hcg, mock_stream_ar): + # ensure group exists + mock_hcg = MagicMock() + mock_get_hcg.return_value = mock_hcg + fake_group = MagicMock() + fake_group.world_size = 2 + mock_hcg.get_model_parallel_group.return_value = fake_group + + # stream.all_reduce may not return value in source; ensure callable + def fake_stream_all_reduce(x, **kwargs): + return None + + mock_stream_ar.side_effect = fake_stream_all_reduce + + x = paddle.to_tensor([2.0]) + # the function does not return input_ in source; ensure call succeeds and stream.all_reduce used + _ = communication.tensor_model_parallel_all_reduce_custom(x) + mock_stream_ar.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/test_cuda_wrapper.py b/tests/distributed/test_cuda_wrapper.py new file mode 100644 index 000000000..aa62059ed --- /dev/null +++ b/tests/distributed/test_cuda_wrapper.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ctypes +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.distributed.custom_all_reduce import cuda_wrapper + + +class TestCudaRTLibrary(unittest.TestCase): + + @patch("fastdeploy.distributed.custom_all_reduce.cuda_wrapper.find_loaded_library") + @patch("ctypes.CDLL") + def test_basic_init_and_function_calls(self, mock_cdll, mock_find_lib): + """Test initialization and basic function calls of CudaRTLibrary""" + mock_find_lib.return_value = "/usr/local/cuda/lib64/libcudart.so" + mock_lib = MagicMock() + mock_cdll.return_value = mock_lib + + # Mock all exported functions to return success (0) + for func in cuda_wrapper.CudaRTLibrary.exported_functions: + setattr(mock_lib, func.name, MagicMock(return_value=0)) + mock_lib.cudaGetErrorString.return_value = b"no error" + + lib = cuda_wrapper.CudaRTLibrary() + ptr = lib.cudaMalloc(64) + lib.cudaMemset(ptr, 1, 64) + lib.cudaMemcpy(ptr, ptr, 64) + lib.cudaFree(ptr) + lib.cudaSetDevice(0) + lib.cudaDeviceSynchronize() + lib.cudaDeviceReset() + handle = lib.cudaIpcGetMemHandle(ptr) + lib.cudaIpcOpenMemHandle(handle) + lib.cudaStreamIsCapturing(ctypes.c_void_p(0)) + + self.assertTrue(mock_lib.cudaMalloc.called) + self.assertTrue(mock_lib.cudaFree.called) + + @patch("builtins.open", create=True) + def test_find_loaded_library_found(self, mock_open): + """Test find_loaded_library returns correct path when library is found""" + # Simulate maps file containing libcudart path + mock_open.return_value.__enter__.return_value = ["7f... /usr/local/cuda/lib64/libcudart.so.11.0\n"] + result = cuda_wrapper.find_loaded_library("libcudart") + self.assertIn("libcudart.so.11.0", result) + + def test_find_loaded_library_not_found(self): + """Test find_loaded_library returns None when library is not found""" + with patch("builtins.open", unittest.mock.mock_open(read_data="")): + path = cuda_wrapper.find_loaded_library("libcudart") + self.assertIsNone(path) + + def test_cudart_check_raises_error(self): + """Test CUDART_CHECK raises RuntimeError for non-zero error codes""" + lib = MagicMock() + lib.cudaGetErrorString.return_value = b"mock error" + fake = cuda_wrapper.CudaRTLibrary.__new__(cuda_wrapper.CudaRTLibrary) + fake.funcs = {"cudaGetErrorString": lib.cudaGetErrorString} + + with self.assertRaises(RuntimeError): + fake.CUDART_CHECK(1) # Non-zero error code triggers exception + + @patch("fastdeploy.distributed.custom_all_reduce.cuda_wrapper.find_loaded_library") + @patch("ctypes.CDLL") + def test_cache_path_reuse(self, mock_cdll, mock_find_lib): + """Test library path caching is reused between instances""" + mock_find_lib.return_value = "/usr/local/cuda/lib64/libcudart.so" + mock_lib = MagicMock() + mock_cdll.return_value = mock_lib + + for func in cuda_wrapper.CudaRTLibrary.exported_functions: + setattr(mock_lib, func.name, MagicMock(return_value=0)) + mock_lib.cudaGetErrorString.return_value = b"ok" + + first = cuda_wrapper.CudaRTLibrary() + second = cuda_wrapper.CudaRTLibrary() + + self.assertIs(first.funcs, second.funcs) + self.assertIs(first.lib, second.lib) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/logger/test_handlers.py b/tests/logger/test_handlers.py index 3b0d32370..6880de98f 100644 --- a/tests/logger/test_handlers.py +++ b/tests/logger/test_handlers.py @@ -33,63 +33,63 @@ from fastdeploy.logger.handlers import ( class TestIntervalRotatingFileHandler(unittest.TestCase): def setUp(self): - # 创建临时目录 + # Create temporary directory self.temp_dir = tempfile.mkdtemp() self.base_filename = os.path.join(self.temp_dir, "test.log") def tearDown(self): - # 清理临时目录 + # Clean up temporary directory shutil.rmtree(self.temp_dir, ignore_errors=True) def test_initialization(self): - """测试初始化参数校验""" - # 测试无效interval + """Test initialization parameter validation""" + # Test invalid interval with self.assertRaises(ValueError): - handler = IntervalRotatingFileHandler(self.base_filename, interval=7) + IntervalRotatingFileHandler(self.base_filename, interval=7) - # 测试有效初始化 + # Test valid initialization handler = IntervalRotatingFileHandler(self.base_filename, interval=6, backupDays=3) self.assertEqual(handler.interval, 6) self.assertEqual(handler.backup_days, 3) handler.close() def test_file_rotation(self): - """测试日志文件滚动""" + """Test log file rotation mechanism""" handler = IntervalRotatingFileHandler(self.base_filename, interval=6, backupDays=1) - # 模拟初始状态 + # Get initial state initial_day = handler.current_day initial_hour = handler.current_hour - # 首次写入 + # First log write record = LogRecord("test", 20, "/path", 1, "Test message", [], None) handler.emit(record) - # 验证文件存在 + # Verify file existence expected_dir = Path(self.temp_dir) / initial_day expected_file = f"test_{initial_day}-{initial_hour:02d}.log" self.assertTrue((expected_dir / expected_file).exists()) - # 验证符号链接 + # Verify symlink creation symlink = Path(self.temp_dir) / "current_test.log" self.assertTrue(symlink.is_symlink()) handler.close() def test_time_based_rollover(self): - """测试基于时间的滚动触发""" + """Test time-based rollover triggers""" handler = IntervalRotatingFileHandler(self.base_filename, interval=1, backupDays=1) - # 强制设置初始时间 + # Force initial time settings handler.current_day = "2000-01-01" handler.current_hour = 0 - # 测试小时变化触发 + # Test hour change trigger with unittest.mock.patch.object(handler, "_get_current_day", return_value="2000-01-01"): with unittest.mock.patch.object(handler, "_get_current_hour", return_value=1): self.assertTrue(handler.shouldRollover(None)) - # 测试日期变化触发 + # Test day change trigger with unittest.mock.patch.object(handler, "_get_current_day", return_value="2000-01-02"): with unittest.mock.patch.object(handler, "_get_current_hour", return_value=0): self.assertTrue(handler.shouldRollover(None)) @@ -97,40 +97,38 @@ class TestIntervalRotatingFileHandler(unittest.TestCase): handler.close() def test_cleanup_logic(self): - """测试过期文件清理""" - # 使用固定测试时间 + """Test expired file cleanup mechanism""" + # Use fixed test time test_time = datetime(2023, 1, 1, 12, 0) with unittest.mock.patch("time.time", return_value=time.mktime(test_time.timetuple())): - handler = IntervalRotatingFileHandler(self.base_filename, interval=1, backupDays=0) # 立即清理 + handler = IntervalRotatingFileHandler(self.base_filename, interval=1, backupDays=0) # Clean immediately - # 创建测试目录结构 + # Create test directory structure old_day = (test_time - timedelta(days=2)).strftime("%Y-%m-%d") old_dir = Path(self.temp_dir) / old_day old_dir.mkdir() - # 创建测试文件 + # Create test file old_file = old_dir / f"test_{old_day}-00.log" old_file.write_text("test content") - # 确保文件时间戳正确 + # Ensure correct timestamps old_time = time.mktime((test_time - timedelta(days=2)).timetuple()) os.utime(str(old_dir), (old_time, old_time)) os.utime(str(old_file), (old_time, old_time)) - # 验证文件创建成功 + # Verify file creation self.assertTrue(old_file.exists()) - # 执行清理 + # Execute cleanup handler._clean_expired_data() - # 添加短暂延迟确保文件系统操作完成 + # Short delay for filesystem operations time.sleep(0.1) - # 验证清理结果 + # Verify cleanup result if old_dir.exists(): - # 调试输出:列出目录内容 print(f"Directory contents: {list(old_dir.glob('*'))}") - # 尝试强制删除以清理测试环境 try: shutil.rmtree(str(old_dir)) except Exception as e: @@ -144,7 +142,7 @@ class TestIntervalRotatingFileHandler(unittest.TestCase): handler.close() def test_multi_interval(self): - """测试多间隔配置""" + """Test multiple interval configurations""" for interval in [1, 2, 3, 4, 6, 8, 12, 24]: with self.subTest(interval=interval): handler = IntervalRotatingFileHandler(self.base_filename, interval=interval) @@ -154,35 +152,35 @@ class TestIntervalRotatingFileHandler(unittest.TestCase): handler.close() def test_utc_mode(self): - """测试UTC时间模式""" + """Test UTC time mode""" handler = IntervalRotatingFileHandler(self.base_filename, utc=True) self.assertTrue(time.strftime("%Y-%m-%d", time.gmtime()).startswith(handler.current_day)) handler.close() def test_symlink_creation(self): - """测试符号链接创建和更新""" + """Test symlink creation and updates""" handler = IntervalRotatingFileHandler(self.base_filename) symlink = Path(self.temp_dir) / "current_test.log" - # 获取初始符号链接目标 + # Get initial symlink target initial_target = os.readlink(str(symlink)) - # 强制触发滚动(模拟时间变化) + # Force rollover (simulate time change) with unittest.mock.patch.object(handler, "_get_current_day", return_value="2000-01-01"): with unittest.mock.patch.object(handler, "_get_current_hour", return_value=12): handler.doRollover() - # 获取新符号链接目标 + # Get new symlink target new_target = os.readlink(str(symlink)) - # 验证目标已更新 + # Verify target updated self.assertNotEqual(initial_target, new_target) self.assertIn("2000-01-01/test_2000-01-01-12.log", new_target) handler.close() class TestDailyRotatingFileHandler(unittest.TestCase): - """测试 DailyRotatingFileHandler""" + """Tests for DailyRotatingFileHandler""" def setUp(self): self.temp_dir = tempfile.mkdtemp(prefix="fd_handler_test_") @@ -191,112 +189,112 @@ class TestDailyRotatingFileHandler(unittest.TestCase): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_daily_rotation(self): - """测试每天滚动""" + """Test daily log rotation""" log_file = os.path.join(self.temp_dir, "test.log") handler = DailyRotatingFileHandler(log_file, backupCount=3) logger = getLogger("test_daily_rotation") logger.addHandler(handler) logger.setLevel(INFO) - # 写入第一条日志 + # Write first log logger.info("Test log message day 1") handler.flush() - # 模拟时间变化到第二天 + # Simulate time change to next day with patch.object(handler, "_compute_fn") as mock_compute: tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d") new_filename = f"test.log.{tomorrow}" mock_compute.return_value = new_filename - # 手动触发滚动检查和执行 + # Manually trigger rollover check mock_record = MagicMock() if handler.shouldRollover(mock_record): handler.doRollover() - # 写入第二条日志 + # Write second log logger.info("Test log message day 2") handler.flush() handler.close() - # 验证文件存在 + # Verify file existence today = datetime.now().strftime("%Y-%m-%d") tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d") - # 检查原始文件和带日期的文件 base_file = os.path.join(self.temp_dir, "test.log") today_file = os.path.join(self.temp_dir, f"test.log.{today}") tomorrow_file = os.path.join(self.temp_dir, f"test.log.{tomorrow}") - # 至少应该有一个文件存在 + # At least one file should exist files_exist = any([os.path.isfile(base_file), os.path.isfile(today_file), os.path.isfile(tomorrow_file)]) self.assertTrue(files_exist, f"No log files found in {self.temp_dir}") def test_backup_count(self): - """测试备份文件数量限制""" + """Test backup file count limitation""" log_file = os.path.join(self.temp_dir, "test.log") handler = DailyRotatingFileHandler(log_file, backupCount=2) logger = getLogger("test_backup_count") logger.addHandler(handler) logger.setLevel(INFO) - # 创建多个日期的日志文件 + # Create log files for multiple dates base_date = datetime.now() - for i in range(5): # 创建5天的日志 + for i in range(5): # Create 5 days of logs date_str = (base_date - timedelta(days=i)).strftime("%Y-%m-%d") test_file = os.path.join(self.temp_dir, f"test.log.{date_str}") - # 直接创建文件 + # Create file directly with open(test_file, "w") as f: f.write(f"Test log for {date_str}\n") - # 触发清理 + # Trigger cleanup handler.delete_expired_files() handler.close() - # 验证备份文件数量(应该保留最新的2个 + 当前文件) + # Verify backup count (should keep latest 2 + current file) log_files = [f for f in os.listdir(self.temp_dir) if f.startswith("test.log.")] - print(f"Log files found: {log_files}") # 调试输出 + print(f"Log files found: {log_files}") # Debug output - # backupCount=2 意味着应该最多保留2个备份文件 - self.assertLessEqual(len(log_files), 3) # 2个备份 + 可能的当前文件 + # backupCount=2 means max 2 backup files should remain + self.assertLessEqual(len(log_files), 3) # 2 backups + possible current file class TestLazyFileHandler(unittest.TestCase): def setUp(self): - # 创建临时目录 + # Create temporary directory self.tmpdir = tempfile.TemporaryDirectory() self.logfile = Path(self.tmpdir.name) / "test.log" def tearDown(self): - # 清理临时目录 + # Clean up temporary directory self.tmpdir.cleanup() def test_lazy_initialization_and_write(self): + """Test lazy initialization and log writing""" logger = logging.getLogger("test_lazy") logger.setLevel(logging.DEBUG) - # 初始化 LazyFileHandler + # Initialize LazyFileHandler handler = LazyFileHandler(str(self.logfile), backupCount=3, level=logging.DEBUG) logger.addHandler(handler) - # 此时 _real_handler 应该还没创建 + # _real_handler should not be created yet self.assertIsNone(handler._real_handler) - # 写一条日志 + # Write a log logger.info("Hello Lazy Handler") - # 写入后 _real_handler 应该被创建 + # _real_handler should be created after writing self.assertIsNotNone(handler._real_handler) - # 日志文件应该存在且内容包含日志信息 + # Log file should exist with correct content self.assertTrue(self.logfile.exists()) with open(self.logfile, "r") as f: content = f.read() self.assertIn("Hello Lazy Handler", content) - # 关闭 handler + # Close handler handler.close() logger.removeHandler(handler) diff --git a/tests/logger/test_logger.py b/tests/logger/test_logger.py index 1cb0f0441..fdb63740f 100644 --- a/tests/logger/test_logger.py +++ b/tests/logger/test_logger.py @@ -30,7 +30,7 @@ class LoggerTests(unittest.TestCase): self.env_patchers = [ patch("fastdeploy.envs.FD_LOG_DIR", self.tmp_dir), patch("fastdeploy.envs.FD_DEBUG", 0), - patch("fastdeploy.envs.FD_LOG_BACKUP_COUNT", "1"), + patch("fastdeploy.envs.FD_LOG_BACKUP_COUNT", 1), ] for p in self.env_patchers: p.start() @@ -77,5 +77,54 @@ class LoggerTests(unittest.TestCase): self.assertTrue(legacy_logger.propagate) +class LoggerExtraTests(unittest.TestCase): + def setUp(self): + self.logger = FastDeployLogger() + + def tearDown(self): + if hasattr(FastDeployLogger, "_instance"): + FastDeployLogger._instance = None + if hasattr(FastDeployLogger, "_initialized"): + FastDeployLogger._initialized = False + + def test_singleton_behavior(self): + """Ensure multiple instances are same""" + a = FastDeployLogger() + b = FastDeployLogger() + self.assertIs(a, b) + + def test_initialize_only_once(self): + """Ensure _initialize won't re-run if already initialized""" + self.logger._initialized = True + with patch("fastdeploy.logger.logger.setup_logging") as mock_setup: + self.logger._initialize() + mock_setup.assert_not_called() + + def test_get_logger_unified_path(self): + """Directly test get_logger unified path""" + with patch("fastdeploy.logger.logger.setup_logging") as mock_setup: + log = self.logger.get_logger("utils") + self.assertTrue(log.name.startswith("fastdeploy.")) + mock_setup.assert_called_once() + + def test_get_logger_legacy_path(self): + """Test legacy get_logger path""" + with patch("fastdeploy.logger.logger.FastDeployLogger._get_legacy_logger") as mock_legacy: + self.logger.get_logger("x", "y.log", False, False) + mock_legacy.assert_called_once() + + def test_get_legacy_logger_debug_mode(self): + """Ensure debug level is set when FD_DEBUG=1""" + with patch("fastdeploy.envs.FD_DEBUG", 1): + logger = self.logger._get_legacy_logger("debug_case", "d.log") + self.assertEqual(logger.level, logging.DEBUG) + + def test_get_legacy_logger_without_formatter(self): + """Test legacy logger without formatter""" + logger = self.logger._get_legacy_logger("nofmt", "n.log", without_formater=True) + for h in logger.handlers: + self.assertIsNone(h.formatter) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/logger/test_setup_logging.py b/tests/logger/test_setup_logging.py index c2fdb0994..e83c25964 100644 --- a/tests/logger/test_setup_logging.py +++ b/tests/logger/test_setup_logging.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json import logging import os @@ -26,13 +25,8 @@ from fastdeploy.logger.setup_logging import setup_logging class TestSetupLogging(unittest.TestCase): - - # ------------------------------------------------- - # 夹具:每个测试独占临时目录 - # ------------------------------------------------- def setUp(self): self.temp_dir = tempfile.mkdtemp(prefix="logger_setup_test_") - # 统一 patch 环境变量 self.patches = [ patch("fastdeploy.envs.FD_LOG_DIR", self.temp_dir), patch("fastdeploy.envs.FD_DEBUG", 0), @@ -43,92 +37,91 @@ class TestSetupLogging(unittest.TestCase): def tearDown(self): [p.stop() for p in self.patches] shutil.rmtree(self.temp_dir, ignore_errors=True) - # 清理单例标记,避免影响其他测试 if hasattr(setup_logging, "_configured"): delattr(setup_logging, "_configured") - # ------------------------------------------------- - # 基础:目录自动创建 - # ------------------------------------------------- def test_log_dir_created(self): nested = os.path.join(self.temp_dir, "a", "b", "c") setup_logging(log_dir=nested) self.assertTrue(Path(nested).is_dir()) - # ------------------------------------------------- - # 默认配置文件:文件 handler 不带颜色 - # ------------------------------------------------- - def test_default_config_file_no_ansi(self): - setup_logging() + def test_default_config_fallback(self): + """Pass a non-existent config_file to trigger default_config""" + fake_cfg = os.path.join(self.temp_dir, "no_such_cfg.json") + setup_logging(config_file=fake_cfg) logger = logging.getLogger("fastdeploy") - logger.error("test ansi") + self.assertTrue(logger.handlers) + handler_classes = [h.__class__.__name__ for h in logger.handlers] + self.assertIn("TimedRotatingFileHandler", handler_classes[0]) - default_file = Path(self.temp_dir) / "default.log" - self.assertTrue(default_file.exists()) - with default_file.open() as f: - content = f.read() - # 文件中不应出现 ANSI 转义 - self.assertNotIn("\033[", content) - - # ------------------------------------------------- - # 调试级别开关 - # ------------------------------------------------- - def test_debug_level(self): + def test_debug_level_affects_handlers(self): + """FD_DEBUG=1 should force DEBUG level""" with patch("fastdeploy.envs.FD_DEBUG", 1): - setup_logging() - logger = logging.getLogger("fastdeploy") - self.assertEqual(logger.level, logging.DEBUG) - # debug 消息应该能落到文件 - logger.debug("debug msg") - default_file = Path(self.temp_dir) / "default.log" - self.assertIn("debug msg", default_file.read_text()) + with patch("logging.config.dictConfig") as mock_cfg: + setup_logging() + called_config = mock_cfg.call_args[0][0] + for handler in called_config["handlers"].values(): + self.assertIn("formatter", handler) + self.assertEqual(called_config["handlers"]["console"]["level"], "DEBUG") - # ------------------------------------------------- - # 自定义 JSON 配置文件加载 - # ------------------------------------------------- - def test_custom_config_file(self): + @patch("logging.config.dictConfig") + def test_custom_config_with_dailyrotating_and_debug(self, mock_dict): custom_cfg = { "version": 1, - "disable_existing_loggers": False, - "formatters": {"plain": {"format": "%(message)s"}}, "handlers": { - "custom": { - "class": "logging.FileHandler", - "filename": os.path.join(self.temp_dir, "custom.log"), + "daily": { + "class": "logging.handlers.DailyRotatingFileHandler", + "level": "INFO", "formatter": "plain", } }, - "loggers": {"fastdeploy": {"handlers": ["custom"], "level": "INFO"}}, + "loggers": {"fastdeploy": {"handlers": ["daily"], "level": "INFO"}}, } cfg_path = Path(self.temp_dir) / "cfg.json" cfg_path.write_text(json.dumps(custom_cfg)) - setup_logging(config_file=str(cfg_path)) - logger = logging.getLogger("fastdeploy") - logger.info("from custom cfg") + with patch("fastdeploy.envs.FD_DEBUG", 1): + setup_logging(config_file=str(cfg_path)) - custom_file = Path(self.temp_dir) / "custom.log" - self.assertEqual(custom_file.read_text().strip(), "from custom cfg") + config_used = mock_dict.call_args[0][0] + self.assertIn("daily", config_used["handlers"]) + self.assertEqual(config_used["handlers"]["daily"]["level"], "DEBUG") + self.assertIn("backupCount", config_used["handlers"]["daily"]) - # ------------------------------------------------- - # 重复调用 setup_logging 不会重复配置 - # ------------------------------------------------- def test_configure_once(self): - logger1 = setup_logging() - logger2 = setup_logging() - self.assertIs(logger1, logger2) + """Ensure idempotent setup""" + l1 = setup_logging() + l2 = setup_logging() + self.assertIs(l1, l2) + + def test_envs_priority_used_for_log_dir(self): + """When log_dir=None, should use envs.FD_LOG_DIR""" + with patch("fastdeploy.envs.FD_LOG_DIR", self.temp_dir): + setup_logging() + self.assertTrue(os.path.exists(self.temp_dir)) - # ------------------------------------------------- - # 控制台 handler 使用 ColoredFormatter - # ------------------------------------------------- @patch("logging.StreamHandler.emit") def test_console_colored(self, mock_emit): setup_logging() logger = logging.getLogger("fastdeploy") logger.error("color test") - # 只要 ColoredFormatter 被实例化即可,简单断言 emit 被调用 self.assertTrue(mock_emit.called) + @patch("logging.config.dictConfig") + def test_backup_count_merging(self, mock_dict): + custom_cfg = { + "version": 1, + "handlers": {"daily": {"class": "logging.handlers.DailyRotatingFileHandler", "formatter": "plain"}}, + "loggers": {"fastdeploy": {"handlers": ["daily"], "level": "INFO"}}, + } + cfg_path = Path(self.temp_dir) / "cfg.json" + cfg_path.write_text(json.dumps(custom_cfg)) + + setup_logging(config_file=str(cfg_path)) + + config_used = mock_dict.call_args[0][0] + self.assertEqual(config_used["handlers"]["daily"]["backupCount"], 3) + if __name__ == "__main__": unittest.main()