[Feature] Set prefix caching as default (#3814)

* Set prefix caching as default

* Set prefix caching as default

* Set prefix caching as default

* skip dynamic load scene

* fix kill bug

* fix kill bug

* fix kill bug

* fix

* fix

* fix ci
This commit is contained in:
chenjian
2025-09-16 20:34:27 +08:00
committed by GitHub
parent de8638b1e9
commit 67e6d8c691
5 changed files with 23 additions and 8 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import argparse
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields from dataclasses import fields as dataclass_fields
@@ -190,7 +191,7 @@ class EngineArgs:
""" """
Flag to indicate whether to use warm-up before inference. Flag to indicate whether to use warm-up before inference.
""" """
enable_prefix_caching: bool = False enable_prefix_caching: bool = True
""" """
Flag to enable prefix caching. Flag to enable prefix caching.
""" """
@@ -387,6 +388,16 @@ class EngineArgs:
""" """
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if self.enable_mm:
self.enable_prefix_caching = False
if not current_platform.is_cuda():
self.enable_prefix_caching = False
if self.dynamic_load_weight:
self.enable_prefix_caching = False
if self.enable_logprob: if self.enable_logprob:
if self.speculative_config is not None: if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.") raise NotImplementedError("Logprob does not support speculation_config.")
@@ -725,7 +736,7 @@ class EngineArgs:
perf_group = parser.add_argument_group("Performance Tuning") perf_group = parser.add_argument_group("Performance Tuning")
perf_group.add_argument( perf_group.add_argument(
"--enable-prefix-caching", "--enable-prefix-caching",
action="store_true", action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching, default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching.", help="Flag to enable prefix caching.",
) )

View File

@@ -369,7 +369,8 @@ class LLMEngine:
for p in self.cache_manager_processes: for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}") llm_logger.info(f"Killing cache manager process {p.pid}")
try: try:
os.killpg(p.pid, signal.SIGTERM) pgid = os.getpgid(p.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e: except Exception as e:
console_logger.error( console_logger.error(
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}" f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
@@ -381,7 +382,8 @@ class LLMEngine:
self.get_profile_block_num_signal.clear() self.get_profile_block_num_signal.clear()
if hasattr(self, "worker_proc") and self.worker_proc is not None: if hasattr(self, "worker_proc") and self.worker_proc is not None:
try: try:
os.killpg(self.worker_proc.pid, signal.SIGTERM) pgid = os.getpgid(self.worker_proc.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e: except Exception as e:
console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")

View File

@@ -32,6 +32,7 @@ for file in $TEST_FILES; do
else else
success_pytest=$((success_pytest+1)) success_pytest=$((success_pytest+1))
fi fi
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk '{print $2}' | xargs -r kill -9
done done
################################## ##################################

View File

@@ -29,6 +29,7 @@ FD_ENGINE_QUEUE_PORTS = [
[9991, 9992], [9991, 9992],
] ]
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
FD_CACHE_QUEUE_PORTS = [FD_CACHE_QUEUE_PORT, FD_CACHE_QUEUE_PORT + 1, FD_CACHE_QUEUE_PORT + 2, FD_CACHE_QUEUE_PORT + 3]
models = [ models = [
@@ -54,7 +55,7 @@ def llm(request):
max_model_len=8192, max_model_len=8192,
num_gpu_blocks_override=1024, num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index], engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
cache_queue_port=FD_CACHE_QUEUE_PORT, cache_queue_port=FD_CACHE_QUEUE_PORTS[port_index],
load_choices="default", load_choices="default",
enable_expert_parallel=True, enable_expert_parallel=True,
) )

View File

@@ -30,8 +30,8 @@ def test_normal_schedule():
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
) )
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3199, "prompt_token_ids_len": 3199}) req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3199, "prompt_token_ids_len": 3199})
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3201, "prompt_token_ids_len": 3201}) req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [2] * 3201, "prompt_token_ids_len": 3201})
req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [3] * 3200, "prompt_token_ids_len": 3200})
resource_manager_v1.add_request(req1) resource_manager_v1.add_request(req1)
resource_manager_v1.add_request(req2) resource_manager_v1.add_request(req2)
resource_manager_v1.add_request(req3) resource_manager_v1.add_request(req3)
@@ -93,7 +93,7 @@ def test_preempted_request():
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
) )
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [2] * 3200, "prompt_token_ids_len": 3200})
resource_manager_v1.add_request(req1) resource_manager_v1.add_request(req1)
resource_manager_v1.add_request(req2) resource_manager_v1.add_request(req2)
# step 1 # step 1