mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -21,51 +21,54 @@ import time
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory import (IPCCommManager,
|
||||
RDMACommManager)
|
||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
class CacheMessager(object):
|
||||
class CacheMessager:
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
splitwise_role,
|
||||
transfer_protocol,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
local_data_parallel_id,
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
gpu_id=0,
|
||||
rdma_port=None):
|
||||
def __init__(
|
||||
self,
|
||||
splitwise_role,
|
||||
transfer_protocol,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
local_data_parallel_id,
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
gpu_id=0,
|
||||
rdma_port=None,
|
||||
):
|
||||
"""
|
||||
Initialize the CacheMessager object.
|
||||
Initialize the CacheMessager object.
|
||||
|
||||
Args:
|
||||
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
||||
transfer_protocol (str): support ipc and rdma
|
||||
engine_worker_queue_port (int): engine_worker_queue port
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
Args:
|
||||
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
||||
transfer_protocol (str): support ipc and rdma
|
||||
engine_worker_queue_port (int): engine_worker_queue port
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
|
||||
Returns:
|
||||
None
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
assert splitwise_role in ["prefill", "decode"], \
|
||||
"splitwise_role must be prefill or decode"
|
||||
assert splitwise_role in [
|
||||
"prefill",
|
||||
"decode",
|
||||
], "splitwise_role must be prefill or decode"
|
||||
self.splitwise_role = splitwise_role
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
@@ -76,11 +79,11 @@ class CacheMessager(object):
|
||||
is_server=False,
|
||||
num_client=self.nranks,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=local_data_parallel_id)
|
||||
local_data_parallel_id=local_data_parallel_id,
|
||||
)
|
||||
transfer_protocol = transfer_protocol.split(",")
|
||||
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}"
|
||||
f"rank: {rank}")
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
@@ -90,10 +93,8 @@ class CacheMessager(object):
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache = self.gpu_cache_kvs[
|
||||
f'key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
|
||||
val_cache = self.gpu_cache_kvs[
|
||||
f'value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
@@ -109,7 +110,8 @@ class CacheMessager(object):
|
||||
block_bytes *= 2
|
||||
logger.info(
|
||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}")
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||
)
|
||||
self.block_bytes = block_bytes
|
||||
|
||||
# 3. initialize the messager
|
||||
@@ -122,24 +124,26 @@ class CacheMessager(object):
|
||||
cache_v,
|
||||
)
|
||||
local_device_id = int(str(cache_k[0].place)[-2])
|
||||
logger.info(
|
||||
f"done create ipc_comm with local_device_id:{local_device_id}, "
|
||||
)
|
||||
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
||||
|
||||
elif protocol == "rdma":
|
||||
logger.info(
|
||||
f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}"
|
||||
)
|
||||
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
||||
|
||||
self.messager[protocol] = RDMACommManager(
|
||||
splitwise_role, rank, gpu_id, cache_k_ptr_list,
|
||||
cache_v_ptr_list, max_block_num, block_bytes, rdma_port)
|
||||
splitwise_role,
|
||||
rank,
|
||||
gpu_id,
|
||||
cache_k_ptr_list,
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(
|
||||
target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
|
||||
@@ -159,26 +163,30 @@ class CacheMessager(object):
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=True)
|
||||
create=True,
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=True)
|
||||
create=True,
|
||||
)
|
||||
except:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=False)
|
||||
create=False,
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=False)
|
||||
create=False,
|
||||
)
|
||||
|
||||
step_shm_value.value[0] = -1
|
||||
layer_shm_value.value[0] = -1
|
||||
@@ -193,21 +201,19 @@ class CacheMessager(object):
|
||||
if cache_info:
|
||||
logger.debug(f"cache info {cache_info}")
|
||||
for info in cache_info:
|
||||
if info['request_id'] in self.cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
current_info = self.cache_info[info["request_id"]]
|
||||
if "dest_block_ids" in current_info and "src_block_ids" in current_info:
|
||||
current_src_blocks = current_info[
|
||||
"src_block_ids"][-len(current_info["dest_block_ids"]):]
|
||||
current_info[
|
||||
"src_block_ids"] = current_src_blocks
|
||||
current_src_blocks = current_info["src_block_ids"][
|
||||
-len(current_info["dest_block_ids"]) :
|
||||
]
|
||||
current_info["src_block_ids"] = current_src_blocks
|
||||
current_info["current_layer_ids"] = 0
|
||||
current_info["status"] = "init"
|
||||
logger.info(
|
||||
f"start cache_infos: {current_info}")
|
||||
logger.info(f"start cache_infos: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.last_step_idx = min(
|
||||
self.last_step_idx, current_info['current_id'])
|
||||
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
@@ -223,64 +229,53 @@ class CacheMessager(object):
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.debug(
|
||||
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}"
|
||||
)
|
||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
if "layer_idx" not in item:
|
||||
item["layer_idx"] = 0
|
||||
if item['status'] == 'error':
|
||||
if item["status"] == "error":
|
||||
del self.cache_info[req_id]
|
||||
continue
|
||||
if item['current_id'] > prefilled_step_idx:
|
||||
if item["current_id"] > prefilled_step_idx:
|
||||
continue
|
||||
current_transfer_protocol = item["transfer_protocol"]
|
||||
if item["transfer_protocol"] == "rdma":
|
||||
target_ip = item['ip']
|
||||
target_id = int(item['rdma_ports'][self.rank])
|
||||
status = self.messager[
|
||||
current_transfer_protocol].connect(
|
||||
target_ip, target_id)
|
||||
target_ip = item["ip"]
|
||||
target_id = int(item["rdma_ports"][self.rank])
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(
|
||||
f"connect to {target_ip}:{target_id} failed")
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([
|
||||
(item['request_id'], "connect error")
|
||||
])
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
|
||||
continue
|
||||
elif item["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(item['device_ids'][self.rank])
|
||||
src_block_ids = paddle.to_tensor(item['src_block_ids'],
|
||||
dtype='int32',
|
||||
place='cpu')
|
||||
dest_block_ids = paddle.to_tensor(item['dest_block_ids'],
|
||||
dtype='int32',
|
||||
place='cpu')
|
||||
if item['current_id'] < prefilled_step_idx:
|
||||
target_id = int(item["device_ids"][self.rank])
|
||||
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
||||
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
||||
if item["current_id"] < prefilled_step_idx:
|
||||
current_layer_idx = self.num_layers
|
||||
else:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
|
||||
for layer_idx in range(item["layer_idx"],
|
||||
current_layer_idx):
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[
|
||||
current_transfer_protocol].write_cache(
|
||||
target_ip, target_id, src_block_ids,
|
||||
dest_block_ids, layer_idx)
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
if return_code != 0:
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([
|
||||
(item['request_id'], "write cache error")
|
||||
])
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
@@ -298,16 +293,14 @@ class CacheMessager(object):
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
item['layer_idx'] = current_layer_idx
|
||||
if item['layer_idx'] == self.num_layers:
|
||||
item["layer_idx"] = current_layer_idx
|
||||
if item["layer_idx"] == self.num_layers:
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([
|
||||
(item['request_id'], "finished")
|
||||
])
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
del self.cache_info[req_id]
|
||||
|
||||
@@ -315,5 +308,4 @@ class CacheMessager(object):
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"prefill layerwise send cache thread has exception: {e}")
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e}")
|
||||
|
Reference in New Issue
Block a user