[Feature] Optimize prefix cache (#3208)

* [LLM] support ep

* Update worker_process.py

* Update expert_service.py

* Update worker_process.py

* format files

* optimize prefix cache

* optimize prefix cache

* optimize prefix cache

* pre commit format

* pre commit format

* pre commit format

* Update cache_messager.py
This commit is contained in:
ltd0924
2025-08-05 17:13:11 +08:00
committed by GitHub
parent 9f9971844f
commit dcf9c2daff
7 changed files with 314 additions and 147 deletions

View File

@@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger
@@ -39,26 +39,12 @@ def parse_args():
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
@@ -68,7 +54,6 @@ def parse_args():
help="engine worker queue port",
)
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
@@ -109,7 +94,6 @@ class CacheTransferManager:
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
@@ -138,40 +122,27 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [
args.num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
]
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
for i in range(args.num_hidden_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
cache_shape[0] = num_gpu_blocks
key_name = f"key_caches_{i}_rank{rank}.device{device}"
value_name = f"value_caches_{i}_rank{rank}.device{device}"
key_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_name, cache_shape)
value_cache = share_external_data(value_cache, value_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[value_name] = value_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -180,7 +151,7 @@ class CacheTransferManager:
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
for i in range(args.num_hidden_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
@@ -190,38 +161,6 @@ class CacheTransferManager:
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager
self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal",