diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4f0ee57cb..2d1532ae5 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import argparse import json from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields @@ -190,7 +191,7 @@ class EngineArgs: """ Flag to indicate whether to use warm-up before inference. """ - enable_prefix_caching: bool = False + enable_prefix_caching: bool = True """ Flag to enable prefix caching. """ @@ -387,6 +388,16 @@ class EngineArgs: """ if not self.tokenizer: self.tokenizer = self.model + if self.splitwise_role == "decode": + self.enable_prefix_caching = False + if self.speculative_config is not None: + self.enable_prefix_caching = False + if self.enable_mm: + self.enable_prefix_caching = False + if not current_platform.is_cuda(): + self.enable_prefix_caching = False + if self.dynamic_load_weight: + self.enable_prefix_caching = False if self.enable_logprob: if self.speculative_config is not None: raise NotImplementedError("Logprob does not support speculation_config.") @@ -725,7 +736,7 @@ class EngineArgs: perf_group = parser.add_argument_group("Performance Tuning") perf_group.add_argument( "--enable-prefix-caching", - action="store_true", + action=argparse.BooleanOptionalAction, default=EngineArgs.enable_prefix_caching, help="Flag to enable prefix caching.", ) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 9109cc7b6..623dec456 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -369,7 +369,8 @@ class LLMEngine: for p in self.cache_manager_processes: llm_logger.info(f"Killing cache manager process {p.pid}") try: - os.killpg(p.pid, signal.SIGTERM) + pgid = os.getpgid(p.pid) + os.killpg(pgid, signal.SIGTERM) except Exception as e: console_logger.error( f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}" @@ -381,7 +382,8 @@ class LLMEngine: self.get_profile_block_num_signal.clear() if hasattr(self, "worker_proc") and self.worker_proc is not None: try: - os.killpg(self.worker_proc.pid, signal.SIGTERM) + pgid = os.getpgid(self.worker_proc.pid) + os.killpg(pgid, signal.SIGTERM) except Exception as e: console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") diff --git a/scripts/coverage_run.sh b/scripts/coverage_run.sh index 23550ac29..da539e97a 100644 --- a/scripts/coverage_run.sh +++ b/scripts/coverage_run.sh @@ -32,6 +32,7 @@ for file in $TEST_FILES; do else success_pytest=$((success_pytest+1)) fi + ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk '{print $2}' | xargs -r kill -9 done ################################## diff --git a/tests/model_loader/test_w4a8_model.py b/tests/model_loader/test_w4a8_model.py index 3521cca3d..7b55c7aac 100644 --- a/tests/model_loader/test_w4a8_model.py +++ b/tests/model_loader/test_w4a8_model.py @@ -29,6 +29,7 @@ FD_ENGINE_QUEUE_PORTS = [ [9991, 9992], ] FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) +FD_CACHE_QUEUE_PORTS = [FD_CACHE_QUEUE_PORT, FD_CACHE_QUEUE_PORT + 1, FD_CACHE_QUEUE_PORT + 2, FD_CACHE_QUEUE_PORT + 3] models = [ @@ -54,7 +55,7 @@ def llm(request): max_model_len=8192, num_gpu_blocks_override=1024, engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index], - cache_queue_port=FD_CACHE_QUEUE_PORT, + cache_queue_port=FD_CACHE_QUEUE_PORTS[port_index], load_choices="default", enable_expert_parallel=True, ) diff --git a/tests/v1/test_schedule_output.py b/tests/v1/test_schedule_output.py index ffe9432c3..d99a4ce76 100644 --- a/tests/v1/test_schedule_output.py +++ b/tests/v1/test_schedule_output.py @@ -30,8 +30,8 @@ def test_normal_schedule(): max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" ) req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3199, "prompt_token_ids_len": 3199}) - req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3201, "prompt_token_ids_len": 3201}) - req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [2] * 3201, "prompt_token_ids_len": 3201}) + req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [3] * 3200, "prompt_token_ids_len": 3200}) resource_manager_v1.add_request(req1) resource_manager_v1.add_request(req2) resource_manager_v1.add_request(req3) @@ -93,7 +93,7 @@ def test_preempted_request(): max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" ) req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) - req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [2] * 3200, "prompt_token_ids_len": 3200}) resource_manager_v1.add_request(req1) resource_manager_v1.add_request(req2) # step 1