mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Optimize prefix cache (#3208)
* [LLM] support ep * Update worker_process.py * Update expert_service.py * Update worker_process.py * format files * optimize prefix cache * optimize prefix cache * optimize prefix cache * pre commit format * pre commit format * pre commit format * Update cache_messager.py
This commit is contained in:
@@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cuda_host_alloc,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -39,26 +39,12 @@ def parse_args():
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache transfer manager")
|
||||
parser.add_argument(
|
||||
"--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now",
|
||||
)
|
||||
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument(
|
||||
@@ -68,7 +54,6 @@ def parse_args():
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
@@ -109,7 +94,6 @@ class CacheTransferManager:
|
||||
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
self.gpu_cache_kvs = {}
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
@@ -138,40 +122,27 @@ class CacheTransferManager:
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
cache_type = args.cache_dtype
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape = [
|
||||
args.num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
]
|
||||
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
for i in range(args.num_hidden_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape[0] = num_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{rank}.device{device}"
|
||||
value_name = f"value_caches_{i}_rank{rank}.device{device}"
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
value_cache = share_external_data(value_cache, value_name, cache_shape)
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_kvs[value_name] = value_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
|
||||
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
@@ -180,7 +151,7 @@ class CacheTransferManager:
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
for i in range(args.num_hidden_layers + self.num_extra_layers):
|
||||
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
||||
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
)
|
||||
@@ -190,38 +161,6 @@ class CacheTransferManager:
|
||||
)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
if args.enable_splitwise:
|
||||
logger.debug("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
from fastdeploy.cache_manager.cache_messager import CacheMessager
|
||||
|
||||
self.cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=self.gpu_cache_kvs,
|
||||
rank=self.rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + self.num_extra_layers,
|
||||
gpu_id=self.device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
logger.info("successfully create cache messager")
|
||||
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
name="cache_task_broadcast_signal",
|
||||
|
Reference in New Issue
Block a user