mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 12:52:29 +08:00

* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
854 lines
40 KiB
Python
854 lines
40 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import queue
|
|
import threading
|
|
import time
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
|
from fastdeploy.config import SpeculativeConfig
|
|
from fastdeploy.inter_communicator import (
|
|
EngineWorkerQueue,
|
|
IPCSignal,
|
|
shared_memory_exists,
|
|
)
|
|
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
|
|
from fastdeploy.utils import envs, get_logger
|
|
|
|
logger = get_logger("cache_messager", "cache_messager.log")
|
|
|
|
|
|
def parse_args():
|
|
"""
|
|
从命令行解析参数
|
|
"""
|
|
parser = argparse.ArgumentParser("Cache Messager")
|
|
parser.add_argument(
|
|
"--splitwise_role",
|
|
type=str,
|
|
default="mixed",
|
|
help="splitwise role, can be decode, prefill or mixed",
|
|
)
|
|
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
|
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
|
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
|
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
|
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
|
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
|
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
|
parser.add_argument(
|
|
"--protocol",
|
|
type=str,
|
|
default="ipc",
|
|
help="cache transfer protocol, only surport ipc now",
|
|
)
|
|
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
|
parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
|
|
parser.add_argument(
|
|
"--engine_worker_queue_port",
|
|
type=int,
|
|
default=9923,
|
|
help="engine worker queue port",
|
|
)
|
|
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
|
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
|
parser.add_argument(
|
|
"--cache_dtype",
|
|
type=str,
|
|
default="bfloat16",
|
|
choices=["uint8", "bfloat16"],
|
|
help="cache dtype",
|
|
)
|
|
parser.add_argument(
|
|
"--speculative_config",
|
|
type=json.loads,
|
|
default="{}",
|
|
help="speculative config",
|
|
)
|
|
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
class CacheMessager:
|
|
"""
|
|
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
splitwise_role,
|
|
transfer_protocol,
|
|
pod_ip,
|
|
engine_worker_queue_port,
|
|
local_data_parallel_id,
|
|
gpu_cache_kvs,
|
|
rank,
|
|
nranks,
|
|
num_layers,
|
|
gpu_id=0,
|
|
rdma_port=None,
|
|
):
|
|
"""
|
|
Initialize the CacheMessager object.
|
|
|
|
Args:
|
|
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
|
transfer_protocol (str): support ipc and rdma
|
|
engine_worker_queue_port (int): engine_worker_queue port
|
|
gpu_cache_kvs (dict): GPU kv cache
|
|
rank (int): current rank
|
|
nranks (int): global rank number
|
|
num_layers (int): model layer number
|
|
gpu_id (int, optional): GPU ID
|
|
rdma_port (int, optional): RDMA port
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
self.splitwise_role = splitwise_role
|
|
self.gpu_cache_kvs = gpu_cache_kvs
|
|
self.rank = rank
|
|
self.nranks = nranks
|
|
address = (pod_ip, engine_worker_queue_port)
|
|
self.engine_worker_queue = EngineWorkerQueue(
|
|
address=address,
|
|
is_server=False,
|
|
num_client=self.nranks,
|
|
client_id=self.rank,
|
|
local_data_parallel_id=local_data_parallel_id,
|
|
)
|
|
transfer_protocol = transfer_protocol.split(",")
|
|
|
|
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
|
|
|
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
|
self.num_layers = num_layers
|
|
cache_k_ptr_list = []
|
|
cache_v_ptr_list = []
|
|
cache_k = []
|
|
cache_v = []
|
|
self.messager = {}
|
|
for layer_idx in range(self.num_layers):
|
|
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
|
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
|
cache_k.append(key_cache)
|
|
cache_v.append(val_cache)
|
|
cache_k_ptr_list.append(key_cache.data_ptr())
|
|
cache_v_ptr_list.append(val_cache.data_ptr())
|
|
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
|
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
|
|
|
# 2. initialize the block_bytes
|
|
cache_shape = key_cache.shape
|
|
max_block_num = cache_shape[0]
|
|
block_bytes = math.prod(cache_shape[1:])
|
|
if key_cache.dtype == paddle.bfloat16:
|
|
block_bytes *= 2
|
|
logger.info(
|
|
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
|
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
|
)
|
|
self.block_bytes = block_bytes
|
|
|
|
# 3. initialize the messager
|
|
for protocol in transfer_protocol:
|
|
if protocol == "ipc":
|
|
self.messager[protocol] = IPCCommManager(
|
|
self.rank,
|
|
gpu_id,
|
|
cache_k,
|
|
cache_v,
|
|
)
|
|
local_device_id = int(str(cache_k[0].place)[-2])
|
|
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
|
|
|
elif protocol == "rdma":
|
|
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
|
|
|
self.messager[protocol] = RDMACommManager(
|
|
splitwise_role,
|
|
rank,
|
|
gpu_id,
|
|
cache_k_ptr_list,
|
|
cache_v_ptr_list,
|
|
max_block_num,
|
|
block_bytes,
|
|
rdma_port,
|
|
)
|
|
|
|
self.gpu_id = gpu_id
|
|
self.cache_info = dict()
|
|
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
|
|
|
if self.splitwise_role != "mixed":
|
|
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
|
connect_rdma_thread.daemon = True
|
|
connect_rdma_thread.start()
|
|
|
|
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
|
|
|
def prefill_layerwise_send_cache_thread(self):
|
|
"""
|
|
layerwise_send_cache_thread:
|
|
send cache to other instance
|
|
"""
|
|
try:
|
|
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
|
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
|
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
|
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
|
step_shm_value = IPCSignal(
|
|
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
|
array=prefilled_step_idx_data,
|
|
dtype=np.int32,
|
|
suffix=self.gpu_id,
|
|
create=not shared_memory_exists(prefilled_step_name),
|
|
)
|
|
layer_shm_value = IPCSignal(
|
|
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
|
array=prefilled_layer_idx_data,
|
|
dtype=np.int32,
|
|
suffix=self.gpu_id,
|
|
create=not shared_memory_exists(prefilled_layer_name),
|
|
)
|
|
logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")
|
|
|
|
step_shm_value.value[0] = -1
|
|
layer_shm_value.value[0] = -1
|
|
|
|
self.last_step_idx = -1
|
|
self.last_layer_idx = -1 # int32
|
|
|
|
max_step_idx = 100003
|
|
engine_recycled_count = 0
|
|
|
|
while True:
|
|
|
|
cache_info = self.engine_worker_queue.get_cache_info()
|
|
|
|
if cache_info:
|
|
logger.debug(f"cache info {cache_info}")
|
|
for info in cache_info:
|
|
if info["request_id"] in self.cache_info:
|
|
self.cache_info[info["request_id"]].update(info)
|
|
current_info = self.cache_info[info["request_id"]]
|
|
if "dest_block_ids" in current_info and "src_block_ids" in current_info:
|
|
current_src_blocks = current_info["src_block_ids"][
|
|
-len(current_info["dest_block_ids"]) :
|
|
]
|
|
current_info["src_block_ids"] = current_src_blocks
|
|
current_info["status"] = "init"
|
|
logger.info(f"start cache_infos: {current_info}")
|
|
self.cache_info[info["request_id"]] = current_info
|
|
else:
|
|
self.cache_info[info["request_id"]] = info
|
|
prefilled_layer_idx = layer_shm_value.value[0]
|
|
prefilled_step_idx = step_shm_value.value[0]
|
|
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
|
if prefilled_layer_idx == self.num_layers - 1:
|
|
time.sleep(0.001)
|
|
prefilled_layer_idx = layer_shm_value.value[0]
|
|
prefilled_step_idx = step_shm_value.value[0]
|
|
|
|
if prefilled_step_idx == -1:
|
|
time.sleep(0.001)
|
|
continue
|
|
if not self.cache_info:
|
|
time.sleep(0.001)
|
|
continue
|
|
|
|
if self.last_step_idx > prefilled_step_idx:
|
|
engine_recycled_count += 1
|
|
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
|
|
prefilled_step_idx = (
|
|
prefilled_step_idx + max_step_idx * engine_recycled_count
|
|
) # remap prefilled_step_idx for comparison
|
|
|
|
logger.debug(
|
|
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
|
|
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
|
|
)
|
|
for req_id, item in list(self.cache_info.items()):
|
|
if "status" not in item:
|
|
continue
|
|
if "layer_idx" not in item:
|
|
item["layer_idx"] = 0
|
|
if item["status"] == "error":
|
|
del self.cache_info[req_id]
|
|
continue
|
|
if item["current_id"] > prefilled_step_idx:
|
|
continue
|
|
current_transfer_protocol = item["transfer_protocol"]
|
|
if item["transfer_protocol"] == "rdma":
|
|
target_ip = item["ip"]
|
|
target_id = int(item["rdma_ports"][self.rank])
|
|
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
|
if not status:
|
|
logger.error(f"connect to {target_ip}:{target_id} failed")
|
|
item["status"] = "error"
|
|
self.engine_worker_queue.finish_request_barrier.wait()
|
|
if self.rank == 0:
|
|
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
|
|
continue
|
|
elif item["transfer_protocol"] == "ipc":
|
|
target_ip = "0.0.0.0"
|
|
target_id = int(item["device_ids"][self.rank])
|
|
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
|
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
|
if item["current_id"] < prefilled_step_idx:
|
|
current_layer_idx = self.num_layers
|
|
else:
|
|
current_layer_idx = prefilled_layer_idx + 1
|
|
|
|
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
|
tic = time.time()
|
|
return_code = self.messager[current_transfer_protocol].write_cache(
|
|
target_ip,
|
|
target_id,
|
|
src_block_ids,
|
|
dest_block_ids,
|
|
layer_idx,
|
|
)
|
|
if return_code != 0:
|
|
item["status"] = "error"
|
|
self.engine_worker_queue.finish_request_barrier.wait()
|
|
if self.rank == 0:
|
|
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
|
logger.error(
|
|
f"write cache failed, layer_idx: {layer_idx}, "
|
|
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
|
)
|
|
break
|
|
|
|
tok = time.time()
|
|
cost_time = tok - tic
|
|
block_num = len(src_block_ids)
|
|
avg_time_per_block = cost_time * 1000 / block_num # ms
|
|
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
|
logger.debug(
|
|
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
|
f" {current_transfer_protocol}"
|
|
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
|
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
|
)
|
|
item["layer_idx"] = current_layer_idx
|
|
if item["layer_idx"] == self.num_layers:
|
|
if item["transfer_protocol"] == "ipc":
|
|
self.messager["ipc"].write_block_by_sync(target_id)
|
|
logger.info(f"finish write cache {item['request_id']}")
|
|
self.engine_worker_queue.finish_request_barrier.wait()
|
|
if self.rank == 0:
|
|
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
|
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
|
logger.info(f"put write cache {item['request_id']}")
|
|
del self.cache_info[req_id]
|
|
self.last_layer_idx = prefilled_layer_idx
|
|
|
|
except Exception as e:
|
|
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
|
|
|
|
def _handle_connect_task(self):
|
|
while True:
|
|
try:
|
|
task = self.engine_worker_queue.get_connect_rdma_task()
|
|
if task is None:
|
|
time.sleep(0.001)
|
|
continue
|
|
logger.info(f"_handle_connect_task recv task: {task}")
|
|
task_id = task["task_id"]
|
|
ip, rdma_port = task["ip"], task["rdma_port"]
|
|
status = self.messager["rdma"].connect(ip, rdma_port)
|
|
if not status:
|
|
response = {"task_id": task_id, "success": False}
|
|
else:
|
|
response = {"task_id": task_id, "success": True}
|
|
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
|
except Exception as e:
|
|
logger.error(f"handle_connect_task has exception: {e}")
|
|
|
|
|
|
class CacheMessagerV1:
|
|
"""
|
|
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
splitwise_role,
|
|
transfer_protocol,
|
|
pod_ip,
|
|
engine_worker_queue_port,
|
|
local_data_parallel_id,
|
|
gpu_cache_kvs,
|
|
rank,
|
|
nranks,
|
|
num_layers,
|
|
gpu_id=0,
|
|
block_size=64,
|
|
rdma_port=None,
|
|
):
|
|
"""
|
|
Initialize the CacheMessager object.
|
|
|
|
Args:
|
|
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
|
transfer_protocol (str): support ipc and rdma
|
|
engine_worker_queue_port (int): engine_worker_queue port
|
|
gpu_cache_kvs (dict): GPU kv cache
|
|
rank (int): current rank
|
|
nranks (int): global rank number
|
|
num_layers (int): model layer number
|
|
gpu_id (int, optional): GPU ID
|
|
rdma_port (int, optional): RDMA port
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
self.splitwise_role = splitwise_role
|
|
self.gpu_cache_kvs = gpu_cache_kvs
|
|
self.rank = rank
|
|
self.nranks = nranks
|
|
address = (pod_ip, engine_worker_queue_port)
|
|
self.engine_worker_queue = EngineWorkerQueue(
|
|
address=address,
|
|
is_server=False,
|
|
num_client=self.nranks,
|
|
client_id=self.rank,
|
|
local_data_parallel_id=local_data_parallel_id,
|
|
)
|
|
self.block_size = block_size
|
|
transfer_protocol = transfer_protocol.split(",")
|
|
|
|
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
|
|
|
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
|
self.num_layers = num_layers
|
|
cache_k_ptr_list = []
|
|
cache_v_ptr_list = []
|
|
cache_k = []
|
|
cache_v = []
|
|
self.messager = {}
|
|
for layer_idx in range(self.num_layers):
|
|
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
|
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
|
cache_k.append(key_cache)
|
|
cache_v.append(val_cache)
|
|
cache_k_ptr_list.append(key_cache.data_ptr())
|
|
cache_v_ptr_list.append(val_cache.data_ptr())
|
|
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
|
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
|
|
|
# 2. initialize the block_bytes
|
|
cache_shape = key_cache.shape
|
|
max_block_num = cache_shape[0]
|
|
block_bytes = math.prod(cache_shape[1:])
|
|
if key_cache.dtype == paddle.bfloat16:
|
|
block_bytes *= 2
|
|
logger.info(
|
|
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
|
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
|
)
|
|
self.block_bytes = block_bytes
|
|
|
|
# 3. initialize the messager
|
|
for protocol in transfer_protocol:
|
|
if protocol == "ipc":
|
|
self.messager[protocol] = IPCCommManager(
|
|
self.rank,
|
|
gpu_id,
|
|
cache_k,
|
|
cache_v,
|
|
)
|
|
local_device_id = int(str(cache_k[0].place)[-2])
|
|
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
|
|
|
elif protocol == "rdma":
|
|
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
|
|
|
self.messager[protocol] = RDMACommManager(
|
|
splitwise_role,
|
|
rank,
|
|
gpu_id,
|
|
cache_k_ptr_list,
|
|
cache_v_ptr_list,
|
|
max_block_num,
|
|
block_bytes,
|
|
rdma_port,
|
|
)
|
|
|
|
self.gpu_id = gpu_id
|
|
self.cache_info = dict()
|
|
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
|
self.engine_cache_task_thread_lock = threading.Lock()
|
|
self.engine_cache_tasks = [dict() for _ in range(512)]
|
|
self.idx_cache_task_dict = {}
|
|
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
|
|
if splitwise_role == "prefill":
|
|
consume_signals_thread = threading.Thread(target=self.consume_signals)
|
|
consume_signals_thread.daemon = True
|
|
consume_signals_thread.start()
|
|
add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
|
|
add_cache_task_thread.daemon = True
|
|
add_cache_task_thread.start()
|
|
|
|
if self.splitwise_role != "mixed":
|
|
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
|
connect_rdma_thread.daemon = True
|
|
connect_rdma_thread.start()
|
|
|
|
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
|
|
|
def _add_cache_task_thread(self):
|
|
while True:
|
|
try:
|
|
cache_info = self.engine_worker_queue.get_cache_info()
|
|
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
|
finished_add_cache_task_req_ids = []
|
|
if cache_info:
|
|
for info in cache_info:
|
|
if info["request_id"] in self.cache_info:
|
|
self.cache_info[info["request_id"]].update(info)
|
|
current_info = self.cache_info[info["request_id"]]
|
|
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
|
|
finished_add_cache_task_req_ids.append(info["request_id"])
|
|
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
|
|
current_info["dest_block_ids"]
|
|
)
|
|
padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
|
|
"dest_block_ids"
|
|
]
|
|
current_info["dest_block_ids"] = padding_decode_block_ids
|
|
current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
|
|
current_info["sended_layer_id"] = -1
|
|
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
|
current_info["status"] = "init"
|
|
logger.info(f"finish add cache task: {current_info}")
|
|
self.cache_info[info["request_id"]] = current_info
|
|
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
|
else:
|
|
self.cache_info[info["request_id"]] = info
|
|
if self.rank == 0 and finished_add_cache_task_req_ids:
|
|
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
|
else:
|
|
time.sleep(0.001)
|
|
except Exception as e:
|
|
logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.")
|
|
|
|
def prefill_layerwise_send_cache_thread(self):
|
|
"""
|
|
layerwise_send_cache_thread:
|
|
send cache to other instance
|
|
"""
|
|
while True:
|
|
try:
|
|
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
|
self.engine_worker_queue.finish_request_barrier.wait()
|
|
block_start_end_list = []
|
|
current_prefilled_token_num_list = []
|
|
for engine_index in engine_indexes:
|
|
assert engine_index in self.idx_cache_task_dict
|
|
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
|
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
|
if (
|
|
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
|
): # all chunks have been prefilled
|
|
block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
|
|
else:
|
|
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
|
block_start_end_list.append((block_id_start, block_id_end))
|
|
current_prefilled_token_num_list.append(prefilled_token_num)
|
|
while True: # from layer0 to last layer
|
|
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
|
start_layer_idx = sended_layer_idx + 1
|
|
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
|
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
|
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
|
logger.info(
|
|
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
|
)
|
|
assert (
|
|
current_prefilled_token_num_list[0]
|
|
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
|
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
|
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
|
else:
|
|
end_layer_idx = prefilled_layer_idx
|
|
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
|
|
time.sleep(0.01)
|
|
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
|
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
|
engine_index = engine_indexes[i]
|
|
task = self.idx_cache_task_dict[engine_index]
|
|
req_id = task["request_id"]
|
|
if (
|
|
block_id_start >= block_id_end
|
|
): # no blocks need to transfer for this request in this chunk
|
|
task["sended_layer_id"] += 1
|
|
assert task["sended_layer_id"] == layer_idx
|
|
if task["sended_layer_id"] == self.num_layers - 1:
|
|
task["sended_layer_id"] = -1
|
|
continue
|
|
else:
|
|
current_transfer_protocol = task["transfer_protocol"]
|
|
if task["transfer_protocol"] == "rdma":
|
|
target_ip = task["ip"]
|
|
target_id = int(task["rdma_ports"][self.rank])
|
|
if task["status"] == "error":
|
|
continue
|
|
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
|
if not status:
|
|
logger.error(f"connect to {target_ip}:{target_id} failed")
|
|
task["status"] = "connection error"
|
|
continue
|
|
elif task["transfer_protocol"] == "ipc":
|
|
target_ip = "0.0.0.0"
|
|
target_id = int(task["device_ids"][self.rank])
|
|
|
|
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
|
|
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
|
|
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
|
|
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
|
|
|
|
logger.info(
|
|
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
|
)
|
|
tic = time.time()
|
|
return_code = self.messager[current_transfer_protocol].write_cache(
|
|
target_ip,
|
|
target_id,
|
|
src_block_ids,
|
|
dest_block_ids,
|
|
layer_idx,
|
|
)
|
|
if return_code != 0:
|
|
task["status"] = "write cache error"
|
|
logger.error(
|
|
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
|
)
|
|
tok = time.time()
|
|
cost_time = tok - tic
|
|
block_num = len(src_block_ids)
|
|
avg_time_per_block = cost_time * 1000 / block_num # ms
|
|
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
|
logger.debug(
|
|
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
|
|
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
|
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
|
|
)
|
|
|
|
task["sended_layer_id"] += 1
|
|
assert task["sended_layer_id"] == layer_idx
|
|
if task["sended_layer_id"] == self.num_layers - 1:
|
|
self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
|
|
block_id_end - block_id_start
|
|
)
|
|
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
|
|
if task["status"] != "error":
|
|
task["status"] = "finished"
|
|
logger.info(
|
|
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
|
)
|
|
else:
|
|
task["sended_layer_id"] = -1
|
|
if end_layer_idx == self.num_layers - 1:
|
|
with self.engine_cache_task_thread_lock:
|
|
for engine_idx in engine_indexes:
|
|
task = self.idx_cache_task_dict[engine_idx]
|
|
if task["status"] == "finished" or ("error" in task["status"]):
|
|
target_id = int(task["rdma_ports"][self.rank])
|
|
if task["transfer_protocol"] == "ipc":
|
|
self.messager["ipc"].write_block_by_sync(target_id)
|
|
if self.rank == 0:
|
|
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
|
self.engine_worker_queue.put_finished_req(
|
|
[(task["request_id"], task["status"])]
|
|
)
|
|
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
|
self.engine_cache_tasks[task["current_id"]] = dict()
|
|
del self.cache_info[task["request_id"]]
|
|
del self.idx_cache_task_dict[task["current_id"]]
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
|
|
time.sleep(0.01)
|
|
|
|
def consume_signals(self):
|
|
paddle.device.set_device("cpu")
|
|
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
|
|
while True:
|
|
try:
|
|
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
|
|
if not self.cache_info:
|
|
time.sleep(0.01)
|
|
continue
|
|
tasks_count = kv_signal_data[0]
|
|
if tasks_count == -1:
|
|
time.sleep(0.001)
|
|
continue
|
|
layer_id = kv_signal_data[1].numpy().tolist()
|
|
if layer_id == self.num_layers - 1:
|
|
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
|
batch_engine_ids = []
|
|
with self.engine_cache_task_thread_lock:
|
|
for bi in range(tasks_count):
|
|
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
|
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
|
|
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
|
|
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
|
|
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
|
chuck_token_offset + current_seq_len
|
|
)
|
|
batch_engine_ids.append(engine_idx)
|
|
if layer_id == 0:
|
|
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
|
except Exception as e:
|
|
logger.error(f"Consume signals get exception: {e}")
|
|
|
|
def _handle_connect_task(self):
|
|
while True:
|
|
try:
|
|
task = self.engine_worker_queue.get_connect_rdma_task()
|
|
if task is None:
|
|
time.sleep(0.001)
|
|
continue
|
|
logger.info(f"_handle_connect_task recv task: {task}")
|
|
task_id = task["task_id"]
|
|
ip, rdma_port = task["ip"], task["rdma_port"]
|
|
status = self.messager["rdma"].connect(ip, rdma_port)
|
|
if not status:
|
|
response = {"task_id": task_id, "success": False}
|
|
else:
|
|
response = {"task_id": task_id, "success": True}
|
|
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
|
except Exception as e:
|
|
logger.error(f"handle_connect_task has exception: {e}")
|
|
|
|
|
|
def main():
|
|
device = args.device_id
|
|
rank = args.rank
|
|
paddle.set_device(f"gpu:{device}")
|
|
cache_type = args.cache_dtype
|
|
speculative_config = SpeculativeConfig(args.speculative_config)
|
|
num_extra_layers = speculative_config.num_extra_cache_layer
|
|
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
|
gpu_cache_kvs = {}
|
|
gpu_cache_k_tensors = []
|
|
gpu_cache_v_tensors = []
|
|
|
|
for i in range(args.num_layers + num_extra_layers):
|
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
|
|
|
|
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
|
shape=[
|
|
num_gpu_blocks,
|
|
args.kv_num_head,
|
|
args.block_size,
|
|
args.head_dim,
|
|
],
|
|
fill_value=0,
|
|
dtype=cache_type,
|
|
)
|
|
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
|
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
|
shape=[
|
|
num_gpu_blocks,
|
|
args.kv_num_head,
|
|
args.block_size,
|
|
args.head_dim,
|
|
],
|
|
fill_value=0,
|
|
dtype=cache_type,
|
|
)
|
|
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
|
|
|
set_data_ipc(
|
|
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
|
f"key_caches_{i}_rank{rank}.device{device}",
|
|
)
|
|
set_data_ipc(
|
|
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
|
f"value_caches_{i}_rank{rank}.device{device}",
|
|
)
|
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
|
logger.info(f"device :{device}")
|
|
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
|
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
|
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
cache_messager = CacheMessagerV1(
|
|
splitwise_role=args.splitwise_role,
|
|
transfer_protocol=args.protocol,
|
|
pod_ip=args.pod_ip,
|
|
engine_worker_queue_port=args.engine_worker_queue_port,
|
|
local_data_parallel_id=args.local_data_parallel_id,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=rank,
|
|
nranks=args.mp_num,
|
|
num_layers=args.num_layers + num_extra_layers,
|
|
gpu_id=device,
|
|
rdma_port=args.rdma_port,
|
|
)
|
|
else:
|
|
cache_messager = CacheMessager(
|
|
splitwise_role=args.splitwise_role,
|
|
transfer_protocol=args.protocol,
|
|
pod_ip=args.pod_ip,
|
|
engine_worker_queue_port=args.engine_worker_queue_port,
|
|
local_data_parallel_id=args.local_data_parallel_id,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=rank,
|
|
nranks=args.mp_num,
|
|
num_layers=args.num_layers + num_extra_layers,
|
|
gpu_id=device,
|
|
rdma_port=args.rdma_port,
|
|
)
|
|
|
|
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
|
cache_ready_signal = IPCSignal(
|
|
name="cache_ready_signal",
|
|
array=cache_ready_signal_data,
|
|
dtype=np.int32,
|
|
suffix=args.engine_pid,
|
|
create=False,
|
|
)
|
|
cache_ready_signal.value[rank] = 1
|
|
if args.splitwise_role == "mixed":
|
|
while True:
|
|
time.sleep(1)
|
|
cache_messager.prefill_layerwise_send_cache_thread()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parse_args()
|
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
|
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
|
|
|
logger.info("create cache messager...")
|
|
logger.info(f"{args}")
|
|
main()
|