[Feature] Set prefix caching as default (#3814)

* Set prefix caching as default

* Set prefix caching as default

* Set prefix caching as default

* skip dynamic load scene

* fix kill bug

* fix kill bug

* fix kill bug

* fix

* fix

* fix ci
This commit is contained in:
chenjian
2025-09-16 20:34:27 +08:00
committed by GitHub
parent de8638b1e9
commit 67e6d8c691
5 changed files with 23 additions and 8 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import argparse
import json
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
@@ -190,7 +191,7 @@ class EngineArgs:
"""
Flag to indicate whether to use warm-up before inference.
"""
enable_prefix_caching: bool = False
enable_prefix_caching: bool = True
"""
Flag to enable prefix caching.
"""
@@ -387,6 +388,16 @@ class EngineArgs:
"""
if not self.tokenizer:
self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if self.enable_mm:
self.enable_prefix_caching = False
if not current_platform.is_cuda():
self.enable_prefix_caching = False
if self.dynamic_load_weight:
self.enable_prefix_caching = False
if self.enable_logprob:
if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.")
@@ -725,7 +736,7 @@ class EngineArgs:
perf_group = parser.add_argument_group("Performance Tuning")
perf_group.add_argument(
"--enable-prefix-caching",
action="store_true",
action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching.",
)

View File

@@ -369,7 +369,8 @@ class LLMEngine:
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
pgid = os.getpgid(p.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
console_logger.error(
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
@@ -381,7 +382,8 @@ class LLMEngine:
self.get_profile_block_num_signal.clear()
if hasattr(self, "worker_proc") and self.worker_proc is not None:
try:
os.killpg(self.worker_proc.pid, signal.SIGTERM)
pgid = os.getpgid(self.worker_proc.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")

View File

@@ -32,6 +32,7 @@ for file in $TEST_FILES; do
else
success_pytest=$((success_pytest+1))
fi
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk '{print $2}' | xargs -r kill -9
done
##################################

View File

@@ -29,6 +29,7 @@ FD_ENGINE_QUEUE_PORTS = [
[9991, 9992],
]
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
FD_CACHE_QUEUE_PORTS = [FD_CACHE_QUEUE_PORT, FD_CACHE_QUEUE_PORT + 1, FD_CACHE_QUEUE_PORT + 2, FD_CACHE_QUEUE_PORT + 3]
models = [
@@ -54,7 +55,7 @@ def llm(request):
max_model_len=8192,
num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
cache_queue_port=FD_CACHE_QUEUE_PORT,
cache_queue_port=FD_CACHE_QUEUE_PORTS[port_index],
load_choices="default",
enable_expert_parallel=True,
)

View File

@@ -30,8 +30,8 @@ def test_normal_schedule():
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
)
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3199, "prompt_token_ids_len": 3199})
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3201, "prompt_token_ids_len": 3201})
req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [2] * 3201, "prompt_token_ids_len": 3201})
req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [3] * 3200, "prompt_token_ids_len": 3200})
resource_manager_v1.add_request(req1)
resource_manager_v1.add_request(req2)
resource_manager_v1.add_request(req3)
@@ -93,7 +93,7 @@ def test_preempted_request():
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
)
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [2] * 3200, "prompt_token_ids_len": 3200})
resource_manager_v1.add_request(req1)
resource_manager_v1.add_request(req2)
# step 1