[Cherry-Pick][BugFix] cp fix_cpu_cache_bugs(#5544) (#5577)

* cp fix_cpu_cache_bugs

* update ce case

* update test case

* update code
This commit is contained in:
kevin
2025-12-19 11:48:50 +08:00
committed by GitHub
parent 2aa88d3621
commit 23bfd28624
8 changed files with 14 additions and 5 deletions

View File

@@ -209,7 +209,7 @@ jobs:
export TEMPLATE=TOKEN_NORMAL
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-VL-28B-A3B-Thinking\", \"--reasoning-parser\": \"ernie-45-vl-thinking\", \"--tool-call-parser\": \"ernie-45-vl-thinking\", \"--tensor-parallel-size\": 1, \"--quantization\": \"wint4\", \"--max-model-len\": 131072, \"--max-num-seqs\": 32}"
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-VL-28B-A3B-Thinking\", \"--reasoning-parser\": \"ernie-45-vl-thinking\", \"--tool-call-parser\": \"ernie-45-vl-thinking\", \"--tensor-parallel-size\": 1, \"--quantization\": \"wint4\", \"--max-model-len\": 131072, \"--max-num-seqs\": 32, \"--no-enable-prefix-caching\": true}"
check_service 90
python -m pytest -sv test_prompt_ids.py || TEST_EXIT_CODE=1

View File

@@ -85,6 +85,13 @@ def parse_args():
default="ipc",
help="cache transfer protocol, only support ipc now",
)
parser.add_argument(
"--default_dtype",
type=str,
default="bfloat16",
choices=["float16", "bfloat16", "uint8"],
help="paddle default dtype, swap_cache_batch only support float16、bfloat16 and uint8 now",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument(
@@ -125,6 +132,7 @@ class CacheTransferManager:
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
paddle.set_default_dtype(args.default_dtype)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务

View File

@@ -275,6 +275,7 @@ class PrefixCacheManager:
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --engine_pid {pid_suffix}"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"

View File

@@ -1752,9 +1752,6 @@ class FDConfig:
else:
# It will hang when real batch_size < tp_size
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
if self.model_config.enable_mm and self.graph_opt_config.use_cudagraph:
self.cache_config.enable_prefix_caching = False
logger.info("Multi-modal models do not support prefix caching when using CUDAGraph!")
if self.scheduler_config.splitwise_role == "mixed":
self._disable_sequence_parallel_moe_if_needed("Mixed")

View File

@@ -26,6 +26,7 @@ class Args:
value_cache_shape = ""
create_cache_tensor = False
cache_dtype = "bfloat16"
default_dtype = "bfloat16"
# ==========================

View File

@@ -89,7 +89,7 @@ def build_command(config):
# 添加配置参数
for key, value in config.items():
if "--enable" in key:
if "--enable" in key or "--no-enable" in key:
value = bool(value if isinstance(value, bool) else eval(value))
if value:
cmd.append(key)

View File

@@ -80,6 +80,7 @@ def setup_and_run_server():
'{"cudagraph_capture_sizes": [1], "use_cudagraph":true}',
"--routing-replay-config",
'{"enable_routing_replay":true, "routing_store_type":"local", "local_store_dir":"./routing_replay_output"}',
"--no-enable-prefix-caching",
]
# Start subprocess in new process group

View File

@@ -72,6 +72,7 @@ def setup_and_run_server():
"128",
"--limit-mm-per-prompt",
limit_mm_str,
"--no-enable-prefix-caching",
]
print(cmd)