mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 21:32:29 +08:00

* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
415 lines
15 KiB
Python
415 lines
15 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import concurrent.futures
|
|
import json
|
|
import queue
|
|
import time
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.cache_manager.cache_data import CacheStatus
|
|
from fastdeploy.config import SpeculativeConfig
|
|
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
|
from fastdeploy.model_executor.ops.gpu import (
|
|
cuda_host_alloc,
|
|
share_external_data,
|
|
swap_cache_all_layers,
|
|
)
|
|
from fastdeploy.utils import get_logger
|
|
|
|
|
|
def parse_args():
|
|
"""
|
|
从命令行解析参数
|
|
"""
|
|
parser = argparse.ArgumentParser("Cache transfer manager")
|
|
parser.add_argument(
|
|
"--splitwise_role",
|
|
type=str,
|
|
default="mixed",
|
|
help="splitwise role, can be decode, prefill or mixed",
|
|
)
|
|
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
|
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
|
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
|
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
|
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
|
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
|
parser.add_argument(
|
|
"--protocol",
|
|
type=str,
|
|
default="ipc",
|
|
help="cache transfer protocol, only support ipc now",
|
|
)
|
|
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
|
|
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
|
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
|
parser.add_argument(
|
|
"--engine_worker_queue_port",
|
|
type=int,
|
|
default=9923,
|
|
help="engine worker queue port",
|
|
)
|
|
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
|
|
|
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
|
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
|
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
|
parser.add_argument(
|
|
"--bytes_per_layer_per_block",
|
|
type=int,
|
|
default=1024,
|
|
help="per layer per block bytes",
|
|
)
|
|
parser.add_argument(
|
|
"--cache_dtype",
|
|
type=str,
|
|
default="bfloat16",
|
|
choices=["uint8", "bfloat16"],
|
|
help="cache dtype",
|
|
)
|
|
parser.add_argument(
|
|
"--speculative_config",
|
|
type=json.loads,
|
|
default="{}",
|
|
help="speculative config",
|
|
)
|
|
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
class CacheTransferManager:
|
|
"""
|
|
管理CPU和GPU之间缓存的交换传输
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
"""
|
|
初始化CacheTransferManager
|
|
"""
|
|
|
|
device = args.device_id
|
|
rank = args.rank
|
|
paddle.set_device(f"gpu:{device}")
|
|
self.gpu_cache_kvs = {}
|
|
self.cpu_cache_kvs = {}
|
|
self.gpu_cache_k_tensors = []
|
|
self.gpu_cache_v_tensors = []
|
|
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
|
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
|
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
|
|
|
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
|
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
|
self.n_ranks = args.mp_num
|
|
self.rank = rank
|
|
self.device = device
|
|
|
|
address = (args.pod_ip, args.cache_queue_port)
|
|
self.cache_task_queue = EngineCacheQueue(
|
|
address=address,
|
|
is_server=False,
|
|
num_client=args.mp_num,
|
|
client_id=rank,
|
|
local_data_parallel_id=args.local_data_parallel_id,
|
|
)
|
|
|
|
self.num_cpu_blocks = args.num_cpu_blocks
|
|
|
|
cache_type = args.cache_dtype
|
|
cache_shape = [
|
|
args.num_gpu_blocks,
|
|
args.kv_num_head,
|
|
args.block_size,
|
|
args.head_dim,
|
|
]
|
|
|
|
for i in range(args.num_layers + self.num_extra_layers):
|
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
|
cache_shape[0] = num_gpu_blocks
|
|
key_name = f"key_caches_{i}_rank{rank}.device{device}"
|
|
value_name = f"value_caches_{i}_rank{rank}.device{device}"
|
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
|
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
|
value_cache = share_external_data(value_cache, value_name, cache_shape)
|
|
self.gpu_cache_kvs[key_name] = key_cache
|
|
self.gpu_cache_kvs[value_name] = value_cache
|
|
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
|
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
|
|
|
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
|
logger.info(f"device :{self.device}")
|
|
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
|
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
|
|
|
paddle.set_device("cpu")
|
|
self.k_dst_ptrs = []
|
|
self.v_dst_ptrs = []
|
|
for i in range(args.num_layers + self.num_extra_layers):
|
|
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
|
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
|
)
|
|
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
|
|
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
|
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
|
)
|
|
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
|
|
|
|
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
|
self.cache_ready_signal = IPCSignal(
|
|
name="cache_ready_signal",
|
|
array=cache_ready_signal_data,
|
|
dtype=np.int32,
|
|
suffix=args.engine_pid,
|
|
create=False,
|
|
)
|
|
self.cache_ready_signal.value[self.rank] = 1
|
|
|
|
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
|
self.cache_task_broadcast_signal = IPCSignal(
|
|
name="cache_task_broadcast_signal",
|
|
array=cache_task_broadcast_data,
|
|
dtype=np.int32,
|
|
suffix=args.engine_pid,
|
|
create=False,
|
|
)
|
|
|
|
def _do_swap_to_cpu_task(
|
|
self,
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
):
|
|
"""
|
|
swap cache GPU->CPU
|
|
"""
|
|
self.cache_task_queue.swap_to_cpu_barrier1.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.swap_to_cpu_barrier1.reset()
|
|
result = self._transfer_data(
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
)
|
|
self.cache_task_queue.swap_to_cpu_barrier2.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.swap_to_cpu_barrier2.reset()
|
|
self.cache_task_queue.put_transfer_done_signal(result)
|
|
logger.debug(f"_do_swap_to_cpu_task: put_transfer_done_signal {result}")
|
|
logger.info(f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
|
|
|
|
def _do_swap_to_gpu_task(
|
|
self,
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
):
|
|
"""
|
|
swap cache CPU->GPU
|
|
"""
|
|
self.cache_task_queue.swap_to_gpu_barrier1.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.swap_to_gpu_barrier1.reset()
|
|
result = self._transfer_data(
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
)
|
|
self.cache_task_queue.swap_to_gpu_barrier2.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.swap_to_gpu_barrier2.reset()
|
|
self.cache_task_queue.put_transfer_done_signal(result)
|
|
logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
|
|
logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
|
|
|
|
def do_data_transfer(self):
|
|
"""
|
|
do data transfer task
|
|
"""
|
|
while True:
|
|
try:
|
|
if self.rank == 0:
|
|
if not self.cache_task_queue.empty():
|
|
self.cache_task_broadcast_signal.value[0] = 1
|
|
if self.n_ranks > 1:
|
|
self.cache_task_queue.barrier1.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.barrier1.reset()
|
|
if self.cache_task_broadcast_signal.value[0] == 1:
|
|
data, read_finish = self.cache_task_queue.get_transfer_task()
|
|
logger.debug(f"transfer data: get_transfer_task {data}")
|
|
if read_finish:
|
|
self.cache_task_broadcast_signal.value[0] = 0
|
|
(
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
) = data
|
|
if event_type.value == CacheStatus.SWAP2CPU.value:
|
|
self.swap_to_cpu_thread_pool.submit(
|
|
self._do_swap_to_cpu_task,
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
)
|
|
else:
|
|
self.swap_to_gpu_thread_pool.submit(
|
|
self._do_swap_to_gpu_task,
|
|
swap_node_ids,
|
|
gpu_block_id,
|
|
cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
)
|
|
else:
|
|
if self.n_ranks > 1:
|
|
self.cache_task_queue.barrier2.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.barrier2.reset()
|
|
continue
|
|
|
|
if self.n_ranks > 1:
|
|
self.cache_task_queue.barrier3.wait()
|
|
if self.rank == 0:
|
|
self.cache_task_queue.barrier3.reset()
|
|
except Exception as e:
|
|
logger.info(f"do_data_transfer: error: {e}, {str(traceback.format_exc())}")
|
|
|
|
def _transfer_data(
|
|
self,
|
|
swap_node_ids,
|
|
task_gpu_block_id,
|
|
task_cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
):
|
|
"""
|
|
transfer data
|
|
task_gpu_block_id format: [[block_id0, [fold_block_id0, fold_block_id1]],
|
|
[block_id1, [fold_block_id0, fold_block_id1]], ...]
|
|
"""
|
|
logger.debug(
|
|
f"transfer data: transfer_task_id {transfer_task_id}: swap_node_ids {swap_node_ids}"
|
|
+ f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
|
|
)
|
|
start_time = time.time()
|
|
try:
|
|
# transform block id
|
|
assert len(task_gpu_block_id) == len(task_cpu_block_id)
|
|
gpu_block_ids = task_gpu_block_id
|
|
cpu_block_ids = task_cpu_block_id
|
|
|
|
if event_type.value == CacheStatus.SWAP2CPU.value:
|
|
swap_cache_all_layers(
|
|
self.gpu_cache_k_tensors,
|
|
self.k_dst_ptrs,
|
|
self.num_cpu_blocks,
|
|
gpu_block_ids,
|
|
cpu_block_ids,
|
|
self.device,
|
|
0,
|
|
)
|
|
swap_cache_all_layers(
|
|
self.gpu_cache_v_tensors,
|
|
self.v_dst_ptrs,
|
|
self.num_cpu_blocks,
|
|
gpu_block_ids,
|
|
cpu_block_ids,
|
|
self.device,
|
|
0,
|
|
)
|
|
|
|
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
|
swap_cache_all_layers(
|
|
self.gpu_cache_k_tensors,
|
|
self.k_dst_ptrs,
|
|
self.num_cpu_blocks,
|
|
gpu_block_ids,
|
|
cpu_block_ids,
|
|
self.device,
|
|
1,
|
|
)
|
|
swap_cache_all_layers(
|
|
self.gpu_cache_v_tensors,
|
|
self.v_dst_ptrs,
|
|
self.num_cpu_blocks,
|
|
gpu_block_ids,
|
|
cpu_block_ids,
|
|
self.device,
|
|
1,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"transfer data: error: {e}")
|
|
raise e
|
|
end_time = time.time()
|
|
elasped_time = end_time - start_time
|
|
logger.info(
|
|
f"transfer data: transfer_task_id {transfer_task_id} event_type {event_type}: "
|
|
+ f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
|
|
)
|
|
return (
|
|
swap_node_ids,
|
|
task_gpu_block_id,
|
|
task_cpu_block_id,
|
|
event_type,
|
|
transfer_task_id,
|
|
)
|
|
|
|
|
|
def main():
|
|
"""
|
|
启动cache manager
|
|
"""
|
|
|
|
cache_manager = CacheTransferManager(args)
|
|
|
|
cache_manager.do_data_transfer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parse_args()
|
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
|
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
|
|
paddle.set_device(f"gpu:{args.device_id}")
|
|
main()
|