Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,15 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -0,0 +1,162 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from enum import Enum
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
class CacheStatus(Enum):
"""
cache status enum class
"""
GPU = 0
SWAP2CPU = 1
SWAP2GPU = 2
CPU = 3
class BlockNode:
"""
BlockNode: store the information of a block node
"""
def __init__(
self,
node_id,
input_ids,
input_hash_value,
depth,
block_id,
token_num,
hash_value,
last_used_time,
parent=None,
shared_count=1,
reverved_dec_block_ids=[],
cache_status=CacheStatus.GPU,
is_persistent=False,
persistent_shared_count=0,
):
"""
Args:
node_id: Unique identifier of the node
depth: Depth of the node
block_id: Assigned block ID (CPU block ID if on CPU, GPU block ID if on GPU)
token_num: Number of tokens in the current block
hash_value: Hash value of the current block
last_used_time: Timestamp of last usage
parent: Parent node
shared_count: Reference count of requests currently using this node
reserved_dec_block_ids: Pre-allocated block IDs reserved for decoding, formatted as [block_id, block_id,...]
cache_status: Current cache state (USING, SWAP2CPU, SWAP2GPU, FREE)
is_persistent: Whether the node is persistently stored
persistent_shared_count: Reference count of persistent cache requests
"""
self.node_id = node_id
self.depth = depth
self.parent = parent
self.hash_value = hash_value
self.token_num = token_num
self.input_ids = input_ids
self.input_hash_value = input_hash_value
self.children = {}
self.shared_count = shared_count
self.last_used_time = last_used_time
self.block_id = block_id
self.reverved_dec_block_ids = reverved_dec_block_ids
self.cache_status = cache_status
self.is_persistent = is_persistent
self.persistent_shared_count = persistent_shared_count
self.req_id_set = set()
def __lt__(self, other):
"""
override the less than operator
"""
if self.last_used_time < other.last_used_time:
return True
elif self.last_used_time > other.last_used_time:
return False
else:
return self.depth > other.depth
def __str__(self):
"""
return node info
"""
if self.parent is not None:
parent_node_id = self.parent.node_id
else:
parent_node_id = None
return (
f"node_id {self.node_id}: depth {self.depth} hash_value {self.hash_value}"
+
f" shared_count {self.shared_count} is_gpu_leaf_node {self.is_gpu_leaf_node}"
+
f" is_cpu_leaf_node {self.is_cpu_leaf_node} block_id {self.block_id} "
+ f"has_in_gpu {self.has_in_gpu} " +
f"cache_status {self.cache_status} parent {parent_node_id} with children number "
+ f"{len(self.children)} req_id_set {self.req_id_set}")
@property
def has_in_gpu(self):
"""
check if the node has been allocated in GPU
"""
return self.cache_status == CacheStatus.GPU
def increment_shared_count(self):
"""
increment shared count
"""
self.shared_count += 1
def decrement_shared_count(self):
"""
decrement shared count
"""
self.shared_count -= 1
@property
def is_cpu_leaf_node(self):
"""
check if the node is a leaf node in CPU
"""
if (self.cache_status == CacheStatus.CPU) and (len(self.children)
== 0):
return True
return False
@property
def is_gpu_leaf_node(self):
"""
check if the node is a leaf node in GPU
"""
if self.has_in_gpu is False:
return False
else:
if len(self.children) == 0:
return True
for child in self.children.values():
if child.has_in_gpu is True:
return False
return True

View File

@@ -0,0 +1,318 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
import threading
import time
import numpy as np
import paddle
from fastdeploy.cache_manager.transfer_factory import (IPCCommManager,
RDMACommManager)
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
class CacheMessager(object):
"""
CacheMessager is used to send the cache data between the engine worker and the cache server.
"""
def __init__(self,
splitwise_role,
transfer_protocol,
engine_worker_queue_port,
local_data_parallel_id,
gpu_cache_kvs,
rank,
nranks,
num_layers,
gpu_id=0,
rdma_port=None):
"""
Initialize the CacheMessager object.
Args:
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
transfer_protocol (str): support ipc and rdma
engine_worker_queue_port (int): engine_worker_queue port
gpu_cache_kvs (dict): GPU kv cache
rank (int): current rank
nranks (int): global rank number
num_layers (int): model layer number
gpu_id (int, optional): GPU ID
rdma_port (int, optional): RDMA port
Returns:
None
"""
assert splitwise_role in ["prefill", "decode"], \
"splitwise_role must be prefill or decode"
self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
self.nranks = nranks
address = ('0.0.0.0', engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.nranks,
client_id=self.rank,
local_data_parallel_id=local_data_parallel_id)
transfer_protocol = transfer_protocol.split(",")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}"
f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers
cache_k_ptr_list = []
cache_v_ptr_list = []
cache_k = []
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[
f'key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
val_cache = self.gpu_cache_kvs[
f'value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
# 2. initialize the block_bytes
cache_shape = key_cache.shape
max_block_num = cache_shape[0]
block_bytes = math.prod(cache_shape[1:])
if key_cache.dtype == paddle.bfloat16:
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}")
self.block_bytes = block_bytes
# 3. initialize the messager
for protocol in transfer_protocol:
if protocol == "ipc":
self.messager[protocol] = IPCCommManager(
self.rank,
gpu_id,
cache_k,
cache_v,
)
local_device_id = int(str(cache_k[0].place)[-2])
logger.info(
f"done create ipc_comm with local_device_id:{local_device_id}, "
)
elif protocol == "rdma":
logger.info(
f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}"
)
self.messager[protocol] = RDMACommManager(
splitwise_role, rank, gpu_id, cache_k_ptr_list,
cache_v_ptr_list, max_block_num, block_bytes, rdma_port)
self.gpu_id = gpu_id
self.cache_info = dict()
layerwise_send_cache_thread = threading.Thread(
target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}")
def _prefill_layerwise_send_cache_thread(self):
"""
layerwise_send_cache_thread:
send cache to other instance
"""
try:
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
try:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.rank}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.rank}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True)
except:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.rank}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.rank}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False)
step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1
self.last_step_idx = -1
self.last_layer_idx = -1 # int32
while True:
cache_info = self.engine_worker_queue.get_cache_info()
if cache_info:
logger.debug(f"cache info {cache_info}")
for info in cache_info:
if info['request_id'] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]]
if "dest_block_ids" in current_info and "src_block_ids" in current_info:
current_src_blocks = current_info[
"src_block_ids"][-len(current_info["dest_block_ids"]):]
current_info[
"src_block_ids"] = current_src_blocks
current_info["current_layer_ids"] = 0
current_info["status"] = "init"
logger.info(
f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(
self.last_step_idx, current_info['current_id'])
else:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
if prefilled_layer_idx == self.num_layers - 1:
time.sleep(0.001)
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
if prefilled_step_idx == -1:
time.sleep(0.001)
continue
if not self.cache_info:
time.sleep(0.001)
continue
logger.debug(
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}"
)
for req_id, item in list(self.cache_info.items()):
if "status" not in item:
continue
if "layer_idx" not in item:
item["layer_idx"] = 0
if item['status'] == 'error':
del self.cache_info[req_id]
continue
if item['current_id'] > prefilled_step_idx:
continue
current_transfer_protocol = item["transfer_protocol"]
if item["transfer_protocol"] == "rdma":
target_ip = item['ip']
target_id = int(item['rdma_ports'][self.rank])
status = self.messager[
current_transfer_protocol].connect(
target_ip, target_id)
if not status:
logger.error(
f"connect to {target_ip}:{target_id} failed")
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([
(item['request_id'], "connect error")
])
continue
elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(item['device_ids'][self.rank])
src_block_ids = paddle.to_tensor(item['src_block_ids'],
dtype='int32',
place='cpu')
dest_block_ids = paddle.to_tensor(item['dest_block_ids'],
dtype='int32',
place='cpu')
if item['current_id'] < prefilled_step_idx:
current_layer_idx = self.num_layers
else:
current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"],
current_layer_idx):
tic = time.time()
return_code = self.messager[
current_transfer_protocol].write_cache(
target_ip, target_id, src_block_ids,
dest_block_ids, layer_idx)
if return_code != 0:
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([
(item['request_id'], "write cache error")
])
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
)
break
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
item['layer_idx'] = current_layer_idx
if item['layer_idx'] == self.num_layers:
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([
(item['request_id'], "finished")
])
logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id]
self.last_step_idx = prefilled_step_idx
self.last_layer_idx = prefilled_layer_idx
except Exception as e:
logger.error(
f"prefill layerwise send cache thread has exception: {e}")

View File

@@ -0,0 +1,137 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
class CacheMetrics:
"""
Cache Metrics used to record the cache hit time, token num, request num, etc.
"""
def __init__(self):
self.total_match_time = 0.0
self.avg_match_time = 0.0
self.min_match_time = 1e9
self.max_match_time = 0.0
# request level
self.req_count = 0
self.hit_req_count = 0
self.hit_req_ratio = 0.0
# token level
self.total_gpu_matched_token_num = 0
self.total_cpu_matched_token_num = 0
self.matched_token_num = 0
self.total_token_num = 0
self.hit_token_ratio = 0.0
self.cpu_hit_token_ratio = 0.0
self.gpu_hit_token_ratio = 0.0
def _update_history_hit_metrics(self):
"""
update hit ratio
"""
self.hit_req_ratio = self.hit_req_count / self.req_count
self.hit_token_ratio = self.matched_token_num / self.total_token_num
self.cpu_hit_token_ratio = (
self.total_cpu_matched_token_num / self.total_token_num
)
self.gpu_hit_token_ratio = (
self.total_gpu_matched_token_num / self.total_token_num
)
logger.info(
f"Metrics for all requests: req_count {self.req_count} hit_req_count {self.hit_req_count}"
+ f" hit_req_ratio {self.hit_req_ratio:.2f} hit_token_ratio {self.hit_token_ratio:.2f}"
+ f" gpu_hit_token_ratio {self.gpu_hit_token_ratio:.2f}"
+ f" cpu_hit_token_ratio {self.cpu_hit_token_ratio:.2f}"
+ f" total_gpu_matched_token_num {self.total_gpu_matched_token_num}"
+ f" total_cpu_matched_token_num {self.total_cpu_matched_token_num}"
+ f" total_matched_token_num {self.matched_token_num}"
+ f" total_token_num {self.total_token_num}"
)
def calculate_hit_metrics(
self,
req_id,
current_query_cpu_match_token_num,
current_query_gpu_match_token_num,
current_query_token_num,
):
"""
calculate hit metrics for current query
"""
cpu_cache_match_ratio = (
current_query_cpu_match_token_num / current_query_token_num
)
gpu_cache_match_ratio = (
current_query_gpu_match_token_num / current_query_token_num
)
total_match_ratio = (
cpu_cache_match_ratio + gpu_cache_match_ratio
)
self.total_cpu_matched_token_num += (
current_query_cpu_match_token_num
)
self.total_gpu_matched_token_num += (
current_query_gpu_match_token_num
)
self.matched_token_num += (
current_query_cpu_match_token_num
+ current_query_gpu_match_token_num
)
self.total_token_num += current_query_token_num
logger.info(
f"Metrics for req_id {req_id}: token_num {current_query_token_num}"
+ f" cpu_cache_match_ratio {cpu_cache_match_ratio}"
+ f" gpu_cache_match_ratio {gpu_cache_match_ratio}"
+ f" total_match_ratio {total_match_ratio}"
)
def reset_metrics(self):
"""
reset metrics
"""
self.total_match_time = 0.0
self.avg_match_time = 0.0
self.min_match_time = 1e9
self.max_match_time = 0.0
self.req_count = 0
self.hit_req_count = 0
self.hit_req_ratio = 0.0
self.total_gpu_matched_token_num = 0
self.total_cpu_matched_token_num = 0
self.matched_token_num = 0
self.total_token_num = 0
self.hit_token_ratio = 0.0
self.cpu_hit_token_ratio = 0.0
self.gpu_hit_token_ratio = 0.0

View File

@@ -0,0 +1,470 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import concurrent.futures
import json
import queue
import time
import numpy as np
import paddle
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (cuda_host_alloc, set_data_ipc,
swap_cache_all_layers)
from fastdeploy.utils import get_logger
def parse_args():
"""
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument("--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed")
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers",
type=int,
default=1,
help="model num layers")
parser.add_argument("--head_dim",
type=int,
default=1,
help="model head dim")
parser.add_argument("--kv_num_head",
type=int,
default=1,
help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num",
type=int,
default=1,
help="number of model parallel")
parser.add_argument("--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now")
parser.add_argument("--enable_splitwise",
type=int,
default=0,
help="enable splitwise ")
parser.add_argument("--cache_queue_port",
type=int,
default=9923,
help="cache queue port")
parser.add_argument("--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port")
parser.add_argument("--engine_pid",
type=str,
default=None,
help="engine pid")
parser.add_argument("--num_gpu_blocks",
type=int,
default=1,
help="gpu cache block number")
parser.add_argument("--num_cpu_blocks",
type=int,
default=4,
help="cpu cache block number")
parser.add_argument("--block_size",
type=int,
default=64,
help="cache block size(tokens)")
parser.add_argument("--bytes_per_layer_per_block",
type=int,
default=1024,
help="per layer per block bytes")
parser.add_argument("--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype")
parser.add_argument("--speculative_config",
type=json.loads,
default="{}",
help="speculative config")
parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args()
return args
class CacheTransferManager:
"""
管理CPU和GPU之间缓存的交换传输
"""
def __init__(self, args):
"""
初始化CacheTransferManager
"""
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
self.gpu_cache_v_tensors = []
self.speculative_config = SpeculativeConfig(**args.speculative_config)
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = \
int(args.num_gpu_blocks * \
self.speculative_config.num_gpu_block_expand_ratio)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=1)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num
self.rank = rank
self.device = device
address = ('0.0.0.0', args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
address=address,
is_server=False,
num_client=args.mp_num,
client_id=rank,
local_data_parallel_id=args.local_data_parallel_id)
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else \
self.num_extra_layer_gpu_blocks
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
i, rank, device)] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
i, rank, device)])
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)])
set_data_ipc(
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
i, rank, device)],
"key_caches_{}_rank{}.device{}".format(i, rank, device))
set_data_ipc(
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)],
"value_caches_{}_rank{}.device{}".format(i, rank, device))
cache_kv_size_byte = sum(
[tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(
f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs["key_caches_{}_rank{}".format(
i, rank)] = cuda_host_alloc(args.num_cpu_blocks *
args.bytes_per_layer_per_block)
self.k_dst_ptrs.append(
self.cpu_cache_kvs["key_caches_{}_rank{}".format(i, rank)])
self.cpu_cache_kvs["value_caches_{}_rank{}".format(
i, rank)] = cuda_host_alloc(args.num_cpu_blocks *
args.bytes_per_layer_per_block)
self.v_dst_ptrs.append(
self.cpu_cache_kvs["value_caches_{}_rank{}".format(i, rank)])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False)
self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager
self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(
f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal",
array=cache_task_broadcast_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False)
def _do_swap_to_cpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id,
event_type, transfer_task_id):
"""
swap cache GPU->CPU
"""
self.cache_task_queue.swap_to_cpu_barrier1.wait()
if self.rank == 0:
self.cache_task_queue.swap_to_cpu_barrier1.reset()
result = self._transfer_data(
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
)
self.cache_task_queue.swap_to_cpu_barrier2.wait()
if self.rank == 0:
self.cache_task_queue.swap_to_cpu_barrier2.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(
f"_do_swap_to_cpu_task: put_transfer_done_signal {result}")
logger.info(
f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
)
def _do_swap_to_gpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id,
event_type, transfer_task_id):
"""
swap cache CPU->GPU
"""
self.cache_task_queue.swap_to_gpu_barrier1.wait()
if self.rank == 0:
self.cache_task_queue.swap_to_gpu_barrier1.reset()
result = self._transfer_data(
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
)
self.cache_task_queue.swap_to_gpu_barrier2.wait()
if self.rank == 0:
self.cache_task_queue.swap_to_gpu_barrier2.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(
f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
logger.info(
f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
)
def do_data_transfer(self):
"""
do data transfer task
"""
while True:
try:
if self.rank == 0:
if not self.cache_task_queue.empty():
self.cache_task_broadcast_signal.value[0] = 1
if self.n_ranks > 1:
self.cache_task_queue.barrier1.wait()
if self.rank == 0:
self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1:
data, read_finish = self.cache_task_queue.get_transfer_task(
)
logger.debug(f"transfer data: get_transfer_task {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
(
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
) = data
if event_type.value == CacheStatus.SWAP2CPU.value:
self.swap_to_cpu_thread_pool.submit(
self._do_swap_to_cpu_task,
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
)
else:
self.swap_to_gpu_thread_pool.submit(
self._do_swap_to_gpu_task,
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
)
else:
if self.n_ranks > 1:
self.cache_task_queue.barrier2.wait()
if self.rank == 0:
self.cache_task_queue.barrier2.reset()
continue
if self.n_ranks > 1:
self.cache_task_queue.barrier3.wait()
if self.rank == 0:
self.cache_task_queue.barrier3.reset()
except Exception as e:
logger.info(f"do_data_transfer: error: {e}")
def _transfer_data(
self,
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
):
"""
transfer data
task_gpu_block_id format: [[block_id0, [fold_block_id0, fold_block_id1]],
[block_id1, [fold_block_id0, fold_block_id1]], ...]
"""
logger.debug(
f"transfer data: transfer_task_id {transfer_task_id}: swap_node_ids {swap_node_ids}"
+
f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
)
start_time = time.time()
try:
# transform block id
assert len(task_gpu_block_id) == len(task_cpu_block_id)
gpu_block_ids = task_gpu_block_id
cpu_block_ids = task_cpu_block_id
if event_type.value == CacheStatus.SWAP2CPU.value:
swap_cache_all_layers(
self.gpu_cache_k_tensors,
self.k_dst_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
0,
)
swap_cache_all_layers(
self.gpu_cache_v_tensors,
self.v_dst_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
0,
)
elif event_type.value == CacheStatus.SWAP2GPU.value:
swap_cache_all_layers(
self.gpu_cache_k_tensors,
self.k_dst_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
1,
)
swap_cache_all_layers(
self.gpu_cache_v_tensors,
self.v_dst_ptrs,
self.num_cpu_blocks,
gpu_block_ids,
cpu_block_ids,
self.device,
1,
)
else:
logger.warning(
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"
)
except Exception as e:
logger.error(f"transfer data: error: {e}")
raise e
end_time = time.time()
elasped_time = end_time - start_time
logger.info(
f"transfer data: transfer_task_id {transfer_task_id} event_type {event_type}: "
+
f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
)
return (
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
)
def main():
"""
启动cache manager
"""
cache_manager = CacheTransferManager(args)
cache_manager.do_data_transfer()
if __name__ == "__main__":
args = parse_args()
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .ipc_cache_transfer import IPCCommManager
from .rdma_cache_transfer import RDMACommManager

View File

@@ -0,0 +1,133 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import paddle
from fastdeploy.model_executor.ops.gpu import (
get_data_ptr_ipc, ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync)
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
class IPCConnector:
"""
IPC communication class.
"""
def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_):
"""
Args:
rank_id_: rank id
remote_gpu_id_: remote gpu id
"""
self.remote_key_tensor_ptr_list = []
self.remote_value_tensor_ptr_list = []
self.remote_gpu_id = int(remote_gpu_id_)
self.rank_id = rank_id_
self.local_gpu_id = int(local_gpu_id_)
tmp = paddle.ones([1, 1])
logger.info(
f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}"
)
for layer_id in range(layer_num):
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
self.remote_key_tensor_ptr_list.append(
get_data_ptr_ipc(tmp, key_unique_name))
self.remote_value_tensor_ptr_list.append(
get_data_ptr_ipc(tmp, value_unique_name))
self.write_stream = paddle.device.Stream(f'gpu:{self.local_gpu_id}')
self.finish_event = paddle.device.Event()
class IPCCommManager:
"""
IPC communication manager, used to initialize ipc and cache transmission.
"""
def __init__(
self,
rank_id_,
gpu_idx_,
local_key_cache_tensor_list, # tensor list
local_value_cache_tensor_list, # tensor
):
self.rank_id = rank_id_
self.gpu_idx = gpu_idx_
# local cache to tensor
self.local_key_cache_tensor_list = local_key_cache_tensor_list
self.local_value_cache_tensor_list = local_value_cache_tensor_list
self.layer_num = len(self.local_key_cache_tensor_list)
# record connected ipc info
self.comm_map = {}
def connect(self, remote_gpu_id_=0):
"""
Connect to remote gpu.
"""
logger.info(
f"{self.rank_id}: connect to remote_gpu_id:{remote_gpu_id_} {self.layer_num} {self.gpu_idx}"
)
if self.is_connected(remote_gpu_id_):
return True
else:
self.comm_map[remote_gpu_id_] = IPCConnector(
self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
return True
def is_connected(self, remote_gpu_id_=0):
"""
Check if remote gpu is connected.
"""
if remote_gpu_id_ in self.comm_map.keys():
return True
else:
return False
def write_cache(self, ip, remote_gpu_id, local_block_ids, remote_block_ids,
layer_idx):
"""
Connect to remote gpu and write cache.
"""
block_num = len(local_block_ids)
if not self.is_connected(remote_gpu_id):
self.connect(remote_gpu_id)
comm = self.comm_map[remote_gpu_id]
with paddle.device.stream_guard(comm.write_stream):
ipc_sent_key_value_cache_by_remote_ptr(
self.local_key_cache_tensor_list[layer_idx],
self.local_value_cache_tensor_list[layer_idx], local_block_ids,
remote_block_ids, comm.remote_key_tensor_ptr_list[layer_idx],
comm.remote_value_tensor_ptr_list[layer_idx], block_num,
self.gpu_idx, comm.remote_gpu_id,
comm.write_stream.stream_base.cuda_stream)
return 0
def write_block_by_sync(self, remote_gpu_id):
"""
check finish event and wait for it
"""
paddle.set_device(f'gpu:{self.gpu_idx}')
comm = self.comm_map[remote_gpu_id]
ipc_sent_key_value_cache_by_remote_ptr_block_sync(
self.local_key_cache_tensor_list[0], #tensor no use
self.local_value_cache_tensor_list[0], #tensor no use
comm.write_stream.stream_base.cuda_stream)

View File

@@ -0,0 +1,35 @@
cmake_minimum_required (VERSION 3.5)
project(rdma_comm LANGUAGES CXX)
set(PROJECT_SOURCE_DIR ${CMAKE_SOURCE_DIR})
set(CMAKE_BINARY_DIR ${CMAKE_SOURCE_DIR}/bin)
set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR})
set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR})
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_CXX_COMPILER g++)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -Ofast -ffast-math -funroll-loops -march=native -std=c++11")
add_compile_options("-std=c++11")
find_library(IBVERBS_LIBRARY ibverbs)
find_library(RDMACM_LIBRARY rdmacm)
find_package(pybind11 CONFIG REQUIRED)
include_directories("${PROJECT_SOURCE_DIR}/include")
add_library(rdma_comm MODULE ${PROJECT_SOURCE_DIR}/src/pybind.cpp ${PROJECT_SOURCE_DIR}/src/kvcache_rdma.cpp ${PROJECT_SOURCE_DIR}/src/kvcache_connection.cpp ${PROJECT_SOURCE_DIR}/src/log.cpp)
set_target_properties(rdma_comm PROPERTIES
OUTPUT_NAME "rdma_comm"
PREFIX ""
SUFFIX ".so"
)
target_link_libraries(rdma_comm PRIVATE pybind11::module)
target_link_libraries(rdma_comm LINK_PUBLIC ibverbs pthread)

View File

@@ -0,0 +1,232 @@
# KVTransferManager
A dedicated component for transferring KV Cache between Prefill and Decode nodes, supporting RDMA communication with ultra-low latency.
## Performance Benchmark
### KVTransferManager vs Mooncake Performance Comparison
### Test Scenario
- **Hardware Configuration**:
- Single Mellanox ConnectX-7 400G NIC (single port)
- Tested with BATCH_SIZE = 1538 and block size = 1K - 256K
- Single pressure thread (threads = 1)
- **Comparison Baseline**:
- Mooncake performance measured using transfer_engine_bench from example directory
- Same hardware configuration and test parameters applied to KVTransferManager
### Performance Results
| Block Size | KVTransferManager | Mooncake | Performance Gain |
|------------|-----------------|----------|------------------|
| 1K | 10.67 GB/s | 1.54 GB/s | 6.9x |
| 2K | 17.53 GB/s | 3.40 GB/s | 5.2x |
| 4K | 28.85 GB/s | 6.95 GB/s | 4.2x |
| 8K | 36.56 GB/s | 12.48 GB/s | 2.9x |
| 16K | 41.73 GB/s | 23.42 GB/s | 1.8x |
| 32K | 43.55 GB/s | 31.58 GB/s | 1.4x |
| 64K | 44.46 GB/s | 38.39 GB/s | 1.2x |
| 128K | 44.86 GB/s | 40.11 GB/s | 1.1x |
| 256K | 45.01 GB/s | 40.71 GB/s | 1.1x |
Bandwidth Saturation Capability: Under multi-threaded high-pressure scenarios, both KVTransferManager and Mooncake can fully utilize the 400G network card bandwidth, achieving transmission performance close to the theoretical hardware limit (approximately 50 GB/s).
## Quick start
### Requirements
- Supported Architectures:
Hopper GPUs
Kunlun XPU
Ampere GPUs (supported by enabling KVCACHE_GDRCOPY_FLUSH_ENABLE)
### Dependencies Installation
#### Python Packages
```bash
pip install pyzmq pybind11[global]
```
#### System Libraries (Linux)
```bash
# Ubuntu/Debian
sudo apt-get install -y libibverbs-dev librdmacm-dev
# RHEL/CentOS
sudo yum install -y libibverbs-devel librdmacm-devel
```
#### Hardware Requirements
- RDMA-capable network hardware (e.g. Mellanox NICs)
- Supported GPU architectures: Hopper, Kunlun XPU, Ampere
#### Ampere Architecture Note
To support Ampere GPUs, enable the environment variable KVCACHE_GDRCOPY_FLUSH_ENABLE.
- What it does:
Forces memory flushing after a GDRCopy write operation to ensure data consistency on the Ampere architecture. Here if KVCACHE_GDRCOPY_FLUSH_ENABLE is enable we trigger an RDMA read operation after the last RDMA write.
- Why its needed:
When the NIC delivers a completion to the CPU, it indicates that the data has reach the GPU. However, it doesn't mean that the GPU can read that data yet. To make sure the data has gone all the way down to the GPU memory and the GPU can read it, we need to perform a read.
[NCCL Issue #683](https://github.com/NVIDIA/nccl/issues/683) |
[NCCL Issue #1702](https://github.com/NVIDIA/nccl/issues/1702)
Since the upper layer typically issues a cache arrival notification only after polling a Completion Queue Entry (CQE), this prevents the application from being notified before the data is actually written back to memory. Therefore, the potential race condition where the cache has not yet been flushed but the application assumes completion is considered a rare event in practice.
- How to enable:
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
### Development
```bash
# Build and make symbolic links for SO files
python setup.py bdist_wheel
pip install dist/*.whl
```
## Environment Variables Configuration
### RDMA Settings
| Variable | Default | Description |
|----------|---------|-------------|
| `KVCACHE_RDMA_GID_INDEX` | 3 | RDMA GID index |
| `KVCACHE_RDMA_NICS` | - | RDMA NIC list, comma-separated (e.g., “mlx5_0,mlx5_1”), selects ib device based on gpu_index. This environment variable must be set. NICs are selected using modulo operation on gpu_index. |
| `KVCACHE_IB_TIMEOUT` | 18 | InfiniBand communication timeout (14-31), where timeout = 4.096μs * 2^value (default 18 ≈ 1.07s).|
| `KVCACHE_RELAX_ORDERING` | false | Enable RDMA relaxed ordering to improve performance in multi-GPU scenarios. Recommended when multiple GPUs share the same NIC to mitigate TX pause issues. |
### Network Settings
| Variable | Default | Description |
|----------|---------|-------------|
| `KVCACHE_SOCKET_IFNAME` | auto | Network interface for socket comm (e.g. "eth0") |
### Debugging
| Variable | Default | Description |
|----------|---------|-------------|
| `KVCACHE_DEBUG` | false | Enable debug logging |
| `KVCACHE_DEBUG_FILE` | - | Debug log file path |
| `KVCACHE_ERROR_FILE` | - | Error log file path |
### Performance Tuning
| Variable | Default | Description |
|----------|---------|-------------|
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | Enable GDRCopy flush for Ampere GPUs |
# Set RDMA GID index
export KVCACHE_RDMA_GID_INDEX=3
# Set RDMA IB Device List
export KVCACHE_RDMA_NICS=mlx5_0,mlx5_1,mlx5_2
# Specify network interface
export KVCACHE_SOCKET_IFNAME=eth0
# Enable debug mode
export KVCACHE_DEBUG=1
# Set log files
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
## Network configurations
kvcache transfer is fully tested with RDMA over Converged Ethernet (RoCE) networks. However, it is theoretically compatible with Infiniband as well.
For complete implementation details and advanced usage, please refer to the source code.
## Python API Reference
### RDMACommunicator Class
```python
from rdma_comm import RDMACommunicator
# Constructor
comm = RDMACommunicator(
role, # Role ("prefill" or "decode")
gpu_idx, # GPU device index
port, # Communication port
local_key_cache, # List of local key cache pointers
local_value_cache, # List of local value cache pointers
block_number, # Number of blocks
block_bytes # Bytes per block
)
# Methods
comm.connect(dst_ip, dst_port) # Connect to target IP and port
comm.is_connected(dst_ip, dst_port) # Check connection status
comm.write_cache(
ip, # Target server IP address
port, # Target server port number
local_block_ids, # List of local block IDs to transfer
remote_block_ids, # List of remote block IDs to write
layer_idx # Model layer index (for multi-layer models)
)
```
**Parameter Details**:
1. `role`:
- "prefill": Prefill node role
- "decode": Decode node role
2. `gpu_idx`:
- GPU device index to use
3. `port`:
- RDMA communication port number
4. `local_key_cache`/`local_value_cache`:
- List of local KV cache pointers
5. `block_number`:
- Number of cache blocks
6. `block_bytes`:
- Bytes per cache block
**Example Usage**:
```python
import numpy as np
from rdma_comm import RDMACommunicator
# Initialize
local_keys = [np.array([0]*1024, dtype=np.int64).ctypes.data] # Example key pointer
local_values = [np.array([0]*1024, dtype=np.int64).ctypes.data] # Example value pointer
comm = RDMACommunicator(
role="prefill",
gpu_idx=0,
port="12345",
local_key_cache=local_keys,
local_value_cache=local_values,
block_number=1024,
block_bytes=4096
)
# Client connection
comm = RDMACommunicator(
role="prefill",
gpu_idx=0,
port="12345",
local_key_cache=local_keys,
local_value_cache=local_values,
block_number=1024,
block_bytes=4096
)
if comm.connect("192.168.1.100", "12345"):
print("Connection established")
# Write cache
comm.write_cache(
ip="192.168.1.100", # Target server IP
port="12345", # Target server port
local_block_ids=[0,1,2], # Local block IDs to transfer
remote_block_ids=[3,4,5], # Remote block IDs to write
layer_idx=0 # Model layer index (0 for first layer)
)
```
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:

View File

@@ -0,0 +1,232 @@
# KVTransferManager 中文文档
一个专为Prefill节点和Decode节点传输KV Cache的组件支持RDMA通信。
## 性能基准测试
### KVTransferManager 与 Mooncake 性能对比
### 测试场景
- **硬件配置**:
- 单张Mellanox ConnectX-7 400G网卡(单端口)
- 测试参数: BATCH_SIZE = 1538, 块大小 = 1K - 256K
- 单压力线程(threads = 1)
- **对比基准**:
- Mooncake性能使用example目录中的transfer_engine_bench测量
- KVTransferManager使用相同的硬件配置和测试参数
### 性能结果
| Block Size | KVTransferManager | Mooncake | 性能提升 |
|--------|-----------------|----------|----------|
| 1K | 10.67 GB/s | 1.54 GB/s | 6.9倍 |
| 2K | 17.53 GB/s | 3.40 GB/s | 5.2倍 |
| 4K | 28.85 GB/s | 6.95 GB/s | 4.2倍 |
| 8K | 36.56 GB/s | 12.48 GB/s | 2.9倍 |
| 16K | 41.73 GB/s | 23.42 GB/s | 1.8倍 |
| 32K | 43.55 GB/s | 31.58 GB/s | 1.4倍 |
| 64K | 44.46 GB/s | 38.39 GB/s | 1.2倍 |
| 128K | 44.86 GB/s | 40.11 GB/s | 1.1倍 |
| 256K | 45.01 GB/s | 40.71 GB/s | 1.1倍 |
在多压力线程场景下KVTransferManager 和 Mooncake 都能够充分利用 400Gb 网卡带宽,达到接近网卡硬件理论极限的传输性能
## 快速开始
### 系统要求
- 支持的架构:
Hopper GPU
昆仑XPU
Ampere GPU (需启用KVCACHE_GDRCOPY_FLUSH_ENABLE)
### 依赖安装
#### Python包
```bash
pip install pyzmq pybind11[global]
```
#### 系统库(Linux)
```bash
# Ubuntu/Debian
sudo apt-get install -y libibverbs-dev librdmacm-dev
# RHEL/CentOS
sudo yum install -y libibverbs-devel librdmacm-devel
```
#### 硬件要求
- 支持RDMA的网络硬件(如Mellanox网卡)
- 支持的GPU架构: Hopper, 昆仑XPU, Ampere
#### Ampere架构注意事项
要支持Ampere GPU需启用环境变量KVCACHE_GDRCOPY_FLUSH_ENABLE。
- 作用:
在GDRCopy写操作后强制内存刷新确保Ampere架构上的数据一致性。启用后会在最后一个RDMA写操作后触发一个RDMA读操作。
- 原因:
当网卡向CPU发送完成通知时仅表示数据已到达GPU但不保证GPU可以立即读取该数据。为确保数据已完全写入GPU内存且可被GPU读取需要执行读操作。
[NCCL Issue #683](https://github.com/NVIDIA/nccl/issues/683) |
[NCCL Issue #1702](https://github.com/NVIDIA/nccl/issues/1702)
由于上层通常只在轮询完成队列条目(CQE)后发出缓存到达通知,这避免了应用在数据实际写回内存前收到通知的情况。因此,缓存未刷新但应用认为已完成这种潜在问题在实践中被认为是罕见情况。
- 启用方式:
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
### 开发构建
```bash
# 构建并创建SO文件的符号链接
python setup.py bdist_wheel
pip install dist/*.whl
```
## 环境变量配置
### RDMA设置
| 变量 | 默认值 | 描述 |
|------|--------|------|
| `KVCACHE_RDMA_GID_INDEX` | 3 | RDMA GID索引 |
| `KVCACHE_RDMA_NICS` | - | RDMA网卡列表逗号分隔(如"mlx5_0,mlx5_1"),根据 gpu_index 选取ib device设备, 此环境变量必须设置, 根据gpu_index取模选取网卡 |
| `KVCACHE_IB_TIMEOUT` | 18 | InfiniBand通信超时(14-31),超时时间=4.096μs * 2^值(默认18≈1.07秒) |
| `KVCACHE_RELAX_ORDERING` | false | 启用RDMA宽松排序以提高多GPU场景性能。当多个GPU共享同一网卡时推荐启用可缓解TX Pause问题。 |
### 网络设置
| 变量 | 默认值 | 描述 |
|------|--------|------|
| `KVCACHE_SOCKET_IFNAME` | auto | 用于socket通信的网络接口(如"eth0"),如果不设置自动检测第一张可用网卡 |
### 调试
| 变量 | 默认值 | 描述 |
|------|--------|------|
| `KVCACHE_DEBUG` | false | 启用调试日志 |
| `KVCACHE_DEBUG_FILE` | - | 调试日志文件路径 |
| `KVCACHE_ERROR_FILE` | - | 错误日志文件路径 |
### 性能调优
| 变量 | 默认值 | 描述 |
|------|--------|------|
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | 为Ampere GPU启用GDRCopy刷新 |
# 设置RDMA GID索引
export KVCACHE_RDMA_GID_INDEX=3
# 设置RDMA IB设备列表
export KVCACHE_RDMA_NICS=mlx5_0,mlx5_1,mlx5_2
# 指定网络接口
export KVCACHE_SOCKET_IFNAME=eth0
# 启用调试模式
export KVCACHE_DEBUG=1
# 设置日志文件
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
## 网络配置
kvcache transfer已通过RDMA over Converged Ethernet (RoCE)网络全面测试。理论上也兼容Infiniband。
完整实现细节和高级用法,请参考源代码。
## Python API 接口
### RDMACommunicator 类
```python
from rdma_comm import RDMACommunicator
# 构造函数
comm = RDMACommunicator(
role, # 角色("prefill"或"decode")
gpu_idx, # GPU设备索引(0~7)
port, # 通信端口
local_key_cache, # 本地key缓存指针列表
local_value_cache, # 本地value缓存指针列表
block_number, # 块数量
block_bytes # 每块字节数
)
# 方法说明
comm.connect(dst_ip, dst_port) # 连接到目标IP和端口
comm.is_connected(dst_ip, dst_port) # 检查是否已连接
comm.write_cache(
ip, # 目标服务器IP地址
port, # 目标服务器端口号
local_block_ids, # 本地缓存块ID列表指定要传输的本地块
remote_block_ids, # 远程缓存块ID列表指定要写入的远程块
layer_idx # 模型层索引,用于多层模型场景
)
```
**参数说明**:
1. `role`:
- "prefill"
- "decode"
2. `gpu_idx`:
- 使用的GPU设备索引
3. `port`:
- RDMA通信端口号
4. `local_key_cache`/`local_value_cache`:
- 本地KV缓存指针列表
5. `block_number`:
- 缓存块数量
6. `block_bytes`:
- 每个缓存块的字节大小
**示例代码**:
```python
import numpy as np
from rdma_comm import RDMACommunicator
# 初始化
local_keys = [np.array([0]*1024, dtype=np.int64).ctypes.data] # 示例key指针
local_values = [np.array([0]*1024, dtype=np.int64).ctypes.data] # 示例value指针
comm = RDMACommunicator(
role="decode",
gpu_idx=0,
port="12345",
local_key_cache=local_keys,
local_value_cache=local_values,
block_number=1024,
block_bytes=4096
)
# 客户端初始化
comm = RDMACommunicator(
role="prefill",
gpu_idx=0,
port="12345",
local_key_cache=local_keys,
local_value_cache=local_values,
block_number=1024,
block_bytes=4096
)
if comm.connect("192.168.1.100", "12345"):
print("连接成功")
# 写入缓存
comm.write_cache(
ip="192.168.1.100", # 目标服务器IP
port="12345", # 目标服务器端口
local_block_ids=[0,1,2], # 要传输的本地块ID列表
remote_block_ids=[3,4,5], # 要写入的远程块ID列表
layer_idx=0 # 模型层索引(0表示第一层)
)
```
## 引用
如果您使用此代码库,或认为我们的工作有价值,请引用:

View File

@@ -0,0 +1,211 @@
/**
* @file kvcache_connection.h
* @brief RDMA connection management for key-value cache
* @version 1.0.0
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FASTDEPLOY_KVCACHE_CONNECTION_H
#define FASTDEPLOY_KVCACHE_CONNECTION_H
#pragma once
#include <rdma/rdma_cma.h>
#include <rdma/rdma_verbs.h>
#include <sys/epoll.h>
#include <atomic>
#include <string>
#include <vector>
#include <netinet/tcp.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sstream>
#include <netdb.h>
#include <sstream>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <cstring>
#include <netdb.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <unistd.h>
#include <cstring>
#include <memory>
#include <iostream>
#include "kvcache_rdma.h"
#include "util.h"
#define KVCACHE_RDMA_NIC_MAX_LEN 256
#define KVCACHE_RDMA_MAX_NICS 8
#define NAME_MAX 255
#define MAXNAMESIZE 64
#define NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16
/// @brief IB device information structure
struct IbDeviceInfo {
int device;
uint64_t guid;
enum ibv_mtu mtu;
uint64_t busid;
uint8_t port;
uint8_t link;
uint8_t active_mtu;
int speed;
ibv_context* context;
char devName[64];
int realPort;
int maxQp;
};
/// @brief Queue Pair information for RDMA
struct QpInfo {
uint32_t lid;
uint32_t qpn;
uint32_t psn;
union ibv_gid gid;
enum ibv_mtu mtu;
/// @brief Serialize QP info to buffer
void serialize(char* buffer) const {
uint32_t* intBuffer = reinterpret_cast<uint32_t*>(buffer);
intBuffer[0] = htonl(lid);
intBuffer[1] = htonl(qpn);
intBuffer[2] = htonl(psn);
memcpy(buffer + 12, gid.raw, sizeof(gid.raw));
intBuffer[7] = htonl(static_cast<uint32_t>(mtu));
}
/// @brief Deserialize QP info from buffer
void deserialize(const char* buffer) {
const uint32_t* intBuffer = reinterpret_cast<const uint32_t*>(buffer);
lid = ntohl(intBuffer[0]);
qpn = ntohl(intBuffer[1]);
psn = ntohl(intBuffer[2]);
memcpy(gid.raw, buffer + 12, sizeof(gid.raw));
mtu = static_cast<ibv_mtu>(ntohl(intBuffer[7]));
}
static const size_t size = 12 + sizeof(gid.raw) + 4;
};
/// @brief RDMA connection context
struct Connection {
std::atomic<int> connected;
// Memory regions
struct ibv_mr *recv_mr;
struct ibv_mr *send_mr;
// Cache pointers
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer;
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer;
// Memory region lists
std::vector<ibv_mr*> write_cache_key_server_mr_list;
std::vector<ibv_mr*> write_cache_value_server_mr_list;
std::vector<std::vector<ibv_mr*>> write_mr_key_list;
std::vector<std::vector<ibv_mr*>> write_mr_value_list;
// Remote access information
std::vector<void*> write_cache_key_remote_ptr_list;
std::vector<uint32_t> write_cache_key_remote_rkey_list;
std::vector<void*> write_cache_value_remote_ptr_list;
std::vector<uint32_t> write_cache_value_remote_rkey_list;
// Received remote memory information
std::vector<void*> receive_write_cache_key_remote_ptr_list;
std::vector<uint32_t> receive_write_cache_key_remote_rkey_list;
std::vector<void*> receive_write_cache_value_remote_ptr_list;
std::vector<uint32_t> receive_write_cache_value_remote_rkey_list;
std::vector<void *> send_write_cache_key_remote_ptr_list;
std::vector<uint32_t> send_write_cache_key_remote_rkey_list;
std::vector<void *> send_write_cache_value_remote_ptr_list;
std::vector<uint32_t> send_write_cache_value_remote_rkey_list;
// For rdma read operations
std::vector<void*> read_bufs;
std::vector<ibv_mr*> read_mrs;
// Work completion tracking
int wc_count;
int wc_target_count;
// Configuration
int layer_number;
int block_number;
int block_byte_size;
std::string url;
Connection() = default;
~Connection();
};
/// @brief RDMA context structure
struct RdmaContext {
int sock_fd;
struct ibv_context* context;
struct ibv_comp_channel* channel;
struct ibv_pd* pd;
struct ibv_mr* mr;
struct ibv_cq* cq;
struct ibv_qp* qp;
struct ibv_port_attr portinfo;
struct Connection conn;
};
// Global variables
extern std::vector<IbDeviceInfo> g_ib_all_devs;
static int g_kvcache_ib_dev_nums = -1;
// Connection management functions
bool client_exchange_destinations(
struct RdmaContext* ctx,
int ib_port,
unsigned int port,
int gidx,
const std::string& dst_ip);
int server_exchange_qp_info(int connfd, QpInfo* local_dest, QpInfo* rem_dest);
struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd);
bool clear_qp_info(struct RdmaContext* ctx);
// QP modification functions
QpStatus modify_qp_to_rts(struct RdmaContext* ctx, int port, int my_psn,
struct QpInfo* dest, int sgid_id);
bool poll_cq_with_timeout(struct RdmaContext* ctx, int timeout_seconds, int cqe_count);
// Utility functions
int get_port_info(struct ibv_context* Context, int port,
struct ibv_port_attr* attr);
int parse_port_ib_info();
// Memory region exchange
bool client_exchange_mr(struct RdmaContext* ctx);
bool server_exchange_mr(struct RdmaContext* ctx);
bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte_num);
bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int byte_num);
// Network setup
int setup_listening_socket(int port);
int configure_epoll(int sockfd);
std::vector<std::string> get_net_ifname();
#endif // FASTDEPLOY_KVCACHE_CONNECTION_H

View File

@@ -0,0 +1,127 @@
#ifndef KVCACHE_RDMA_H
#define KVCACHE_RDMA_H
#pragma once
#include <rdma/rdma_cma.h>
#include <vector>
#include <string>
#include <map>
#include <mutex>
#include "util.h" // Contains constant definitions
#include "kvcache_connection.h"
#include "log.h"
/**
* @brief RDMA communication handler for key-value cache
*/
class RDMACommunicator {
public:
// Construction/Destruction
RDMACommunicator(std::string &role, int gpu_idx, std::string &port,
std::vector<int64_t> local_key_cache,
std::vector<int64_t> local_value_cache,
int block_number, int block_bytes);
~RDMACommunicator();
// Connection management
int connect(const std::string &dst_ip, const std::string &dst_port);
bool is_connected(const std::string &dst_ip, const std::string &dst_port);
// Core functionality
int write_cache(const std::string &ip, const std::string &port,
const std::vector<int64_t>& local_block_ids,
const std::vector<int64_t>& remote_block_ids,
int32_t layer_idx);
// Server Init
int init_server();
// get socket nic ip
std::string fetch_local_ip();
private:
// Server Core functions
int start_server(int sport, int sgid_idx, int gpu_index);
// Internal implementation methods
void resize_vectors();
void assign_pointers();
void validate_addr();
bool client_mr_register_per_layer(struct RdmaContext *ctx);
bool server_mr_register_per_layer(struct RdmaContext *ctx);
struct ibv_mr* register_memory_region(ibv_pd* pd, void* addr, size_t size,
const std::string& desc, uint32_t access_flags);
bool deregister_memory_regions(struct RdmaContext* ctx);
bool post_block_send(struct RdmaContext* ctx, int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key, std::vector<uint64_t>& remote_addr,
uint32_t rkey, const std::string &ip,
const std::string &port);
bool execute_rdma_writes(struct RdmaContext* ctx, int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key, std::vector<uint64_t>& remote_addr,
uint32_t rkey);
void prepare_write_requests(struct ibv_sge* sge_list,
struct ibv_send_wr* send_wr_list,
int layer_idx,
const std::vector<int64_t>& local_block_ids,
bool is_key,
std::vector<uint64_t>& remote_addr,
uint32_t rkey);
bool execute_read_verification(struct RdmaContext* ctx,
size_t block_idx,
uint64_t remote_addr,
uint32_t rkey,
int layer_idx,
const std::string& ip,
const std::string& port);
bool post_send_with_retry(struct RdmaContext* ctx,
struct ibv_send_wr* wr_list,
size_t inflight_wr,
bool need_poll);
// Connection management
int client_listener();
void close_server_connection(int fd, struct RdmaContext* ctx, int epollfd,
std::map<int, struct RdmaContext*>& connectionContexts);
void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd);
void remove_conn(const std::string& url);
struct RdmaContext *get_conn(const std::string &ip,
const std::string &port);
// Member variables
std::string splitwise_role; // Role in distributed system ("decode" or other)
int gpu_idx; // GPU device index
std::string port; // Communication port
std::vector<int64_t> local_cache_key_ptr_layer_head_; // Key cache pointers
std::vector<int64_t> local_cache_value_ptr_layer_head_; // Value cache pointers
int block_number; // Number of blocks
int block_size_byte; // Size of each block in bytes
int layer_number; // Number of layers
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer; // Per-layer key pointers
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer; // Per-layer value pointers
std::vector<struct ibv_mr*> write_mr_key_list; // Memory regions for key writes
std::vector<struct ibv_mr*> write_mr_value_list; // Memory regions for value writes
std::vector<struct ibv_mr*> write_cache_key_server_mr_list; // Server-side key memory regions
std::vector<struct ibv_mr*> write_cache_value_server_mr_list; // Server-side value memory regions
std::vector<std::string> main_ip_list; // List of local IP addresses
std::map<std::string, struct RdmaContext*> conn_map; // Active connections map
std::mutex mutex_; // Thread synchronization mutex
int rdma_event_channel_epoll_fd; // Epoll file descriptor
struct ibv_pd *g_pd = NULL; // fd
int RDMACommunicator_status; // Communicator status flag
bool start_client_listener = false; // Client listener flag
};
#endif // KVCACHE_RDMA_H

View File

@@ -0,0 +1,117 @@
#pragma once
/**
* @file log.h
* @brief Logging module for key-value cache system
* @version 1.0.0
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <sys/time.h>
#include <unistd.h> //for gethostname
#include <sys/syscall.h>
#include <pthread.h>
#include <string>
#include <ctime>
#include <chrono>
#define KV_IS_DEBUG_ENABLED (std::getenv("KVCACHE_DEBUG"))
#define FILE_NAME(x) (strrchr(x,'/') ? strrchr(x,'/')+1 : x)
static thread_local char __attribute__((__unused__)) str[64];
// for log levels (C++ enum class style in C)
typedef enum {
KV_LOG_LEVEL_INFO = 0,
KV_LOG_LEVEL_DEBUG = 1,
KV_LOG_LEVEL_WARN = 2,
KV_LOG_LEVEL_ERROR = 3
} KVLogLevel;
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc,
int line, const char *fmt, ...) __attribute__ ((format (printf, 5, 6)));
/**
* @brief Unified logging macro to reduce duplication and improve maintainability.
*
* @param level Log level (e.g., INFO, DEBUG, WARN, ERR).
* @param to_terminal If true, the log will be printed to terminal.
* @param ... Format string and arguments (like printf).
*/
#define KV_LOG(level, to_terminal, ...) \
debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__)
// Public logging macros with terminal output
#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__)
#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__)
#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__)
#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__)
#define gettid() ((pid_t)syscall(SYS_gettid))
#define GET_CURRENT_TIME() do { \
time_t timer = time(0); \
struct tm* t = localtime(&timer); \
char hostname[32]; \
gethostname(hostname, 32); \
sprintf(str, "%02d:%02d:%02d][%.32s][%d", \
t->tm_hour, t->tm_min, t->tm_sec, hostname, gettid()); \
} while (0)
#define LOGE(fmt, arg...) do { \
GET_CURRENT_TIME(); \
fprintf(stderr, "[%s][ERR][KV_CACHE][%s:%d] " \
fmt "\n",str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} while (0)
#define LOGW(fmt, arg...) do { \
GET_CURRENT_TIME(); \
fprintf(stderr, "[%s][WARN][KV_CACHE][%s:%d] " \
fmt "\n",str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} while (0)
#define LOGI(fmt, arg...) do { \
GET_CURRENT_TIME(); \
fprintf(stdout, "[%s][INFO][KV_CACHE][%s:%d] " \
fmt "\n",str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} while (0)
#define LOGD(fmt, arg...) do { \
if (KV_IS_DEBUG_ENABLED) { \
GET_CURRENT_TIME(); \
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
fmt "\n", str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} \
} while (0)
#define LOGD_IF(cond, fmt, ...) do { \
if ((cond)) \
LOGD(fmt, __VA_ARGS__); \
} while (0)
#define LOGD_RAW(fmt, arg...) do { \
if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \
GET_CURRENT_TIME(); \
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
fmt "\n", str, \
FILE_NAME(__FILE__), __LINE__, ## arg); \
} \
} while (0)

View File

@@ -0,0 +1,315 @@
#ifndef KVCACHE_UTILS_H
#define KVCACHE_UTILS_H
#include <ctime>
#include <chrono>
#include <iostream>
#include <string>
#include <cstdlib>
#include <algorithm>
#include <cctype>
#include <stdexcept>
#include <cstdio>
#include <ifaddrs.h>
#include <net/if.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <vector>
#include <cstring>
#include "log.h"
#define PATH_MAX 4096 /* # chars in a path name including nul */
#define RDMA_WR_LIST_MAX_SIZE 32
#define RDMA_SQ_MAX_SIZE 1024
#define RDMA_DEFAULT_PORT 20001
#define RDMA_TCP_CONNECT_SIZE 1024
#define RDMA_POLL_CQE_TIMEOUT 30
/// @brief Connection status enumeration
enum class ConnStatus {
kConnected, // Connection is active
kDisconnected, // Connection is not active
kError, // Connection error occurred
kTimeout, // Connection timed out
kInvalidParameters // Invalid connection parameters
};
/// @brief Queue Pair (QP) setup result status
enum class QpStatus {
kSuccess, // Successfully transitioned QP to RTS
kInvalidParameters, // ctx or dest is null
kDeviceQueryFailed, // ibv_query_device failed
kPortQueryFailed, // ibv_query_port failed
kMtuMismatch, // Requested MTU exceeds active MTU
kModifyToRTRFailed, // Failed to modify QP to RTR
kModifyToRTSFailed // Failed to modify QP to RTS
};
/**
* @brief Convert PCI bus ID string to int64_t
* @param busId PCI bus ID string (e.g. "0000:3b:00.0")
* @param[out] id Converted numeric ID
*/
inline void busid_to_int64(const char *busId, int64_t *id) {
char hexStr[17] = {0};
int hexOffset = 0;
// Filter valid hex characters
for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) {
char c = busId[i];
if (c == '.' || c == ':') continue;
if ((c >= '0' && c <= '9') ||
(c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f')) {
hexStr[hexOffset++] = c;
}
}
*id = strtol(hexStr, NULL, 16);
}
class NetworkInterfaceManager {
public:
struct InterfaceInfo {
std::string name;
std::string ip;
bool is_up;
bool is_running;
bool is_loopback;
bool isUsable() const {
return is_up && is_running && !is_loopback;
}
};
static std::vector<InterfaceInfo> getAllInterfaces() {
std::vector<InterfaceInfo> interfaces;
struct ifaddrs *ifaddrs_ptr = nullptr;
if (getifaddrs(&ifaddrs_ptr) == -1) {
return interfaces;
}
for (struct ifaddrs *ifa = ifaddrs_ptr; ifa != nullptr; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == nullptr) continue;
if (ifa->ifa_addr->sa_family != AF_INET) continue;
InterfaceInfo info;
info.name = ifa->ifa_name;
info.is_up = (ifa->ifa_flags & IFF_UP) != 0;
info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0;
info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0;
struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr;
char ip_str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN);
info.ip = ip_str;
interfaces.push_back(info);
}
freeifaddrs(ifaddrs_ptr);
return interfaces;
}
static std::string getFirstUsableInterface() {
auto interfaces = getAllInterfaces();
for (const auto& iface : interfaces) {
if (iface.isUsable()) {
return iface.name;
}
}
return "";
}
static void displayAllInterfaces() {
auto interfaces = getAllInterfaces();
printf("Available network interfaces:\n");
for (const auto& iface : interfaces) {
printf(" %s: %s [%s%s%s]\n",
iface.name.c_str(),
iface.ip.c_str(),
iface.is_up ? "UP" : "DOWN",
iface.is_running ? ",RUNNING" : "",
iface.is_loopback ? ",LOOPBACK" : "");
}
}
};
class KVCacheConfig {
private:
// Configuration values
int rdma_gid_index_;
bool has_rdma_dest_port_override_; // 替代 std::optional
int rdma_dest_port_override_;
const char* socket_interface_;
char* socket_interface_buffer_;
bool gdrcopy_flush_enabled_;
bool verify_read_enabled_;
bool debug_mode_enabled_;
bool debug_output_enabled_;
const char* debug_file_path_;
const char* error_file_path_;
bool relax_ordering_enabled_;
int ib_timeout_;
const char* rdma_nics_;
// Private constructor for singleton pattern
KVCacheConfig() {
// Initialize configuration from environment variables
rdma_gid_index_ = parse_int_value(
std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX");
// Parse optional RDMA port override
const char* port_value = std::getenv("SET_RDMA_DEST_PORT");
has_rdma_dest_port_override_ = false; // 默认为false
if (port_value) {
try {
rdma_dest_port_override_ = std::stoi(std::string(port_value));
has_rdma_dest_port_override_ = true;
} catch (const std::exception& e) {
fprintf(stderr, "Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n", port_value);
}
}
const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME");
if (env_interface && env_interface[0] != '\0') {
socket_interface_ = env_interface;
printf("Using specified interface: %s\n", socket_interface_);
} else {
std::string iface = NetworkInterfaceManager::getFirstUsableInterface();
if (!iface.empty()) {
socket_interface_buffer_ = new char[iface.size() + 1];
std::strcpy(socket_interface_buffer_, iface.c_str());
socket_interface_ = socket_interface_buffer_;
printf("Auto-detected interface: %s\n", socket_interface_);
} else {
fprintf(stderr, "Warning: No usable network interface found\n");
socket_interface_ = "";
}
NetworkInterfaceManager::displayAllInterfaces();
}
socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME");
debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE");
error_file_path_ = std::getenv("KVCACHE_ERROR_FILE");
gdrcopy_flush_enabled_ = parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"));
verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ"));
debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) ||
parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED"));
debug_output_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT"));
relax_ordering_enabled_ = parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING"));
ib_timeout_ = parse_int_value(
std::getenv("KVCACHE_IB_TIMEOUT"),
18,
"KVCACHE_IB_TIMEOUT"
);
rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS");
}
// Helper methods
bool parse_bool_value(const char* value) {
if (!value) return false;
std::string str_value(value);
std::transform(str_value.begin(), str_value.end(), str_value.begin(), ::tolower);
return (str_value == "1" || str_value == "true" ||
str_value == "on" || str_value == "yes");
}
int parse_int_value(const char* value, int default_value, const char* env_name) {
if (!value) return default_value;
try {
return std::stoi(std::string(value));
} catch (const std::invalid_argument& e) {
fprintf(stderr, "Invalid value for %s: '%s', using default: %d\n",
env_name, value, default_value);
return default_value;
} catch (const std::out_of_range& e) {
fprintf(stderr, "%s value out of range: '%s', using default: %d\n",
env_name, value, default_value);
return default_value;
}
}
public:
// Prevent copying and assignment
KVCacheConfig(const KVCacheConfig&) = delete;
KVCacheConfig& operator=(const KVCacheConfig&) = delete;
// Get singleton instance
static KVCacheConfig& getInstance() {
static KVCacheConfig instance;
return instance;
}
int get_ib_timeout() const { return ib_timeout_; }
// Configuration retrieval methods
int get_rdma_gid_index() const { return rdma_gid_index_; }
int resolve_rdma_dest_port(int default_port) const {
return has_rdma_dest_port_override_ ? rdma_dest_port_override_ : default_port;
}
int resolve_rdma_dest_port(const std::string& default_port) const {
try {
return resolve_rdma_dest_port(std::stoi(default_port));
} catch (const std::exception& e) {
fprintf(stderr, "Invalid default port string: %s\n", default_port.c_str());
return 0;
}
}
const char* get_socket_interface() const { return socket_interface_; }
const char* get_debug_file_path() const { return debug_file_path_; }
const char* get_error_file_path() const { return error_file_path_; }
const char* get_rdma_nics() const { return rdma_nics_; }
// Feature check methods
bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; }
bool is_verify_read_enabled() const { return verify_read_enabled_; }
bool is_debug_mode_enabled() const { return debug_mode_enabled_; }
bool is_debug_output_enabled() const { return debug_output_enabled_; }
bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; }
// Display configuration
void displayConfiguration() const {
INFO("KVCache Configuration:\n");
INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_);
if (has_rdma_dest_port_override_) {
INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n", rdma_dest_port_override_);
}
if (socket_interface_) {
INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_);
}
INFO("Init KVCacheConfig GDRCopy Flush: %s\n", gdrcopy_flush_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Verify Read: %s\n", verify_read_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Debug Mode: %s\n", debug_mode_enabled_ ? "enabled" : "disabled");
INFO("Init KVCacheConfig Debug Output: %s\n", debug_output_enabled_ ? "enabled" : "disabled");
if (debug_file_path_) {
INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_);
}
if (error_file_path_) {
INFO("Init KVCacheConfig Error File: %s\n", error_file_path_);
}
}
};
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,212 @@
/**
* @file log.cpp
* @brief Logging module implementation for key-value cache system
* @version 1.0.0
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdlib.h>
#include <stdarg.h>
#include <sys/syscall.h>
#include <sys/stat.h>
#include <libgen.h>
#include <errno.h>
#include <string.h>
#include "log.h"
#include "util.h"
static int pid = -1;
static __thread int tid = -1;
static char hostname[64];
char global_log_last_error[1024] = "";
FILE *global_debug_file = stdout;
FILE *global_error_file = stdout;
static char global_debug_file_name[PATH_MAX+1] = "";
static char global_err_file_name[PATH_MAX+1] = "";
int global_debug_level = -1;
pthread_mutex_t global_debug_lock = PTHREAD_MUTEX_INITIALIZER;
pthread_mutex_t global_log_file_lock = PTHREAD_MUTEX_INITIALIZER;
void log_file_init(FILE **kv_cache_log_file, const char *kv_cache_log_file_env, char *logFileName) {
int c = 0;
char *dfn = logFileName;
while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') {
if (kv_cache_log_file_env[c++] != '%') {
*dfn++ = kv_cache_log_file_env[c - 1];
continue;
}
switch (kv_cache_log_file_env[c++]) {
case '%': // Double %
*dfn++ = '%';
break;
case 'h': // %h = hostname
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
break;
case 'p': // %p = pid
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
break;
default: // Echo everything we don't understand
*dfn++ = '%';
*dfn++ = kv_cache_log_file_env[c - 1];
break;
}
}
*dfn = '\0';
if (logFileName[0] != '\0') {
FILE *file = fopen(logFileName, "w");
if (file != nullptr) {
setbuf(file, nullptr); // disable buffering
*kv_cache_log_file = file;
}
}
}
void recreate_log_file(FILE **kv_cache_log_file, char *logFileName) {
if (logFileName[0] != '\0') {
pthread_mutex_lock(&global_log_file_lock);
FILE *file = fopen(logFileName, "a"); // Use "a" mode to append if file exists, otherwise create it
// close the previous log file if it exists
if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) {
fclose(*kv_cache_log_file);
*kv_cache_log_file = NULL;
}
if (file != NULL) {
setbuf(file, NULL); // disable buffering
*kv_cache_log_file = file;
}
pthread_mutex_unlock(&global_log_file_lock);
}
}
void debug_init() {
pthread_mutex_lock(&global_debug_lock);
if (global_debug_level != -1) {
pthread_mutex_unlock(&global_debug_lock);
return;
}
const char* kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED");
int tempg_kv_cache_debug_level = -1;
if (kv_cache_debug == NULL) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
} else if (strcasecmp(kv_cache_debug, "0") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
} else if (strcasecmp(kv_cache_debug, "1") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG;
} else if (strcasecmp(kv_cache_debug, "2") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN;
} else if (strcasecmp(kv_cache_debug, "3") == 0) {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR;
} else {
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
}
gethostname(hostname, 64);
pid = getpid();
const char* g_kv_cache_debug_fileEnv = KVCacheConfig::getInstance().get_debug_file_path();
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_debug_fileEnv != NULL) {
log_file_init(&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name);
}
const char* g_kv_cache_error_fileEnv = KVCacheConfig::getInstance().get_error_file_path();
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_error_fileEnv != NULL) {
log_file_init(&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name);
char buffer[1024];
size_t len = 0;
char timeBuffer[80]; // Buffer to hold the formatted time
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer);
buffer[len++] = '\n';
if (global_error_file != NULL) {
fwrite(buffer, 1, len, global_error_file);
}
}
__atomic_store_n(&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
pthread_mutex_unlock(&global_debug_lock);
}
/* Common logging function used by the INFO, DEBUG and WARN macros
* Also exported to the dynamically loadable Net transport modules so
* they can share the debugging mechanisms and output files
*/
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc, int line, const char *fmt, ...) {
if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) {
debug_init();
}
// Save the last error (WARN) as a human readable string
if (level == KV_LOG_LEVEL_WARN) {
pthread_mutex_lock(&global_debug_lock);
va_list vargs;
va_start(vargs, fmt);
(void) vsnprintf(global_log_last_error, sizeof(global_log_last_error), fmt, vargs);
va_end(vargs);
pthread_mutex_unlock(&global_debug_lock);
}
if (tid == -1) {
tid = syscall(SYS_gettid);
}
char buffer[1024];
size_t len = 0;
// Convert timestamp to absolute time and directly use it in the snprintf function
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
char timeBuffer[80]; // Buffer to hold the formatted time
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
if (level == KV_LOG_LEVEL_WARN) {
len = snprintf(buffer, sizeof(buffer), "\n%s %s:%d:%d %s:%d KV_CACHE WARN ",
timeBuffer, hostname, pid, tid, filefunc, line);
} else if (level == KV_LOG_LEVEL_INFO) {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE INFO ", timeBuffer, hostname, pid, tid);
} else if (level == KV_LOG_LEVEL_DEBUG) {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE DEBUG ", timeBuffer, hostname, pid, tid);
} else if (level == KV_LOG_LEVEL_ERROR) {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ERROR ", timeBuffer, hostname, pid, tid);
} else {
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ", timeBuffer, hostname, pid, tid);
}
if (len) {
va_list vargs;
va_start(vargs, fmt);
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
va_end(vargs);
// vsnprintf may return len > sizeof(buffer) in the case of a truncated output.
// Rewind len so that we can replace the final \0 by \n
if (len > sizeof(buffer)) {
len = sizeof(buffer) - 1;
}
buffer[len++] = '\n';
if (access(global_debug_file_name, F_OK) != 0) {
recreate_log_file(&global_debug_file, global_debug_file_name);
}
if (enable_to_terminal) {
fwrite(buffer, 1, len, global_debug_file);
}
if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) {
if (access(global_err_file_name, F_OK) != 0) {
recreate_log_file(&global_error_file, global_err_file_name);
}
if (global_error_file != NULL) {
fwrite(buffer, 1, len, global_error_file);
}
}
}
}

View File

@@ -0,0 +1,22 @@
#include "kvcache_connection.h"
#include "kvcache_rdma.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
PYBIND11_MODULE(rdma_comm, m) {
m.doc() = R"pbdoc(kv cache messager)pbdoc";
py::class_<RDMACommunicator>(m, "RDMACommunicator")
.def(py::init<std::string &, int, std::string &, std::vector<int64_t>,
std::vector<int64_t>, int, int>())
.def("connect", &RDMACommunicator::connect)
.def("is_connected", &RDMACommunicator::is_connected)
.def("write_cache", &RDMACommunicator::write_cache);
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}

View File

@@ -0,0 +1,76 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
class RDMACommManager:
"""
RDMACommManager to manage rdma communication
"""
def __init__(self, splitwise_role, rank, gpu_id, cache_k_ptr_list, \
cache_v_ptr_list, max_block_num, block_bytes, rdma_port):
try:
import rdma_comm
except:
logger.error(f"The installation of the RDMA library failed." \
"Confirm whether your network card supports RDMA transmission.")
return
self.messager = rdma_comm.RDMACommunicator(
splitwise_role,
rank,
str(rdma_port) if splitwise_role == "decode" else "0",
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
)
self.splitwise_role = splitwise_role
self.connected_rdma = set()
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
def connect(self, ip, port):
"""
Connect to remote gpu and write cache.
"""
assert self.splitwise_role == "prefill", "only prefill can call this method"
addr = f"{ip}:{str(port)}"
if addr in self.connected_rdma:
return True
ret = self.messager.is_connected(ip, str(port))
if ret:
self.connected_rdma.add(addr)
return True
ret = self.messager.connect(ip, str(port))
logger.info(
f"connect to remote rdma address {ip}:{port} status is {ret}")
if ret == 0:
self.connected_rdma.add(addr)
return ret == 0
def write_cache(self, ip, port, local_block_ids, remote_block_ids,
layer_idx):
"""
Connect to remote gpu and write cache.
"""
return self.messager.write_cache(ip, str(port), local_block_ids,
remote_block_ids, layer_idx)