diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 46a701f14..b60029019 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -801,7 +801,6 @@ class AsyncLLMEngine: f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" f" --engine_worker_queue_port {ports}" f" --pod_ip {self.cfg.master_ip}" - f" --total_block_num {self.cfg.cache_config.total_block_num}" f" --block_size {self.cfg.cache_config.block_size}" f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" f" --eos_tokens_lens {self.data_processor.eos_token_id_len}" @@ -833,7 +832,7 @@ class AsyncLLMEngine: f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" ) - worker_append_flag = { + worker_store_true_flag = { "enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel, "enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching, "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill, @@ -844,9 +843,17 @@ class AsyncLLMEngine: "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, } - for worker_flag, value in worker_append_flag.items(): + for worker_flag, value in worker_store_true_flag.items(): if value: arguments = arguments + f" --{worker_flag}" + + worker_default_none_flag = { + "num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override, + } + for worker_flag, value in worker_default_none_flag.items(): + if value: + arguments = arguments + f" --{worker_flag} {value}" + if self.cfg.nnode > 1: pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}" pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" diff --git a/tests/engine/test_async_llm.py b/tests/engine/test_async_llm.py index f8cc50cd1..19a2ad3ea 100644 --- a/tests/engine/test_async_llm.py +++ b/tests/engine/test_async_llm.py @@ -64,7 +64,7 @@ class TestAsyncLLMEngine(unittest.TestCase): except Exception as e: print(f"Setting up AsyncLLMEngine failed: {e}") - raise unittest.SkipTest(f"AsyncLLMEngine initialization failed: {e}") + raise @classmethod def tearDownClass(cls):