[Feature] Enable prefix caching as default (#3816)

* [Feature] Enable prefix caching as default

* [Feature] Enable prefix caching as default

* Set prefix caching as default

* skip dynamic load

* fix kill bug

* fix kill bug

* fix kill bug

* fix ci

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-06 09:51:34 +08:00
committed by GitHub
parent 11b18e5ef0
commit 41cd3e24c9
6 changed files with 37 additions and 5 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import argparse
import json
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
@@ -190,7 +191,7 @@ class EngineArgs:
"""
Flag to indicate whether to use warm-up before inference.
"""
enable_prefix_caching: bool = False
enable_prefix_caching: bool = True
"""
Flag to enable prefix caching.
"""
@@ -387,6 +388,16 @@ class EngineArgs:
"""
if not self.tokenizer:
self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if self.enable_mm:
self.enable_prefix_caching = False
if not current_platform.is_cuda():
self.enable_prefix_caching = False
if self.dynamic_load_weight:
self.enable_prefix_caching = False
if self.enable_logprob:
if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.")
@@ -725,7 +736,7 @@ class EngineArgs:
perf_group = parser.add_argument_group("Performance Tuning")
perf_group.add_argument(
"--enable-prefix-caching",
action="store_true",
action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching.",
)

View File

@@ -342,7 +342,8 @@ class LLMEngine:
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
pgid = os.getpgid(p.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
console_logger.error(
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"

View File

@@ -221,6 +221,7 @@ class GPUModelRunner(ModelRunnerBase):
req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
has_preempted_task = False
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
@@ -320,6 +321,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
has_preempted_task = True
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
@@ -375,6 +377,10 @@ class GPUModelRunner(ModelRunnerBase):
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
if has_preempted_task:
self.share_inputs["not_need_stop"][0] = not (
self.share_inputs["stop_flags"].sum() == self.parallel_config.max_num_seqs
)
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):

View File

@@ -32,6 +32,7 @@ for file in $TEST_FILES; do
else
success_pytest=$((success_pytest+1))
fi
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk '{print $2}' | xargs -r kill -9
done
##################################

View File

@@ -27,7 +27,7 @@ for subdir in "$run_path"*/; do
timeout 600 python -m pytest --disable-warnings -sv "$file"
exit_code=$?
set -e
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk '{print $2}' | xargs -r kill -9
if [ $exit_code -ne 0 ]; then
if [ -f "${subdir%/}/log/workerlog.0" ]; then
echo "---------------- log/workerlog.0 -------------------"

View File

@@ -181,6 +181,19 @@ def stop_server(signum=None, frame=None):
except Exception as e:
print(f"Failed to stop server: {e}, {str(traceback.format_exc())}")
try:
result = subprocess.run(
f"ps -ef -ww | grep {FD_CACHE_QUEUE_PORT} | grep -v grep", shell=True, capture_output=True, text=True
)
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split()
pid = int(parts[1]) # ps -ef 的第二列是 PID
print(f"Killing PID: {pid}")
os.kill(pid, signal.SIGKILL)
except Exception as e:
print(f"Failed to kill cache manager process: {e}, {str(traceback.format_exc())}")
for port in [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]:
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
@@ -285,7 +298,7 @@ def start_service():
def switch_service():
"""切换模型服务"""
# kill掉已有服务
stop_server()
res, status_code = stop_server()
time.sleep(2)
try: