mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
15
fastdeploy/cache_manager/__init__.py
Normal file
15
fastdeploy/cache_manager/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
162
fastdeploy/cache_manager/cache_data.py
Normal file
162
fastdeploy/cache_manager/cache_data.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
|
||||
|
||||
class CacheStatus(Enum):
|
||||
"""
|
||||
cache status enum class
|
||||
"""
|
||||
|
||||
GPU = 0
|
||||
SWAP2CPU = 1
|
||||
SWAP2GPU = 2
|
||||
CPU = 3
|
||||
|
||||
|
||||
class BlockNode:
|
||||
"""
|
||||
BlockNode: store the information of a block node
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id,
|
||||
input_ids,
|
||||
input_hash_value,
|
||||
depth,
|
||||
block_id,
|
||||
token_num,
|
||||
hash_value,
|
||||
last_used_time,
|
||||
parent=None,
|
||||
shared_count=1,
|
||||
reverved_dec_block_ids=[],
|
||||
cache_status=CacheStatus.GPU,
|
||||
is_persistent=False,
|
||||
persistent_shared_count=0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
node_id: Unique identifier of the node
|
||||
depth: Depth of the node
|
||||
block_id: Assigned block ID (CPU block ID if on CPU, GPU block ID if on GPU)
|
||||
token_num: Number of tokens in the current block
|
||||
hash_value: Hash value of the current block
|
||||
last_used_time: Timestamp of last usage
|
||||
parent: Parent node
|
||||
shared_count: Reference count of requests currently using this node
|
||||
reserved_dec_block_ids: Pre-allocated block IDs reserved for decoding, formatted as [block_id, block_id,...]
|
||||
cache_status: Current cache state (USING, SWAP2CPU, SWAP2GPU, FREE)
|
||||
is_persistent: Whether the node is persistently stored
|
||||
persistent_shared_count: Reference count of persistent cache requests
|
||||
"""
|
||||
|
||||
self.node_id = node_id
|
||||
self.depth = depth
|
||||
self.parent = parent
|
||||
self.hash_value = hash_value
|
||||
self.token_num = token_num
|
||||
self.input_ids = input_ids
|
||||
self.input_hash_value = input_hash_value
|
||||
|
||||
self.children = {}
|
||||
self.shared_count = shared_count
|
||||
self.last_used_time = last_used_time
|
||||
self.block_id = block_id
|
||||
self.reverved_dec_block_ids = reverved_dec_block_ids
|
||||
self.cache_status = cache_status
|
||||
self.is_persistent = is_persistent
|
||||
self.persistent_shared_count = persistent_shared_count
|
||||
self.req_id_set = set()
|
||||
|
||||
def __lt__(self, other):
|
||||
"""
|
||||
override the less than operator
|
||||
"""
|
||||
if self.last_used_time < other.last_used_time:
|
||||
return True
|
||||
elif self.last_used_time > other.last_used_time:
|
||||
return False
|
||||
else:
|
||||
return self.depth > other.depth
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
return node info
|
||||
"""
|
||||
if self.parent is not None:
|
||||
parent_node_id = self.parent.node_id
|
||||
else:
|
||||
parent_node_id = None
|
||||
return (
|
||||
f"node_id {self.node_id}: depth {self.depth} hash_value {self.hash_value}"
|
||||
+
|
||||
f" shared_count {self.shared_count} is_gpu_leaf_node {self.is_gpu_leaf_node}"
|
||||
+
|
||||
f" is_cpu_leaf_node {self.is_cpu_leaf_node} block_id {self.block_id} "
|
||||
+ f"has_in_gpu {self.has_in_gpu} " +
|
||||
f"cache_status {self.cache_status} parent {parent_node_id} with children number "
|
||||
+ f"{len(self.children)} req_id_set {self.req_id_set}")
|
||||
|
||||
@property
|
||||
def has_in_gpu(self):
|
||||
"""
|
||||
check if the node has been allocated in GPU
|
||||
"""
|
||||
return self.cache_status == CacheStatus.GPU
|
||||
|
||||
def increment_shared_count(self):
|
||||
"""
|
||||
increment shared count
|
||||
"""
|
||||
self.shared_count += 1
|
||||
|
||||
def decrement_shared_count(self):
|
||||
"""
|
||||
decrement shared count
|
||||
"""
|
||||
self.shared_count -= 1
|
||||
|
||||
@property
|
||||
def is_cpu_leaf_node(self):
|
||||
"""
|
||||
check if the node is a leaf node in CPU
|
||||
"""
|
||||
if (self.cache_status == CacheStatus.CPU) and (len(self.children)
|
||||
== 0):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_gpu_leaf_node(self):
|
||||
"""
|
||||
check if the node is a leaf node in GPU
|
||||
"""
|
||||
if self.has_in_gpu is False:
|
||||
return False
|
||||
else:
|
||||
if len(self.children) == 0:
|
||||
return True
|
||||
for child in self.children.values():
|
||||
if child.has_in_gpu is True:
|
||||
return False
|
||||
return True
|
318
fastdeploy/cache_manager/cache_messager.py
Normal file
318
fastdeploy/cache_manager/cache_messager.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory import (IPCCommManager,
|
||||
RDMACommManager)
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
class CacheMessager(object):
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
splitwise_role,
|
||||
transfer_protocol,
|
||||
engine_worker_queue_port,
|
||||
local_data_parallel_id,
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
gpu_id=0,
|
||||
rdma_port=None):
|
||||
"""
|
||||
Initialize the CacheMessager object.
|
||||
|
||||
Args:
|
||||
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
||||
transfer_protocol (str): support ipc and rdma
|
||||
engine_worker_queue_port (int): engine_worker_queue port
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
assert splitwise_role in ["prefill", "decode"], \
|
||||
"splitwise_role must be prefill or decode"
|
||||
self.splitwise_role = splitwise_role
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
self.nranks = nranks
|
||||
address = ('0.0.0.0', engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
num_client=self.nranks,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=local_data_parallel_id)
|
||||
transfer_protocol = transfer_protocol.split(",")
|
||||
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}"
|
||||
f"rank: {rank}")
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
cache_k_ptr_list = []
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache = self.gpu_cache_kvs[
|
||||
f'key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
|
||||
val_cache = self.gpu_cache_kvs[
|
||||
f'value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
# 2. initialize the block_bytes
|
||||
cache_shape = key_cache.shape
|
||||
max_block_num = cache_shape[0]
|
||||
block_bytes = math.prod(cache_shape[1:])
|
||||
if key_cache.dtype == paddle.bfloat16:
|
||||
block_bytes *= 2
|
||||
logger.info(
|
||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}")
|
||||
self.block_bytes = block_bytes
|
||||
|
||||
# 3. initialize the messager
|
||||
for protocol in transfer_protocol:
|
||||
if protocol == "ipc":
|
||||
self.messager[protocol] = IPCCommManager(
|
||||
self.rank,
|
||||
gpu_id,
|
||||
cache_k,
|
||||
cache_v,
|
||||
)
|
||||
local_device_id = int(str(cache_k[0].place)[-2])
|
||||
logger.info(
|
||||
f"done create ipc_comm with local_device_id:{local_device_id}, "
|
||||
)
|
||||
|
||||
elif protocol == "rdma":
|
||||
logger.info(
|
||||
f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}"
|
||||
)
|
||||
|
||||
self.messager[protocol] = RDMACommManager(
|
||||
splitwise_role, rank, gpu_id, cache_k_ptr_list,
|
||||
cache_v_ptr_list, max_block_num, block_bytes, rdma_port)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(
|
||||
target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
"""
|
||||
try:
|
||||
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
try:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=True)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=True)
|
||||
except:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=False)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=False)
|
||||
|
||||
step_shm_value.value[0] = -1
|
||||
layer_shm_value.value[0] = -1
|
||||
|
||||
self.last_step_idx = -1
|
||||
self.last_layer_idx = -1 # int32
|
||||
|
||||
while True:
|
||||
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
|
||||
if cache_info:
|
||||
logger.debug(f"cache info {cache_info}")
|
||||
for info in cache_info:
|
||||
if info['request_id'] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
current_info = self.cache_info[info["request_id"]]
|
||||
if "dest_block_ids" in current_info and "src_block_ids" in current_info:
|
||||
current_src_blocks = current_info[
|
||||
"src_block_ids"][-len(current_info["dest_block_ids"]):]
|
||||
current_info[
|
||||
"src_block_ids"] = current_src_blocks
|
||||
current_info["current_layer_ids"] = 0
|
||||
current_info["status"] = "init"
|
||||
logger.info(
|
||||
f"start cache_infos: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.last_step_idx = min(
|
||||
self.last_step_idx, current_info['current_id'])
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
if prefilled_layer_idx == self.num_layers - 1:
|
||||
time.sleep(0.001)
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
|
||||
if prefilled_step_idx == -1:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.debug(
|
||||
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}"
|
||||
)
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
if "layer_idx" not in item:
|
||||
item["layer_idx"] = 0
|
||||
if item['status'] == 'error':
|
||||
del self.cache_info[req_id]
|
||||
continue
|
||||
if item['current_id'] > prefilled_step_idx:
|
||||
continue
|
||||
current_transfer_protocol = item["transfer_protocol"]
|
||||
if item["transfer_protocol"] == "rdma":
|
||||
target_ip = item['ip']
|
||||
target_id = int(item['rdma_ports'][self.rank])
|
||||
status = self.messager[
|
||||
current_transfer_protocol].connect(
|
||||
target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(
|
||||
f"connect to {target_ip}:{target_id} failed")
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([
|
||||
(item['request_id'], "connect error")
|
||||
])
|
||||
continue
|
||||
elif item["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(item['device_ids'][self.rank])
|
||||
src_block_ids = paddle.to_tensor(item['src_block_ids'],
|
||||
dtype='int32',
|
||||
place='cpu')
|
||||
dest_block_ids = paddle.to_tensor(item['dest_block_ids'],
|
||||
dtype='int32',
|
||||
place='cpu')
|
||||
if item['current_id'] < prefilled_step_idx:
|
||||
current_layer_idx = self.num_layers
|
||||
else:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
|
||||
for layer_idx in range(item["layer_idx"],
|
||||
current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[
|
||||
current_transfer_protocol].write_cache(
|
||||
target_ip, target_id, src_block_ids,
|
||||
dest_block_ids, layer_idx)
|
||||
if return_code != 0:
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([
|
||||
(item['request_id'], "write cache error")
|
||||
])
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
)
|
||||
break
|
||||
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
item['layer_idx'] = current_layer_idx
|
||||
if item['layer_idx'] == self.num_layers:
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([
|
||||
(item['request_id'], "finished")
|
||||
])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
del self.cache_info[req_id]
|
||||
|
||||
self.last_step_idx = prefilled_step_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"prefill layerwise send cache thread has exception: {e}")
|
137
fastdeploy/cache_manager/cache_metrics.py
Normal file
137
fastdeploy/cache_manager/cache_metrics.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
|
||||
|
||||
|
||||
|
||||
class CacheMetrics:
|
||||
"""
|
||||
Cache Metrics used to record the cache hit time, token num, request num, etc.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.total_match_time = 0.0
|
||||
self.avg_match_time = 0.0
|
||||
self.min_match_time = 1e9
|
||||
self.max_match_time = 0.0
|
||||
|
||||
# request level
|
||||
self.req_count = 0
|
||||
self.hit_req_count = 0
|
||||
self.hit_req_ratio = 0.0
|
||||
|
||||
# token level
|
||||
self.total_gpu_matched_token_num = 0
|
||||
self.total_cpu_matched_token_num = 0
|
||||
|
||||
self.matched_token_num = 0
|
||||
self.total_token_num = 0
|
||||
self.hit_token_ratio = 0.0
|
||||
self.cpu_hit_token_ratio = 0.0
|
||||
self.gpu_hit_token_ratio = 0.0
|
||||
|
||||
|
||||
def _update_history_hit_metrics(self):
|
||||
"""
|
||||
update hit ratio
|
||||
"""
|
||||
self.hit_req_ratio = self.hit_req_count / self.req_count
|
||||
self.hit_token_ratio = self.matched_token_num / self.total_token_num
|
||||
self.cpu_hit_token_ratio = (
|
||||
self.total_cpu_matched_token_num / self.total_token_num
|
||||
)
|
||||
self.gpu_hit_token_ratio = (
|
||||
self.total_gpu_matched_token_num / self.total_token_num
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Metrics for all requests: req_count {self.req_count} hit_req_count {self.hit_req_count}"
|
||||
+ f" hit_req_ratio {self.hit_req_ratio:.2f} hit_token_ratio {self.hit_token_ratio:.2f}"
|
||||
+ f" gpu_hit_token_ratio {self.gpu_hit_token_ratio:.2f}"
|
||||
+ f" cpu_hit_token_ratio {self.cpu_hit_token_ratio:.2f}"
|
||||
+ f" total_gpu_matched_token_num {self.total_gpu_matched_token_num}"
|
||||
+ f" total_cpu_matched_token_num {self.total_cpu_matched_token_num}"
|
||||
+ f" total_matched_token_num {self.matched_token_num}"
|
||||
+ f" total_token_num {self.total_token_num}"
|
||||
)
|
||||
|
||||
def calculate_hit_metrics(
|
||||
self,
|
||||
req_id,
|
||||
current_query_cpu_match_token_num,
|
||||
current_query_gpu_match_token_num,
|
||||
current_query_token_num,
|
||||
):
|
||||
"""
|
||||
calculate hit metrics for current query
|
||||
"""
|
||||
|
||||
cpu_cache_match_ratio = (
|
||||
current_query_cpu_match_token_num / current_query_token_num
|
||||
)
|
||||
gpu_cache_match_ratio = (
|
||||
current_query_gpu_match_token_num / current_query_token_num
|
||||
)
|
||||
|
||||
total_match_ratio = (
|
||||
cpu_cache_match_ratio + gpu_cache_match_ratio
|
||||
)
|
||||
|
||||
|
||||
self.total_cpu_matched_token_num += (
|
||||
current_query_cpu_match_token_num
|
||||
)
|
||||
self.total_gpu_matched_token_num += (
|
||||
current_query_gpu_match_token_num
|
||||
)
|
||||
|
||||
self.matched_token_num += (
|
||||
current_query_cpu_match_token_num
|
||||
+ current_query_gpu_match_token_num
|
||||
)
|
||||
self.total_token_num += current_query_token_num
|
||||
logger.info(
|
||||
f"Metrics for req_id {req_id}: token_num {current_query_token_num}"
|
||||
+ f" cpu_cache_match_ratio {cpu_cache_match_ratio}"
|
||||
+ f" gpu_cache_match_ratio {gpu_cache_match_ratio}"
|
||||
+ f" total_match_ratio {total_match_ratio}"
|
||||
)
|
||||
|
||||
def reset_metrics(self):
|
||||
"""
|
||||
reset metrics
|
||||
"""
|
||||
self.total_match_time = 0.0
|
||||
self.avg_match_time = 0.0
|
||||
self.min_match_time = 1e9
|
||||
self.max_match_time = 0.0
|
||||
|
||||
self.req_count = 0
|
||||
self.hit_req_count = 0
|
||||
self.hit_req_ratio = 0.0
|
||||
|
||||
self.total_gpu_matched_token_num = 0
|
||||
self.total_cpu_matched_token_num = 0
|
||||
|
||||
self.matched_token_num = 0
|
||||
self.total_token_num = 0
|
||||
self.hit_token_ratio = 0.0
|
||||
self.cpu_hit_token_ratio = 0.0
|
||||
self.gpu_hit_token_ratio = 0.0
|
470
fastdeploy/cache_manager/cache_transfer_manager.py
Normal file
470
fastdeploy/cache_manager/cache_transfer_manager.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import queue
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.engine.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.model_executor.ops.gpu import (cuda_host_alloc, set_data_ipc,
|
||||
swap_cache_all_layers)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache transfer manager")
|
||||
parser.add_argument("--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed")
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="model num layers")
|
||||
parser.add_argument("--head_dim",
|
||||
type=int,
|
||||
default=1,
|
||||
help="model head dim")
|
||||
parser.add_argument("--kv_num_head",
|
||||
type=int,
|
||||
default=1,
|
||||
help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of model parallel")
|
||||
parser.add_argument("--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now")
|
||||
parser.add_argument("--enable_splitwise",
|
||||
type=int,
|
||||
default=0,
|
||||
help="enable splitwise ")
|
||||
parser.add_argument("--cache_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="cache queue port")
|
||||
parser.add_argument("--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port")
|
||||
parser.add_argument("--engine_pid",
|
||||
type=str,
|
||||
default=None,
|
||||
help="engine pid")
|
||||
|
||||
parser.add_argument("--num_gpu_blocks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="gpu cache block number")
|
||||
parser.add_argument("--num_cpu_blocks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="cpu cache block number")
|
||||
parser.add_argument("--block_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="cache block size(tokens)")
|
||||
parser.add_argument("--bytes_per_layer_per_block",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="per layer per block bytes")
|
||||
parser.add_argument("--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
help="cache dtype")
|
||||
parser.add_argument("--speculative_config",
|
||||
type=json.loads,
|
||||
default="{}",
|
||||
help="speculative config")
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class CacheTransferManager:
|
||||
"""
|
||||
管理CPU和GPU之间缓存的交换传输
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
"""
|
||||
初始化CacheTransferManager
|
||||
"""
|
||||
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
self.gpu_cache_kvs = {}
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
self.gpu_cache_v_tensors = []
|
||||
self.speculative_config = SpeculativeConfig(**args.speculative_config)
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = \
|
||||
int(args.num_gpu_blocks * \
|
||||
self.speculative_config.num_gpu_block_expand_ratio)
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1)
|
||||
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
|
||||
address = ('0.0.0.0', args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
num_client=args.mp_num,
|
||||
client_id=rank,
|
||||
local_data_parallel_id=args.local_data_parallel_id)
|
||||
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
cache_type = args.cache_dtype
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else \
|
||||
self.num_extra_layer_gpu_blocks
|
||||
|
||||
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
|
||||
i, rank, device)] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_k_tensors.append(
|
||||
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
|
||||
i, rank, device)])
|
||||
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
|
||||
i, rank, device)] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_v_tensors.append(
|
||||
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
|
||||
i, rank, device)])
|
||||
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
|
||||
i, rank, device)],
|
||||
"key_caches_{}_rank{}.device{}".format(i, rank, device))
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
|
||||
i, rank, device)],
|
||||
"value_caches_{}_rank{}.device{}".format(i, rank, device))
|
||||
cache_kv_size_byte = sum(
|
||||
[tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(
|
||||
f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
|
||||
)
|
||||
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
self.cpu_cache_kvs["key_caches_{}_rank{}".format(
|
||||
i, rank)] = cuda_host_alloc(args.num_cpu_blocks *
|
||||
args.bytes_per_layer_per_block)
|
||||
self.k_dst_ptrs.append(
|
||||
self.cpu_cache_kvs["key_caches_{}_rank{}".format(i, rank)])
|
||||
self.cpu_cache_kvs["value_caches_{}_rank{}".format(
|
||||
i, rank)] = cuda_host_alloc(args.num_cpu_blocks *
|
||||
args.bytes_per_layer_per_block)
|
||||
self.v_dst_ptrs.append(
|
||||
self.cpu_cache_kvs["value_caches_{}_rank{}".format(i, rank)])
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False)
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
if args.enable_splitwise:
|
||||
logger.debug("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
from fastdeploy.cache_manager.cache_messager import CacheMessager
|
||||
|
||||
self.cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=self.gpu_cache_kvs,
|
||||
rank=self.rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + self.num_extra_layers,
|
||||
gpu_id=self.device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
logger.info("successfully create cache messager")
|
||||
logger.info(
|
||||
f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}"
|
||||
)
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
name="cache_task_broadcast_signal",
|
||||
array=cache_task_broadcast_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False)
|
||||
|
||||
def _do_swap_to_cpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id,
|
||||
event_type, transfer_task_id):
|
||||
"""
|
||||
swap cache GPU->CPU
|
||||
"""
|
||||
self.cache_task_queue.swap_to_cpu_barrier1.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.swap_to_cpu_barrier1.reset()
|
||||
result = self._transfer_data(
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
self.cache_task_queue.swap_to_cpu_barrier2.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.swap_to_cpu_barrier2.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result)
|
||||
logger.debug(
|
||||
f"_do_swap_to_cpu_task: put_transfer_done_signal {result}")
|
||||
logger.info(
|
||||
f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
|
||||
)
|
||||
|
||||
def _do_swap_to_gpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id,
|
||||
event_type, transfer_task_id):
|
||||
"""
|
||||
swap cache CPU->GPU
|
||||
"""
|
||||
self.cache_task_queue.swap_to_gpu_barrier1.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.swap_to_gpu_barrier1.reset()
|
||||
result = self._transfer_data(
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
self.cache_task_queue.swap_to_gpu_barrier2.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.swap_to_gpu_barrier2.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result)
|
||||
logger.debug(
|
||||
f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
|
||||
logger.info(
|
||||
f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
|
||||
)
|
||||
|
||||
def do_data_transfer(self):
|
||||
"""
|
||||
do data transfer task
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
if self.rank == 0:
|
||||
if not self.cache_task_queue.empty():
|
||||
self.cache_task_broadcast_signal.value[0] = 1
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier1.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.barrier1.reset()
|
||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||
data, read_finish = self.cache_task_queue.get_transfer_task(
|
||||
)
|
||||
logger.debug(f"transfer data: get_transfer_task {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
(
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
) = data
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
self.swap_to_cpu_thread_pool.submit(
|
||||
self._do_swap_to_cpu_task,
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
else:
|
||||
self.swap_to_gpu_thread_pool.submit(
|
||||
self._do_swap_to_gpu_task,
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier2.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.barrier2.reset()
|
||||
continue
|
||||
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier3.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.barrier3.reset()
|
||||
except Exception as e:
|
||||
logger.info(f"do_data_transfer: error: {e}")
|
||||
|
||||
def _transfer_data(
|
||||
self,
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
):
|
||||
"""
|
||||
transfer data
|
||||
task_gpu_block_id format: [[block_id0, [fold_block_id0, fold_block_id1]],
|
||||
[block_id1, [fold_block_id0, fold_block_id1]], ...]
|
||||
"""
|
||||
logger.debug(
|
||||
f"transfer data: transfer_task_id {transfer_task_id}: swap_node_ids {swap_node_ids}"
|
||||
+
|
||||
f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
|
||||
)
|
||||
start_time = time.time()
|
||||
try:
|
||||
# transform block id
|
||||
assert len(task_gpu_block_id) == len(task_cpu_block_id)
|
||||
gpu_block_ids = task_gpu_block_id
|
||||
cpu_block_ids = task_cpu_block_id
|
||||
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.k_dst_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
0,
|
||||
)
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.v_dst_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
0,
|
||||
)
|
||||
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.k_dst_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
1,
|
||||
)
|
||||
swap_cache_all_layers(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.v_dst_ptrs,
|
||||
self.num_cpu_blocks,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"transfer data: error: {e}")
|
||||
raise e
|
||||
end_time = time.time()
|
||||
elasped_time = end_time - start_time
|
||||
logger.info(
|
||||
f"transfer data: transfer_task_id {transfer_task_id} event_type {event_type}: "
|
||||
+
|
||||
f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
|
||||
)
|
||||
return (
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
启动cache manager
|
||||
"""
|
||||
|
||||
cache_manager = CacheTransferManager(args)
|
||||
|
||||
cache_manager.do_data_transfer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
|
||||
main()
|
1033
fastdeploy/cache_manager/prefix_cache_manager.py
Normal file
1033
fastdeploy/cache_manager/prefix_cache_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
17
fastdeploy/cache_manager/transfer_factory/__init__.py
Normal file
17
fastdeploy/cache_manager/transfer_factory/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from .ipc_cache_transfer import IPCCommManager
|
||||
from .rdma_cache_transfer import RDMACommManager
|
133
fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py
Normal file
133
fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_data_ptr_ipc, ipc_sent_key_value_cache_by_remote_ptr,
|
||||
ipc_sent_key_value_cache_by_remote_ptr_block_sync)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
class IPCConnector:
|
||||
"""
|
||||
IPC communication class.
|
||||
"""
|
||||
|
||||
def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_):
|
||||
"""
|
||||
Args:
|
||||
rank_id_: rank id
|
||||
remote_gpu_id_: remote gpu id
|
||||
|
||||
"""
|
||||
self.remote_key_tensor_ptr_list = []
|
||||
self.remote_value_tensor_ptr_list = []
|
||||
self.remote_gpu_id = int(remote_gpu_id_)
|
||||
self.rank_id = rank_id_
|
||||
self.local_gpu_id = int(local_gpu_id_)
|
||||
tmp = paddle.ones([1, 1])
|
||||
logger.info(
|
||||
f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}"
|
||||
)
|
||||
for layer_id in range(layer_num):
|
||||
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
|
||||
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
|
||||
self.remote_key_tensor_ptr_list.append(
|
||||
get_data_ptr_ipc(tmp, key_unique_name))
|
||||
self.remote_value_tensor_ptr_list.append(
|
||||
get_data_ptr_ipc(tmp, value_unique_name))
|
||||
self.write_stream = paddle.device.Stream(f'gpu:{self.local_gpu_id}')
|
||||
self.finish_event = paddle.device.Event()
|
||||
|
||||
|
||||
class IPCCommManager:
|
||||
"""
|
||||
IPC communication manager, used to initialize ipc and cache transmission.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank_id_,
|
||||
gpu_idx_,
|
||||
local_key_cache_tensor_list, # tensor list
|
||||
local_value_cache_tensor_list, # tensor
|
||||
):
|
||||
self.rank_id = rank_id_
|
||||
self.gpu_idx = gpu_idx_
|
||||
# local cache to tensor
|
||||
self.local_key_cache_tensor_list = local_key_cache_tensor_list
|
||||
self.local_value_cache_tensor_list = local_value_cache_tensor_list
|
||||
self.layer_num = len(self.local_key_cache_tensor_list)
|
||||
# record connected ipc info
|
||||
self.comm_map = {}
|
||||
|
||||
def connect(self, remote_gpu_id_=0):
|
||||
"""
|
||||
Connect to remote gpu.
|
||||
"""
|
||||
logger.info(
|
||||
f"{self.rank_id}: connect to remote_gpu_id:{remote_gpu_id_} {self.layer_num} {self.gpu_idx}"
|
||||
)
|
||||
if self.is_connected(remote_gpu_id_):
|
||||
return True
|
||||
else:
|
||||
self.comm_map[remote_gpu_id_] = IPCConnector(
|
||||
self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
|
||||
return True
|
||||
|
||||
def is_connected(self, remote_gpu_id_=0):
|
||||
"""
|
||||
Check if remote gpu is connected.
|
||||
"""
|
||||
if remote_gpu_id_ in self.comm_map.keys():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def write_cache(self, ip, remote_gpu_id, local_block_ids, remote_block_ids,
|
||||
layer_idx):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
block_num = len(local_block_ids)
|
||||
if not self.is_connected(remote_gpu_id):
|
||||
self.connect(remote_gpu_id)
|
||||
comm = self.comm_map[remote_gpu_id]
|
||||
with paddle.device.stream_guard(comm.write_stream):
|
||||
ipc_sent_key_value_cache_by_remote_ptr(
|
||||
self.local_key_cache_tensor_list[layer_idx],
|
||||
self.local_value_cache_tensor_list[layer_idx], local_block_ids,
|
||||
remote_block_ids, comm.remote_key_tensor_ptr_list[layer_idx],
|
||||
comm.remote_value_tensor_ptr_list[layer_idx], block_num,
|
||||
self.gpu_idx, comm.remote_gpu_id,
|
||||
comm.write_stream.stream_base.cuda_stream)
|
||||
return 0
|
||||
|
||||
def write_block_by_sync(self, remote_gpu_id):
|
||||
"""
|
||||
check finish event and wait for it
|
||||
"""
|
||||
paddle.set_device(f'gpu:{self.gpu_idx}')
|
||||
comm = self.comm_map[remote_gpu_id]
|
||||
ipc_sent_key_value_cache_by_remote_ptr_block_sync(
|
||||
self.local_key_cache_tensor_list[0], #tensor no use
|
||||
self.local_value_cache_tensor_list[0], #tensor no use
|
||||
comm.write_stream.stream_base.cuda_stream)
|
@@ -0,0 +1,35 @@
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
project(rdma_comm LANGUAGES CXX)
|
||||
|
||||
set(PROJECT_SOURCE_DIR ${CMAKE_SOURCE_DIR})
|
||||
set(CMAKE_BINARY_DIR ${CMAKE_SOURCE_DIR}/bin)
|
||||
set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR})
|
||||
set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR})
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS ON)
|
||||
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
|
||||
set(CMAKE_CXX_COMPILER g++)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -Ofast -ffast-math -funroll-loops -march=native -std=c++11")
|
||||
add_compile_options("-std=c++11")
|
||||
|
||||
find_library(IBVERBS_LIBRARY ibverbs)
|
||||
find_library(RDMACM_LIBRARY rdmacm)
|
||||
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
|
||||
include_directories("${PROJECT_SOURCE_DIR}/include")
|
||||
|
||||
add_library(rdma_comm MODULE ${PROJECT_SOURCE_DIR}/src/pybind.cpp ${PROJECT_SOURCE_DIR}/src/kvcache_rdma.cpp ${PROJECT_SOURCE_DIR}/src/kvcache_connection.cpp ${PROJECT_SOURCE_DIR}/src/log.cpp)
|
||||
set_target_properties(rdma_comm PROPERTIES
|
||||
OUTPUT_NAME "rdma_comm"
|
||||
PREFIX ""
|
||||
SUFFIX ".so"
|
||||
)
|
||||
target_link_libraries(rdma_comm PRIVATE pybind11::module)
|
||||
|
||||
target_link_libraries(rdma_comm LINK_PUBLIC ibverbs pthread)
|
@@ -0,0 +1,232 @@
|
||||
# KVTransferManager
|
||||
|
||||
A dedicated component for transferring KV Cache between Prefill and Decode nodes, supporting RDMA communication with ultra-low latency.
|
||||
|
||||
## Performance Benchmark
|
||||
|
||||
### KVTransferManager vs Mooncake Performance Comparison
|
||||
|
||||
### Test Scenario
|
||||
- **Hardware Configuration**:
|
||||
- Single Mellanox ConnectX-7 400G NIC (single port)
|
||||
- Tested with BATCH_SIZE = 1538 and block size = 1K - 256K
|
||||
- Single pressure thread (threads = 1)
|
||||
|
||||
- **Comparison Baseline**:
|
||||
- Mooncake performance measured using transfer_engine_bench from example directory
|
||||
- Same hardware configuration and test parameters applied to KVTransferManager
|
||||
|
||||
### Performance Results
|
||||
| Block Size | KVTransferManager | Mooncake | Performance Gain |
|
||||
|------------|-----------------|----------|------------------|
|
||||
| 1K | 10.67 GB/s | 1.54 GB/s | 6.9x |
|
||||
| 2K | 17.53 GB/s | 3.40 GB/s | 5.2x |
|
||||
| 4K | 28.85 GB/s | 6.95 GB/s | 4.2x |
|
||||
| 8K | 36.56 GB/s | 12.48 GB/s | 2.9x |
|
||||
| 16K | 41.73 GB/s | 23.42 GB/s | 1.8x |
|
||||
| 32K | 43.55 GB/s | 31.58 GB/s | 1.4x |
|
||||
| 64K | 44.46 GB/s | 38.39 GB/s | 1.2x |
|
||||
| 128K | 44.86 GB/s | 40.11 GB/s | 1.1x |
|
||||
| 256K | 45.01 GB/s | 40.71 GB/s | 1.1x |
|
||||
|
||||
Bandwidth Saturation Capability: Under multi-threaded high-pressure scenarios, both KVTransferManager and Mooncake can fully utilize the 400G network card bandwidth, achieving transmission performance close to the theoretical hardware limit (approximately 50 GB/s).
|
||||
|
||||
## Quick start
|
||||
|
||||
### Requirements
|
||||
|
||||
- Supported Architectures:
|
||||
Hopper GPUs
|
||||
Kunlun XPU
|
||||
Ampere GPUs (supported by enabling KVCACHE_GDRCOPY_FLUSH_ENABLE)
|
||||
### Dependencies Installation
|
||||
|
||||
#### Python Packages
|
||||
```bash
|
||||
pip install pyzmq pybind11[global]
|
||||
```
|
||||
|
||||
#### System Libraries (Linux)
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install -y libibverbs-dev librdmacm-dev
|
||||
|
||||
# RHEL/CentOS
|
||||
sudo yum install -y libibverbs-devel librdmacm-devel
|
||||
```
|
||||
|
||||
#### Hardware Requirements
|
||||
- RDMA-capable network hardware (e.g. Mellanox NICs)
|
||||
- Supported GPU architectures: Hopper, Kunlun XPU, Ampere
|
||||
|
||||
#### Ampere Architecture Note
|
||||
To support Ampere GPUs, enable the environment variable KVCACHE_GDRCOPY_FLUSH_ENABLE.
|
||||
- What it does:
|
||||
Forces memory flushing after a GDRCopy write operation to ensure data consistency on the Ampere architecture. Here if KVCACHE_GDRCOPY_FLUSH_ENABLE is enable we trigger an RDMA read operation after the last RDMA write.
|
||||
- Why it’s needed:
|
||||
When the NIC delivers a completion to the CPU, it indicates that the data has reach the GPU. However, it doesn't mean that the GPU can read that data yet. To make sure the data has gone all the way down to the GPU memory and the GPU can read it, we need to perform a read.
|
||||
[NCCL Issue #683](https://github.com/NVIDIA/nccl/issues/683) |
|
||||
[NCCL Issue #1702](https://github.com/NVIDIA/nccl/issues/1702)
|
||||
Since the upper layer typically issues a cache arrival notification only after polling a Completion Queue Entry (CQE), this prevents the application from being notified before the data is actually written back to memory. Therefore, the potential race condition where the cache has not yet been flushed but the application assumes completion is considered a rare event in practice.
|
||||
- How to enable:
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
### Development
|
||||
|
||||
```bash
|
||||
# Build and make symbolic links for SO files
|
||||
python setup.py bdist_wheel
|
||||
|
||||
pip install dist/*.whl
|
||||
```
|
||||
|
||||
## Environment Variables Configuration
|
||||
|
||||
### RDMA Settings
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `KVCACHE_RDMA_GID_INDEX` | 3 | RDMA GID index |
|
||||
| `KVCACHE_RDMA_NICS` | - | RDMA NIC list, comma-separated (e.g., “mlx5_0,mlx5_1”), selects ib device based on gpu_index. This environment variable must be set. NICs are selected using modulo operation on gpu_index. |
|
||||
| `KVCACHE_IB_TIMEOUT` | 18 | InfiniBand communication timeout (14-31), where timeout = 4.096μs * 2^value (default 18 ≈ 1.07s).|
|
||||
| `KVCACHE_RELAX_ORDERING` | false | Enable RDMA relaxed ordering to improve performance in multi-GPU scenarios. Recommended when multiple GPUs share the same NIC to mitigate TX pause issues. |
|
||||
|
||||
### Network Settings
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `KVCACHE_SOCKET_IFNAME` | auto | Network interface for socket comm (e.g. "eth0") |
|
||||
|
||||
### Debugging
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `KVCACHE_DEBUG` | false | Enable debug logging |
|
||||
| `KVCACHE_DEBUG_FILE` | - | Debug log file path |
|
||||
| `KVCACHE_ERROR_FILE` | - | Error log file path |
|
||||
|
||||
### Performance Tuning
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | Enable GDRCopy flush for Ampere GPUs |
|
||||
|
||||
|
||||
|
||||
# Set RDMA GID index
|
||||
export KVCACHE_RDMA_GID_INDEX=3
|
||||
|
||||
# Set RDMA IB Device List
|
||||
export KVCACHE_RDMA_NICS=mlx5_0,mlx5_1,mlx5_2
|
||||
|
||||
# Specify network interface
|
||||
export KVCACHE_SOCKET_IFNAME=eth0
|
||||
|
||||
# Enable debug mode
|
||||
export KVCACHE_DEBUG=1
|
||||
|
||||
# Set log files
|
||||
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
|
||||
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
|
||||
|
||||
|
||||
## Network configurations
|
||||
|
||||
kvcache transfer is fully tested with RDMA over Converged Ethernet (RoCE) networks. However, it is theoretically compatible with Infiniband as well.
|
||||
|
||||
For complete implementation details and advanced usage, please refer to the source code.
|
||||
|
||||
## Python API Reference
|
||||
|
||||
### RDMACommunicator Class
|
||||
|
||||
```python
|
||||
from rdma_comm import RDMACommunicator
|
||||
|
||||
# Constructor
|
||||
comm = RDMACommunicator(
|
||||
role, # Role ("prefill" or "decode")
|
||||
gpu_idx, # GPU device index
|
||||
port, # Communication port
|
||||
local_key_cache, # List of local key cache pointers
|
||||
local_value_cache, # List of local value cache pointers
|
||||
block_number, # Number of blocks
|
||||
block_bytes # Bytes per block
|
||||
)
|
||||
|
||||
# Methods
|
||||
comm.connect(dst_ip, dst_port) # Connect to target IP and port
|
||||
comm.is_connected(dst_ip, dst_port) # Check connection status
|
||||
comm.write_cache(
|
||||
ip, # Target server IP address
|
||||
port, # Target server port number
|
||||
local_block_ids, # List of local block IDs to transfer
|
||||
remote_block_ids, # List of remote block IDs to write
|
||||
layer_idx # Model layer index (for multi-layer models)
|
||||
)
|
||||
```
|
||||
|
||||
**Parameter Details**:
|
||||
|
||||
1. `role`:
|
||||
- "prefill": Prefill node role
|
||||
- "decode": Decode node role
|
||||
|
||||
2. `gpu_idx`:
|
||||
- GPU device index to use
|
||||
|
||||
3. `port`:
|
||||
- RDMA communication port number
|
||||
|
||||
4. `local_key_cache`/`local_value_cache`:
|
||||
- List of local KV cache pointers
|
||||
|
||||
5. `block_number`:
|
||||
- Number of cache blocks
|
||||
|
||||
6. `block_bytes`:
|
||||
- Bytes per cache block
|
||||
|
||||
**Example Usage**:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from rdma_comm import RDMACommunicator
|
||||
|
||||
# Initialize
|
||||
local_keys = [np.array([0]*1024, dtype=np.int64).ctypes.data] # Example key pointer
|
||||
local_values = [np.array([0]*1024, dtype=np.int64).ctypes.data] # Example value pointer
|
||||
|
||||
comm = RDMACommunicator(
|
||||
role="prefill",
|
||||
gpu_idx=0,
|
||||
port="12345",
|
||||
local_key_cache=local_keys,
|
||||
local_value_cache=local_values,
|
||||
block_number=1024,
|
||||
block_bytes=4096
|
||||
)
|
||||
|
||||
# Client connection
|
||||
comm = RDMACommunicator(
|
||||
role="prefill",
|
||||
gpu_idx=0,
|
||||
port="12345",
|
||||
local_key_cache=local_keys,
|
||||
local_value_cache=local_values,
|
||||
block_number=1024,
|
||||
block_bytes=4096
|
||||
)
|
||||
|
||||
if comm.connect("192.168.1.100", "12345"):
|
||||
print("Connection established")
|
||||
|
||||
# Write cache
|
||||
comm.write_cache(
|
||||
ip="192.168.1.100", # Target server IP
|
||||
port="12345", # Target server port
|
||||
local_block_ids=[0,1,2], # Local block IDs to transfer
|
||||
remote_block_ids=[3,4,5], # Remote block IDs to write
|
||||
layer_idx=0 # Model layer index (0 for first layer)
|
||||
)
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
@@ -0,0 +1,232 @@
|
||||
# KVTransferManager 中文文档
|
||||
|
||||
一个专为Prefill节点和Decode节点传输KV Cache的组件,支持RDMA通信。
|
||||
|
||||
## 性能基准测试
|
||||
|
||||
### KVTransferManager 与 Mooncake 性能对比
|
||||
|
||||
### 测试场景
|
||||
- **硬件配置**:
|
||||
- 单张Mellanox ConnectX-7 400G网卡(单端口)
|
||||
- 测试参数: BATCH_SIZE = 1538, 块大小 = 1K - 256K
|
||||
- 单压力线程(threads = 1)
|
||||
|
||||
- **对比基准**:
|
||||
- Mooncake性能使用example目录中的transfer_engine_bench测量
|
||||
- KVTransferManager使用相同的硬件配置和测试参数
|
||||
|
||||
### 性能结果
|
||||
| Block Size | KVTransferManager | Mooncake | 性能提升 |
|
||||
|--------|-----------------|----------|----------|
|
||||
| 1K | 10.67 GB/s | 1.54 GB/s | 6.9倍 |
|
||||
| 2K | 17.53 GB/s | 3.40 GB/s | 5.2倍 |
|
||||
| 4K | 28.85 GB/s | 6.95 GB/s | 4.2倍 |
|
||||
| 8K | 36.56 GB/s | 12.48 GB/s | 2.9倍 |
|
||||
| 16K | 41.73 GB/s | 23.42 GB/s | 1.8倍 |
|
||||
| 32K | 43.55 GB/s | 31.58 GB/s | 1.4倍 |
|
||||
| 64K | 44.46 GB/s | 38.39 GB/s | 1.2倍 |
|
||||
| 128K | 44.86 GB/s | 40.11 GB/s | 1.1倍 |
|
||||
| 256K | 45.01 GB/s | 40.71 GB/s | 1.1倍 |
|
||||
|
||||
在多压力线程场景下,KVTransferManager 和 Mooncake 都能够充分利用 400Gb 网卡带宽,达到接近网卡硬件理论极限的传输性能
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 系统要求
|
||||
|
||||
- 支持的架构:
|
||||
Hopper GPU
|
||||
昆仑XPU
|
||||
Ampere GPU (需启用KVCACHE_GDRCOPY_FLUSH_ENABLE)
|
||||
|
||||
### 依赖安装
|
||||
|
||||
#### Python包
|
||||
```bash
|
||||
pip install pyzmq pybind11[global]
|
||||
```
|
||||
|
||||
#### 系统库(Linux)
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install -y libibverbs-dev librdmacm-dev
|
||||
|
||||
# RHEL/CentOS
|
||||
sudo yum install -y libibverbs-devel librdmacm-devel
|
||||
```
|
||||
|
||||
#### 硬件要求
|
||||
- 支持RDMA的网络硬件(如Mellanox网卡)
|
||||
- 支持的GPU架构: Hopper, 昆仑XPU, Ampere
|
||||
|
||||
#### Ampere架构注意事项
|
||||
要支持Ampere GPU,需启用环境变量KVCACHE_GDRCOPY_FLUSH_ENABLE。
|
||||
- 作用:
|
||||
在GDRCopy写操作后强制内存刷新,确保Ampere架构上的数据一致性。启用后会在最后一个RDMA写操作后触发一个RDMA读操作。
|
||||
- 原因:
|
||||
当网卡向CPU发送完成通知时,仅表示数据已到达GPU,但不保证GPU可以立即读取该数据。为确保数据已完全写入GPU内存且可被GPU读取,需要执行读操作。
|
||||
[NCCL Issue #683](https://github.com/NVIDIA/nccl/issues/683) |
|
||||
[NCCL Issue #1702](https://github.com/NVIDIA/nccl/issues/1702)
|
||||
由于上层通常只在轮询完成队列条目(CQE)后发出缓存到达通知,这避免了应用在数据实际写回内存前收到通知的情况。因此,缓存未刷新但应用认为已完成这种潜在问题在实践中被认为是罕见情况。
|
||||
- 启用方式:
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
### 开发构建
|
||||
|
||||
```bash
|
||||
# 构建并创建SO文件的符号链接
|
||||
python setup.py bdist_wheel
|
||||
|
||||
pip install dist/*.whl
|
||||
```
|
||||
|
||||
## 环境变量配置
|
||||
|
||||
### RDMA设置
|
||||
| 变量 | 默认值 | 描述 |
|
||||
|------|--------|------|
|
||||
| `KVCACHE_RDMA_GID_INDEX` | 3 | RDMA GID索引 |
|
||||
| `KVCACHE_RDMA_NICS` | - | RDMA网卡列表,逗号分隔(如"mlx5_0,mlx5_1"),根据 gpu_index 选取ib device设备, 此环境变量必须设置, 根据gpu_index取模选取网卡 |
|
||||
| `KVCACHE_IB_TIMEOUT` | 18 | InfiniBand通信超时(14-31),超时时间=4.096μs * 2^值(默认18≈1.07秒) |
|
||||
| `KVCACHE_RELAX_ORDERING` | false | 启用RDMA宽松排序以提高多GPU场景性能。当多个GPU共享同一网卡时推荐启用,可缓解TX Pause问题。 |
|
||||
|
||||
### 网络设置
|
||||
| 变量 | 默认值 | 描述 |
|
||||
|------|--------|------|
|
||||
| `KVCACHE_SOCKET_IFNAME` | auto | 用于socket通信的网络接口(如"eth0"),如果不设置自动检测第一张可用网卡 |
|
||||
|
||||
### 调试
|
||||
| 变量 | 默认值 | 描述 |
|
||||
|------|--------|------|
|
||||
| `KVCACHE_DEBUG` | false | 启用调试日志 |
|
||||
| `KVCACHE_DEBUG_FILE` | - | 调试日志文件路径 |
|
||||
| `KVCACHE_ERROR_FILE` | - | 错误日志文件路径 |
|
||||
|
||||
### 性能调优
|
||||
| 变量 | 默认值 | 描述 |
|
||||
|------|--------|------|
|
||||
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | 为Ampere GPU启用GDRCopy刷新 |
|
||||
|
||||
|
||||
# 设置RDMA GID索引
|
||||
export KVCACHE_RDMA_GID_INDEX=3
|
||||
|
||||
# 设置RDMA IB设备列表
|
||||
export KVCACHE_RDMA_NICS=mlx5_0,mlx5_1,mlx5_2
|
||||
|
||||
# 指定网络接口
|
||||
export KVCACHE_SOCKET_IFNAME=eth0
|
||||
|
||||
# 启用调试模式
|
||||
export KVCACHE_DEBUG=1
|
||||
|
||||
# 设置日志文件
|
||||
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
|
||||
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
|
||||
|
||||
|
||||
## 网络配置
|
||||
|
||||
kvcache transfer已通过RDMA over Converged Ethernet (RoCE)网络全面测试。理论上也兼容Infiniband。
|
||||
|
||||
完整实现细节和高级用法,请参考源代码。
|
||||
|
||||
## Python API 接口
|
||||
|
||||
### RDMACommunicator 类
|
||||
|
||||
```python
|
||||
from rdma_comm import RDMACommunicator
|
||||
|
||||
# 构造函数
|
||||
comm = RDMACommunicator(
|
||||
role, # 角色("prefill"或"decode")
|
||||
gpu_idx, # GPU设备索引(0~7)
|
||||
port, # 通信端口
|
||||
local_key_cache, # 本地key缓存指针列表
|
||||
local_value_cache, # 本地value缓存指针列表
|
||||
block_number, # 块数量
|
||||
block_bytes # 每块字节数
|
||||
)
|
||||
|
||||
# 方法说明
|
||||
comm.connect(dst_ip, dst_port) # 连接到目标IP和端口
|
||||
comm.is_connected(dst_ip, dst_port) # 检查是否已连接
|
||||
comm.write_cache(
|
||||
ip, # 目标服务器IP地址
|
||||
port, # 目标服务器端口号
|
||||
local_block_ids, # 本地缓存块ID列表,指定要传输的本地块
|
||||
remote_block_ids, # 远程缓存块ID列表,指定要写入的远程块
|
||||
layer_idx # 模型层索引,用于多层模型场景
|
||||
)
|
||||
```
|
||||
|
||||
**参数说明**:
|
||||
|
||||
1. `role`:
|
||||
- "prefill"
|
||||
- "decode"
|
||||
|
||||
2. `gpu_idx`:
|
||||
- 使用的GPU设备索引
|
||||
|
||||
3. `port`:
|
||||
- RDMA通信端口号
|
||||
|
||||
4. `local_key_cache`/`local_value_cache`:
|
||||
- 本地KV缓存指针列表
|
||||
|
||||
5. `block_number`:
|
||||
- 缓存块数量
|
||||
|
||||
6. `block_bytes`:
|
||||
- 每个缓存块的字节大小
|
||||
|
||||
**示例代码**:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from rdma_comm import RDMACommunicator
|
||||
|
||||
# 初始化
|
||||
local_keys = [np.array([0]*1024, dtype=np.int64).ctypes.data] # 示例key指针
|
||||
local_values = [np.array([0]*1024, dtype=np.int64).ctypes.data] # 示例value指针
|
||||
|
||||
comm = RDMACommunicator(
|
||||
role="decode",
|
||||
gpu_idx=0,
|
||||
port="12345",
|
||||
local_key_cache=local_keys,
|
||||
local_value_cache=local_values,
|
||||
block_number=1024,
|
||||
block_bytes=4096
|
||||
)
|
||||
|
||||
# 客户端初始化
|
||||
comm = RDMACommunicator(
|
||||
role="prefill",
|
||||
gpu_idx=0,
|
||||
port="12345",
|
||||
local_key_cache=local_keys,
|
||||
local_value_cache=local_values,
|
||||
block_number=1024,
|
||||
block_bytes=4096
|
||||
)
|
||||
|
||||
if comm.connect("192.168.1.100", "12345"):
|
||||
print("连接成功")
|
||||
|
||||
# 写入缓存
|
||||
comm.write_cache(
|
||||
ip="192.168.1.100", # 目标服务器IP
|
||||
port="12345", # 目标服务器端口
|
||||
local_block_ids=[0,1,2], # 要传输的本地块ID列表
|
||||
remote_block_ids=[3,4,5], # 要写入的远程块ID列表
|
||||
layer_idx=0 # 模型层索引(0表示第一层)
|
||||
)
|
||||
```
|
||||
|
||||
## 引用
|
||||
|
||||
如果您使用此代码库,或认为我们的工作有价值,请引用:
|
@@ -0,0 +1,211 @@
|
||||
/**
|
||||
* @file kvcache_connection.h
|
||||
* @brief RDMA connection management for key-value cache
|
||||
* @version 1.0.0
|
||||
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef FASTDEPLOY_KVCACHE_CONNECTION_H
|
||||
#define FASTDEPLOY_KVCACHE_CONNECTION_H
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rdma/rdma_cma.h>
|
||||
#include <rdma/rdma_verbs.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <netinet/tcp.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sstream>
|
||||
#include <netdb.h>
|
||||
#include <sstream>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <cstring>
|
||||
#include <netdb.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <net/if.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <netinet/in.h>
|
||||
#include <unistd.h>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
#include "kvcache_rdma.h"
|
||||
#include "util.h"
|
||||
|
||||
#define KVCACHE_RDMA_NIC_MAX_LEN 256
|
||||
#define KVCACHE_RDMA_MAX_NICS 8
|
||||
#define NAME_MAX 255
|
||||
#define MAXNAMESIZE 64
|
||||
#define NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16
|
||||
|
||||
/// @brief IB device information structure
|
||||
struct IbDeviceInfo {
|
||||
int device;
|
||||
uint64_t guid;
|
||||
enum ibv_mtu mtu;
|
||||
uint64_t busid;
|
||||
uint8_t port;
|
||||
uint8_t link;
|
||||
uint8_t active_mtu;
|
||||
int speed;
|
||||
ibv_context* context;
|
||||
char devName[64];
|
||||
int realPort;
|
||||
int maxQp;
|
||||
};
|
||||
|
||||
/// @brief Queue Pair information for RDMA
|
||||
struct QpInfo {
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
union ibv_gid gid;
|
||||
enum ibv_mtu mtu;
|
||||
|
||||
/// @brief Serialize QP info to buffer
|
||||
void serialize(char* buffer) const {
|
||||
uint32_t* intBuffer = reinterpret_cast<uint32_t*>(buffer);
|
||||
intBuffer[0] = htonl(lid);
|
||||
intBuffer[1] = htonl(qpn);
|
||||
intBuffer[2] = htonl(psn);
|
||||
memcpy(buffer + 12, gid.raw, sizeof(gid.raw));
|
||||
intBuffer[7] = htonl(static_cast<uint32_t>(mtu));
|
||||
}
|
||||
|
||||
/// @brief Deserialize QP info from buffer
|
||||
void deserialize(const char* buffer) {
|
||||
const uint32_t* intBuffer = reinterpret_cast<const uint32_t*>(buffer);
|
||||
lid = ntohl(intBuffer[0]);
|
||||
qpn = ntohl(intBuffer[1]);
|
||||
psn = ntohl(intBuffer[2]);
|
||||
memcpy(gid.raw, buffer + 12, sizeof(gid.raw));
|
||||
mtu = static_cast<ibv_mtu>(ntohl(intBuffer[7]));
|
||||
}
|
||||
|
||||
static const size_t size = 12 + sizeof(gid.raw) + 4;
|
||||
};
|
||||
|
||||
/// @brief RDMA connection context
|
||||
struct Connection {
|
||||
std::atomic<int> connected;
|
||||
|
||||
// Memory regions
|
||||
struct ibv_mr *recv_mr;
|
||||
struct ibv_mr *send_mr;
|
||||
|
||||
// Cache pointers
|
||||
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer;
|
||||
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer;
|
||||
|
||||
// Memory region lists
|
||||
std::vector<ibv_mr*> write_cache_key_server_mr_list;
|
||||
std::vector<ibv_mr*> write_cache_value_server_mr_list;
|
||||
std::vector<std::vector<ibv_mr*>> write_mr_key_list;
|
||||
std::vector<std::vector<ibv_mr*>> write_mr_value_list;
|
||||
|
||||
// Remote access information
|
||||
std::vector<void*> write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> write_cache_value_remote_rkey_list;
|
||||
|
||||
// Received remote memory information
|
||||
std::vector<void*> receive_write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> receive_write_cache_key_remote_rkey_list;
|
||||
std::vector<void*> receive_write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> receive_write_cache_value_remote_rkey_list;
|
||||
|
||||
std::vector<void *> send_write_cache_key_remote_ptr_list;
|
||||
std::vector<uint32_t> send_write_cache_key_remote_rkey_list;
|
||||
std::vector<void *> send_write_cache_value_remote_ptr_list;
|
||||
std::vector<uint32_t> send_write_cache_value_remote_rkey_list;
|
||||
|
||||
// For rdma read operations
|
||||
std::vector<void*> read_bufs;
|
||||
std::vector<ibv_mr*> read_mrs;
|
||||
|
||||
// Work completion tracking
|
||||
int wc_count;
|
||||
int wc_target_count;
|
||||
|
||||
// Configuration
|
||||
int layer_number;
|
||||
int block_number;
|
||||
int block_byte_size;
|
||||
std::string url;
|
||||
|
||||
Connection() = default;
|
||||
~Connection();
|
||||
};
|
||||
|
||||
/// @brief RDMA context structure
|
||||
struct RdmaContext {
|
||||
int sock_fd;
|
||||
struct ibv_context* context;
|
||||
struct ibv_comp_channel* channel;
|
||||
struct ibv_pd* pd;
|
||||
struct ibv_mr* mr;
|
||||
struct ibv_cq* cq;
|
||||
struct ibv_qp* qp;
|
||||
struct ibv_port_attr portinfo;
|
||||
struct Connection conn;
|
||||
};
|
||||
|
||||
// Global variables
|
||||
extern std::vector<IbDeviceInfo> g_ib_all_devs;
|
||||
static int g_kvcache_ib_dev_nums = -1;
|
||||
|
||||
// Connection management functions
|
||||
bool client_exchange_destinations(
|
||||
struct RdmaContext* ctx,
|
||||
int ib_port,
|
||||
unsigned int port,
|
||||
int gidx,
|
||||
const std::string& dst_ip);
|
||||
|
||||
int server_exchange_qp_info(int connfd, QpInfo* local_dest, QpInfo* rem_dest);
|
||||
struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd);
|
||||
bool clear_qp_info(struct RdmaContext* ctx);
|
||||
|
||||
// QP modification functions
|
||||
QpStatus modify_qp_to_rts(struct RdmaContext* ctx, int port, int my_psn,
|
||||
struct QpInfo* dest, int sgid_id);
|
||||
bool poll_cq_with_timeout(struct RdmaContext* ctx, int timeout_seconds, int cqe_count);
|
||||
|
||||
// Utility functions
|
||||
int get_port_info(struct ibv_context* Context, int port,
|
||||
struct ibv_port_attr* attr);
|
||||
int parse_port_ib_info();
|
||||
|
||||
// Memory region exchange
|
||||
bool client_exchange_mr(struct RdmaContext* ctx);
|
||||
bool server_exchange_mr(struct RdmaContext* ctx);
|
||||
bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte_num);
|
||||
bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int byte_num);
|
||||
|
||||
// Network setup
|
||||
int setup_listening_socket(int port);
|
||||
int configure_epoll(int sockfd);
|
||||
std::vector<std::string> get_net_ifname();
|
||||
|
||||
#endif // FASTDEPLOY_KVCACHE_CONNECTION_H
|
@@ -0,0 +1,127 @@
|
||||
#ifndef KVCACHE_RDMA_H
|
||||
#define KVCACHE_RDMA_H
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <rdma/rdma_cma.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include "util.h" // Contains constant definitions
|
||||
#include "kvcache_connection.h"
|
||||
#include "log.h"
|
||||
|
||||
|
||||
/**
|
||||
* @brief RDMA communication handler for key-value cache
|
||||
*/
|
||||
class RDMACommunicator {
|
||||
public:
|
||||
// Construction/Destruction
|
||||
RDMACommunicator(std::string &role, int gpu_idx, std::string &port,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number, int block_bytes);
|
||||
~RDMACommunicator();
|
||||
|
||||
// Connection management
|
||||
int connect(const std::string &dst_ip, const std::string &dst_port);
|
||||
bool is_connected(const std::string &dst_ip, const std::string &dst_port);
|
||||
|
||||
// Core functionality
|
||||
int write_cache(const std::string &ip, const std::string &port,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
const std::vector<int64_t>& remote_block_ids,
|
||||
int32_t layer_idx);
|
||||
|
||||
// Server Init
|
||||
int init_server();
|
||||
|
||||
// get socket nic ip
|
||||
std::string fetch_local_ip();
|
||||
|
||||
private:
|
||||
// Server Core functions
|
||||
int start_server(int sport, int sgid_idx, int gpu_index);
|
||||
|
||||
// Internal implementation methods
|
||||
void resize_vectors();
|
||||
void assign_pointers();
|
||||
void validate_addr();
|
||||
bool client_mr_register_per_layer(struct RdmaContext *ctx);
|
||||
bool server_mr_register_per_layer(struct RdmaContext *ctx);
|
||||
struct ibv_mr* register_memory_region(ibv_pd* pd, void* addr, size_t size,
|
||||
const std::string& desc, uint32_t access_flags);
|
||||
bool deregister_memory_regions(struct RdmaContext* ctx);
|
||||
|
||||
bool post_block_send(struct RdmaContext* ctx, int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key, std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey, const std::string &ip,
|
||||
const std::string &port);
|
||||
|
||||
bool execute_rdma_writes(struct RdmaContext* ctx, int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key, std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey);
|
||||
|
||||
void prepare_write_requests(struct ibv_sge* sge_list,
|
||||
struct ibv_send_wr* send_wr_list,
|
||||
int layer_idx,
|
||||
const std::vector<int64_t>& local_block_ids,
|
||||
bool is_key,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey);
|
||||
|
||||
bool execute_read_verification(struct RdmaContext* ctx,
|
||||
size_t block_idx,
|
||||
uint64_t remote_addr,
|
||||
uint32_t rkey,
|
||||
int layer_idx,
|
||||
const std::string& ip,
|
||||
const std::string& port);
|
||||
|
||||
bool post_send_with_retry(struct RdmaContext* ctx,
|
||||
struct ibv_send_wr* wr_list,
|
||||
size_t inflight_wr,
|
||||
bool need_poll);
|
||||
|
||||
// Connection management
|
||||
int client_listener();
|
||||
void close_server_connection(int fd, struct RdmaContext* ctx, int epollfd,
|
||||
std::map<int, struct RdmaContext*>& connectionContexts);
|
||||
void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd);
|
||||
|
||||
void remove_conn(const std::string& url);
|
||||
struct RdmaContext *get_conn(const std::string &ip,
|
||||
const std::string &port);
|
||||
|
||||
// Member variables
|
||||
std::string splitwise_role; // Role in distributed system ("decode" or other)
|
||||
int gpu_idx; // GPU device index
|
||||
std::string port; // Communication port
|
||||
std::vector<int64_t> local_cache_key_ptr_layer_head_; // Key cache pointers
|
||||
std::vector<int64_t> local_cache_value_ptr_layer_head_; // Value cache pointers
|
||||
int block_number; // Number of blocks
|
||||
int block_size_byte; // Size of each block in bytes
|
||||
int layer_number; // Number of layers
|
||||
|
||||
std::vector<std::vector<void*>> local_cache_key_ptr_per_layer; // Per-layer key pointers
|
||||
std::vector<std::vector<void*>> local_cache_value_ptr_per_layer; // Per-layer value pointers
|
||||
|
||||
std::vector<struct ibv_mr*> write_mr_key_list; // Memory regions for key writes
|
||||
std::vector<struct ibv_mr*> write_mr_value_list; // Memory regions for value writes
|
||||
std::vector<struct ibv_mr*> write_cache_key_server_mr_list; // Server-side key memory regions
|
||||
std::vector<struct ibv_mr*> write_cache_value_server_mr_list; // Server-side value memory regions
|
||||
|
||||
std::vector<std::string> main_ip_list; // List of local IP addresses
|
||||
std::map<std::string, struct RdmaContext*> conn_map; // Active connections map
|
||||
std::mutex mutex_; // Thread synchronization mutex
|
||||
int rdma_event_channel_epoll_fd; // Epoll file descriptor
|
||||
struct ibv_pd *g_pd = NULL; // fd
|
||||
int RDMACommunicator_status; // Communicator status flag
|
||||
bool start_client_listener = false; // Client listener flag
|
||||
};
|
||||
|
||||
#endif // KVCACHE_RDMA_H
|
@@ -0,0 +1,117 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* @file log.h
|
||||
* @brief Logging module for key-value cache system
|
||||
* @version 1.0.0
|
||||
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h> //for gethostname
|
||||
#include <sys/syscall.h>
|
||||
#include <pthread.h>
|
||||
#include <string>
|
||||
#include <ctime>
|
||||
#include <chrono>
|
||||
|
||||
#define KV_IS_DEBUG_ENABLED (std::getenv("KVCACHE_DEBUG"))
|
||||
#define FILE_NAME(x) (strrchr(x,'/') ? strrchr(x,'/')+1 : x)
|
||||
|
||||
static thread_local char __attribute__((__unused__)) str[64];
|
||||
|
||||
// for log levels (C++ enum class style in C)
|
||||
typedef enum {
|
||||
KV_LOG_LEVEL_INFO = 0,
|
||||
KV_LOG_LEVEL_DEBUG = 1,
|
||||
KV_LOG_LEVEL_WARN = 2,
|
||||
KV_LOG_LEVEL_ERROR = 3
|
||||
} KVLogLevel;
|
||||
|
||||
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc,
|
||||
int line, const char *fmt, ...) __attribute__ ((format (printf, 5, 6)));
|
||||
|
||||
/**
|
||||
* @brief Unified logging macro to reduce duplication and improve maintainability.
|
||||
*
|
||||
* @param level Log level (e.g., INFO, DEBUG, WARN, ERR).
|
||||
* @param to_terminal If true, the log will be printed to terminal.
|
||||
* @param ... Format string and arguments (like printf).
|
||||
*/
|
||||
#define KV_LOG(level, to_terminal, ...) \
|
||||
debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__)
|
||||
|
||||
// Public logging macros with terminal output
|
||||
#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__)
|
||||
#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__)
|
||||
#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__)
|
||||
#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__)
|
||||
|
||||
#define gettid() ((pid_t)syscall(SYS_gettid))
|
||||
#define GET_CURRENT_TIME() do { \
|
||||
time_t timer = time(0); \
|
||||
struct tm* t = localtime(&timer); \
|
||||
char hostname[32]; \
|
||||
gethostname(hostname, 32); \
|
||||
sprintf(str, "%02d:%02d:%02d][%.32s][%d", \
|
||||
t->tm_hour, t->tm_min, t->tm_sec, hostname, gettid()); \
|
||||
} while (0)
|
||||
|
||||
#define LOGE(fmt, arg...) do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stderr, "[%s][ERR][KV_CACHE][%s:%d] " \
|
||||
fmt "\n",str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} while (0)
|
||||
|
||||
#define LOGW(fmt, arg...) do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stderr, "[%s][WARN][KV_CACHE][%s:%d] " \
|
||||
fmt "\n",str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} while (0)
|
||||
|
||||
#define LOGI(fmt, arg...) do { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, "[%s][INFO][KV_CACHE][%s:%d] " \
|
||||
fmt "\n",str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} while (0)
|
||||
|
||||
#define LOGD(fmt, arg...) do { \
|
||||
if (KV_IS_DEBUG_ENABLED) { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
|
||||
fmt "\n", str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOGD_IF(cond, fmt, ...) do { \
|
||||
if ((cond)) \
|
||||
LOGD(fmt, __VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define LOGD_RAW(fmt, arg...) do { \
|
||||
if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \
|
||||
GET_CURRENT_TIME(); \
|
||||
fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \
|
||||
fmt "\n", str, \
|
||||
FILE_NAME(__FILE__), __LINE__, ## arg); \
|
||||
} \
|
||||
} while (0)
|
@@ -0,0 +1,315 @@
|
||||
#ifndef KVCACHE_UTILS_H
|
||||
#define KVCACHE_UTILS_H
|
||||
|
||||
#include <ctime>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <stdexcept>
|
||||
#include <cstdio>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include "log.h"
|
||||
|
||||
#define PATH_MAX 4096 /* # chars in a path name including nul */
|
||||
#define RDMA_WR_LIST_MAX_SIZE 32
|
||||
#define RDMA_SQ_MAX_SIZE 1024
|
||||
|
||||
#define RDMA_DEFAULT_PORT 20001
|
||||
#define RDMA_TCP_CONNECT_SIZE 1024
|
||||
#define RDMA_POLL_CQE_TIMEOUT 30
|
||||
|
||||
/// @brief Connection status enumeration
|
||||
enum class ConnStatus {
|
||||
kConnected, // Connection is active
|
||||
kDisconnected, // Connection is not active
|
||||
kError, // Connection error occurred
|
||||
kTimeout, // Connection timed out
|
||||
kInvalidParameters // Invalid connection parameters
|
||||
};
|
||||
|
||||
/// @brief Queue Pair (QP) setup result status
|
||||
enum class QpStatus {
|
||||
kSuccess, // Successfully transitioned QP to RTS
|
||||
kInvalidParameters, // ctx or dest is null
|
||||
kDeviceQueryFailed, // ibv_query_device failed
|
||||
kPortQueryFailed, // ibv_query_port failed
|
||||
kMtuMismatch, // Requested MTU exceeds active MTU
|
||||
kModifyToRTRFailed, // Failed to modify QP to RTR
|
||||
kModifyToRTSFailed // Failed to modify QP to RTS
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convert PCI bus ID string to int64_t
|
||||
* @param busId PCI bus ID string (e.g. "0000:3b:00.0")
|
||||
* @param[out] id Converted numeric ID
|
||||
*/
|
||||
inline void busid_to_int64(const char *busId, int64_t *id) {
|
||||
char hexStr[17] = {0};
|
||||
int hexOffset = 0;
|
||||
|
||||
// Filter valid hex characters
|
||||
for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) {
|
||||
char c = busId[i];
|
||||
if (c == '.' || c == ':') continue;
|
||||
|
||||
if ((c >= '0' && c <= '9') ||
|
||||
(c >= 'A' && c <= 'F') ||
|
||||
(c >= 'a' && c <= 'f')) {
|
||||
hexStr[hexOffset++] = c;
|
||||
}
|
||||
}
|
||||
|
||||
*id = strtol(hexStr, NULL, 16);
|
||||
}
|
||||
|
||||
class NetworkInterfaceManager {
|
||||
public:
|
||||
struct InterfaceInfo {
|
||||
std::string name;
|
||||
std::string ip;
|
||||
bool is_up;
|
||||
bool is_running;
|
||||
bool is_loopback;
|
||||
|
||||
bool isUsable() const {
|
||||
return is_up && is_running && !is_loopback;
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<InterfaceInfo> getAllInterfaces() {
|
||||
std::vector<InterfaceInfo> interfaces;
|
||||
struct ifaddrs *ifaddrs_ptr = nullptr;
|
||||
|
||||
if (getifaddrs(&ifaddrs_ptr) == -1) {
|
||||
return interfaces;
|
||||
}
|
||||
|
||||
for (struct ifaddrs *ifa = ifaddrs_ptr; ifa != nullptr; ifa = ifa->ifa_next) {
|
||||
if (ifa->ifa_addr == nullptr) continue;
|
||||
if (ifa->ifa_addr->sa_family != AF_INET) continue;
|
||||
|
||||
InterfaceInfo info;
|
||||
info.name = ifa->ifa_name;
|
||||
info.is_up = (ifa->ifa_flags & IFF_UP) != 0;
|
||||
info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0;
|
||||
info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0;
|
||||
|
||||
struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr;
|
||||
char ip_str[INET_ADDRSTRLEN];
|
||||
inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN);
|
||||
info.ip = ip_str;
|
||||
|
||||
interfaces.push_back(info);
|
||||
}
|
||||
|
||||
freeifaddrs(ifaddrs_ptr);
|
||||
return interfaces;
|
||||
}
|
||||
|
||||
static std::string getFirstUsableInterface() {
|
||||
auto interfaces = getAllInterfaces();
|
||||
|
||||
for (const auto& iface : interfaces) {
|
||||
if (iface.isUsable()) {
|
||||
return iface.name;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
static void displayAllInterfaces() {
|
||||
auto interfaces = getAllInterfaces();
|
||||
|
||||
printf("Available network interfaces:\n");
|
||||
for (const auto& iface : interfaces) {
|
||||
printf(" %s: %s [%s%s%s]\n",
|
||||
iface.name.c_str(),
|
||||
iface.ip.c_str(),
|
||||
iface.is_up ? "UP" : "DOWN",
|
||||
iface.is_running ? ",RUNNING" : "",
|
||||
iface.is_loopback ? ",LOOPBACK" : "");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class KVCacheConfig {
|
||||
private:
|
||||
// Configuration values
|
||||
int rdma_gid_index_;
|
||||
bool has_rdma_dest_port_override_; // 替代 std::optional
|
||||
int rdma_dest_port_override_;
|
||||
const char* socket_interface_;
|
||||
char* socket_interface_buffer_;
|
||||
bool gdrcopy_flush_enabled_;
|
||||
bool verify_read_enabled_;
|
||||
bool debug_mode_enabled_;
|
||||
bool debug_output_enabled_;
|
||||
const char* debug_file_path_;
|
||||
const char* error_file_path_;
|
||||
bool relax_ordering_enabled_;
|
||||
int ib_timeout_;
|
||||
const char* rdma_nics_;
|
||||
|
||||
// Private constructor for singleton pattern
|
||||
KVCacheConfig() {
|
||||
// Initialize configuration from environment variables
|
||||
rdma_gid_index_ = parse_int_value(
|
||||
std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX");
|
||||
|
||||
// Parse optional RDMA port override
|
||||
const char* port_value = std::getenv("SET_RDMA_DEST_PORT");
|
||||
has_rdma_dest_port_override_ = false; // 默认为false
|
||||
if (port_value) {
|
||||
try {
|
||||
rdma_dest_port_override_ = std::stoi(std::string(port_value));
|
||||
has_rdma_dest_port_override_ = true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n", port_value);
|
||||
}
|
||||
}
|
||||
|
||||
const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME");
|
||||
|
||||
if (env_interface && env_interface[0] != '\0') {
|
||||
socket_interface_ = env_interface;
|
||||
printf("Using specified interface: %s\n", socket_interface_);
|
||||
} else {
|
||||
std::string iface = NetworkInterfaceManager::getFirstUsableInterface();
|
||||
if (!iface.empty()) {
|
||||
socket_interface_buffer_ = new char[iface.size() + 1];
|
||||
std::strcpy(socket_interface_buffer_, iface.c_str());
|
||||
socket_interface_ = socket_interface_buffer_;
|
||||
printf("Auto-detected interface: %s\n", socket_interface_);
|
||||
} else {
|
||||
fprintf(stderr, "Warning: No usable network interface found\n");
|
||||
socket_interface_ = "";
|
||||
}
|
||||
NetworkInterfaceManager::displayAllInterfaces();
|
||||
}
|
||||
|
||||
socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME");
|
||||
debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE");
|
||||
error_file_path_ = std::getenv("KVCACHE_ERROR_FILE");
|
||||
|
||||
gdrcopy_flush_enabled_ = parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"));
|
||||
verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ"));
|
||||
debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) ||
|
||||
parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED"));
|
||||
debug_output_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT"));
|
||||
|
||||
relax_ordering_enabled_ = parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING"));
|
||||
|
||||
ib_timeout_ = parse_int_value(
|
||||
std::getenv("KVCACHE_IB_TIMEOUT"),
|
||||
18,
|
||||
"KVCACHE_IB_TIMEOUT"
|
||||
);
|
||||
|
||||
rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS");
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
bool parse_bool_value(const char* value) {
|
||||
if (!value) return false;
|
||||
|
||||
std::string str_value(value);
|
||||
std::transform(str_value.begin(), str_value.end(), str_value.begin(), ::tolower);
|
||||
|
||||
return (str_value == "1" || str_value == "true" ||
|
||||
str_value == "on" || str_value == "yes");
|
||||
}
|
||||
|
||||
int parse_int_value(const char* value, int default_value, const char* env_name) {
|
||||
if (!value) return default_value;
|
||||
|
||||
try {
|
||||
return std::stoi(std::string(value));
|
||||
} catch (const std::invalid_argument& e) {
|
||||
fprintf(stderr, "Invalid value for %s: '%s', using default: %d\n",
|
||||
env_name, value, default_value);
|
||||
return default_value;
|
||||
} catch (const std::out_of_range& e) {
|
||||
fprintf(stderr, "%s value out of range: '%s', using default: %d\n",
|
||||
env_name, value, default_value);
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Prevent copying and assignment
|
||||
KVCacheConfig(const KVCacheConfig&) = delete;
|
||||
KVCacheConfig& operator=(const KVCacheConfig&) = delete;
|
||||
|
||||
// Get singleton instance
|
||||
static KVCacheConfig& getInstance() {
|
||||
static KVCacheConfig instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
int get_ib_timeout() const { return ib_timeout_; }
|
||||
|
||||
// Configuration retrieval methods
|
||||
int get_rdma_gid_index() const { return rdma_gid_index_; }
|
||||
|
||||
int resolve_rdma_dest_port(int default_port) const {
|
||||
return has_rdma_dest_port_override_ ? rdma_dest_port_override_ : default_port;
|
||||
}
|
||||
|
||||
int resolve_rdma_dest_port(const std::string& default_port) const {
|
||||
try {
|
||||
return resolve_rdma_dest_port(std::stoi(default_port));
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "Invalid default port string: %s\n", default_port.c_str());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
const char* get_socket_interface() const { return socket_interface_; }
|
||||
const char* get_debug_file_path() const { return debug_file_path_; }
|
||||
const char* get_error_file_path() const { return error_file_path_; }
|
||||
const char* get_rdma_nics() const { return rdma_nics_; }
|
||||
|
||||
// Feature check methods
|
||||
bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; }
|
||||
bool is_verify_read_enabled() const { return verify_read_enabled_; }
|
||||
bool is_debug_mode_enabled() const { return debug_mode_enabled_; }
|
||||
bool is_debug_output_enabled() const { return debug_output_enabled_; }
|
||||
bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; }
|
||||
|
||||
// Display configuration
|
||||
void displayConfiguration() const {
|
||||
INFO("KVCache Configuration:\n");
|
||||
INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_);
|
||||
|
||||
if (has_rdma_dest_port_override_) {
|
||||
INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n", rdma_dest_port_override_);
|
||||
}
|
||||
|
||||
if (socket_interface_) {
|
||||
INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_);
|
||||
}
|
||||
|
||||
INFO("Init KVCacheConfig GDRCopy Flush: %s\n", gdrcopy_flush_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Verify Read: %s\n", verify_read_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Debug Mode: %s\n", debug_mode_enabled_ ? "enabled" : "disabled");
|
||||
INFO("Init KVCacheConfig Debug Output: %s\n", debug_output_enabled_ ? "enabled" : "disabled");
|
||||
|
||||
if (debug_file_path_) {
|
||||
INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_);
|
||||
}
|
||||
|
||||
if (error_file_path_) {
|
||||
INFO("Init KVCacheConfig Error File: %s\n", error_file_path_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
||||
/**
|
||||
* @file log.cpp
|
||||
* @brief Logging module implementation for key-value cache system
|
||||
* @version 1.0.0
|
||||
* @copyright Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdarg.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/stat.h>
|
||||
#include <libgen.h>
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include "log.h"
|
||||
#include "util.h"
|
||||
|
||||
static int pid = -1;
|
||||
static __thread int tid = -1;
|
||||
static char hostname[64];
|
||||
char global_log_last_error[1024] = "";
|
||||
FILE *global_debug_file = stdout;
|
||||
FILE *global_error_file = stdout;
|
||||
static char global_debug_file_name[PATH_MAX+1] = "";
|
||||
static char global_err_file_name[PATH_MAX+1] = "";
|
||||
int global_debug_level = -1;
|
||||
pthread_mutex_t global_debug_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
pthread_mutex_t global_log_file_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
|
||||
void log_file_init(FILE **kv_cache_log_file, const char *kv_cache_log_file_env, char *logFileName) {
|
||||
int c = 0;
|
||||
char *dfn = logFileName;
|
||||
while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') {
|
||||
if (kv_cache_log_file_env[c++] != '%') {
|
||||
*dfn++ = kv_cache_log_file_env[c - 1];
|
||||
continue;
|
||||
}
|
||||
switch (kv_cache_log_file_env[c++]) {
|
||||
case '%': // Double %
|
||||
*dfn++ = '%';
|
||||
break;
|
||||
case 'h': // %h = hostname
|
||||
dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
|
||||
break;
|
||||
case 'p': // %p = pid
|
||||
dfn += snprintf(dfn, PATH_MAX, "%d", pid);
|
||||
break;
|
||||
default: // Echo everything we don't understand
|
||||
*dfn++ = '%';
|
||||
*dfn++ = kv_cache_log_file_env[c - 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
*dfn = '\0';
|
||||
if (logFileName[0] != '\0') {
|
||||
FILE *file = fopen(logFileName, "w");
|
||||
if (file != nullptr) {
|
||||
setbuf(file, nullptr); // disable buffering
|
||||
*kv_cache_log_file = file;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void recreate_log_file(FILE **kv_cache_log_file, char *logFileName) {
|
||||
if (logFileName[0] != '\0') {
|
||||
pthread_mutex_lock(&global_log_file_lock);
|
||||
FILE *file = fopen(logFileName, "a"); // Use "a" mode to append if file exists, otherwise create it
|
||||
// close the previous log file if it exists
|
||||
if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) {
|
||||
fclose(*kv_cache_log_file);
|
||||
*kv_cache_log_file = NULL;
|
||||
}
|
||||
if (file != NULL) {
|
||||
setbuf(file, NULL); // disable buffering
|
||||
*kv_cache_log_file = file;
|
||||
}
|
||||
pthread_mutex_unlock(&global_log_file_lock);
|
||||
}
|
||||
}
|
||||
|
||||
void debug_init() {
|
||||
pthread_mutex_lock(&global_debug_lock);
|
||||
if (global_debug_level != -1) {
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
return;
|
||||
}
|
||||
|
||||
const char* kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED");
|
||||
int tempg_kv_cache_debug_level = -1;
|
||||
|
||||
if (kv_cache_debug == NULL) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
} else if (strcasecmp(kv_cache_debug, "0") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
} else if (strcasecmp(kv_cache_debug, "1") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG;
|
||||
} else if (strcasecmp(kv_cache_debug, "2") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN;
|
||||
} else if (strcasecmp(kv_cache_debug, "3") == 0) {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR;
|
||||
} else {
|
||||
tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO;
|
||||
}
|
||||
|
||||
gethostname(hostname, 64);
|
||||
pid = getpid();
|
||||
|
||||
const char* g_kv_cache_debug_fileEnv = KVCacheConfig::getInstance().get_debug_file_path();
|
||||
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_debug_fileEnv != NULL) {
|
||||
log_file_init(&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name);
|
||||
}
|
||||
|
||||
const char* g_kv_cache_error_fileEnv = KVCacheConfig::getInstance().get_error_file_path();
|
||||
if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_error_fileEnv != NULL) {
|
||||
log_file_init(&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name);
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
char timeBuffer[80]; // Buffer to hold the formatted time
|
||||
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
|
||||
len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer);
|
||||
buffer[len++] = '\n';
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
__atomic_store_n(&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE);
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
}
|
||||
|
||||
/* Common logging function used by the INFO, DEBUG and WARN macros
|
||||
* Also exported to the dynamically loadable Net transport modules so
|
||||
* they can share the debugging mechanisms and output files
|
||||
*/
|
||||
void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc, int line, const char *fmt, ...) {
|
||||
if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) {
|
||||
debug_init();
|
||||
}
|
||||
|
||||
// Save the last error (WARN) as a human readable string
|
||||
if (level == KV_LOG_LEVEL_WARN) {
|
||||
pthread_mutex_lock(&global_debug_lock);
|
||||
va_list vargs;
|
||||
va_start(vargs, fmt);
|
||||
(void) vsnprintf(global_log_last_error, sizeof(global_log_last_error), fmt, vargs);
|
||||
va_end(vargs);
|
||||
pthread_mutex_unlock(&global_debug_lock);
|
||||
}
|
||||
|
||||
if (tid == -1) {
|
||||
tid = syscall(SYS_gettid);
|
||||
}
|
||||
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
// Convert timestamp to absolute time and directly use it in the snprintf function
|
||||
std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
char timeBuffer[80]; // Buffer to hold the formatted time
|
||||
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime));
|
||||
|
||||
if (level == KV_LOG_LEVEL_WARN) {
|
||||
len = snprintf(buffer, sizeof(buffer), "\n%s %s:%d:%d %s:%d KV_CACHE WARN ",
|
||||
timeBuffer, hostname, pid, tid, filefunc, line);
|
||||
} else if (level == KV_LOG_LEVEL_INFO) {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE INFO ", timeBuffer, hostname, pid, tid);
|
||||
} else if (level == KV_LOG_LEVEL_DEBUG) {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE DEBUG ", timeBuffer, hostname, pid, tid);
|
||||
} else if (level == KV_LOG_LEVEL_ERROR) {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ERROR ", timeBuffer, hostname, pid, tid);
|
||||
} else {
|
||||
len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ", timeBuffer, hostname, pid, tid);
|
||||
}
|
||||
|
||||
if (len) {
|
||||
va_list vargs;
|
||||
va_start(vargs, fmt);
|
||||
len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
|
||||
va_end(vargs);
|
||||
// vsnprintf may return len > sizeof(buffer) in the case of a truncated output.
|
||||
// Rewind len so that we can replace the final \0 by \n
|
||||
if (len > sizeof(buffer)) {
|
||||
len = sizeof(buffer) - 1;
|
||||
}
|
||||
buffer[len++] = '\n';
|
||||
if (access(global_debug_file_name, F_OK) != 0) {
|
||||
recreate_log_file(&global_debug_file, global_debug_file_name);
|
||||
}
|
||||
if (enable_to_terminal) {
|
||||
fwrite(buffer, 1, len, global_debug_file);
|
||||
}
|
||||
if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) {
|
||||
if (access(global_err_file_name, F_OK) != 0) {
|
||||
recreate_log_file(&global_error_file, global_err_file_name);
|
||||
}
|
||||
if (global_error_file != NULL) {
|
||||
fwrite(buffer, 1, len, global_error_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,22 @@
|
||||
#include "kvcache_connection.h"
|
||||
#include "kvcache_rdma.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(rdma_comm, m) {
|
||||
m.doc() = R"pbdoc(kv cache messager)pbdoc";
|
||||
py::class_<RDMACommunicator>(m, "RDMACommunicator")
|
||||
.def(py::init<std::string &, int, std::string &, std::vector<int64_t>,
|
||||
std::vector<int64_t>, int, int>())
|
||||
.def("connect", &RDMACommunicator::connect)
|
||||
.def("is_connected", &RDMACommunicator::is_connected)
|
||||
.def("write_cache", &RDMACommunicator::write_cache);
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
#else
|
||||
m.attr("__version__") = "dev";
|
||||
#endif
|
||||
}
|
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
class RDMACommManager:
|
||||
"""
|
||||
RDMACommManager to manage rdma communication
|
||||
"""
|
||||
|
||||
def __init__(self, splitwise_role, rank, gpu_id, cache_k_ptr_list, \
|
||||
cache_v_ptr_list, max_block_num, block_bytes, rdma_port):
|
||||
try:
|
||||
import rdma_comm
|
||||
except:
|
||||
logger.error(f"The installation of the RDMA library failed." \
|
||||
"Confirm whether your network card supports RDMA transmission.")
|
||||
return
|
||||
self.messager = rdma_comm.RDMACommunicator(
|
||||
splitwise_role,
|
||||
rank,
|
||||
str(rdma_port) if splitwise_role == "decode" else "0",
|
||||
cache_k_ptr_list,
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.connected_rdma = set()
|
||||
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
|
||||
|
||||
def connect(self, ip, port):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
assert self.splitwise_role == "prefill", "only prefill can call this method"
|
||||
addr = f"{ip}:{str(port)}"
|
||||
if addr in self.connected_rdma:
|
||||
return True
|
||||
ret = self.messager.is_connected(ip, str(port))
|
||||
if ret:
|
||||
self.connected_rdma.add(addr)
|
||||
return True
|
||||
|
||||
ret = self.messager.connect(ip, str(port))
|
||||
logger.info(
|
||||
f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||
if ret == 0:
|
||||
self.connected_rdma.add(addr)
|
||||
return ret == 0
|
||||
|
||||
def write_cache(self, ip, port, local_block_ids, remote_block_ids,
|
||||
layer_idx):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
return self.messager.write_cache(ip, str(port), local_block_ids,
|
||||
remote_block_ids, layer_idx)
|
||||
|
||||
|
Reference in New Issue
Block a user