mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@@ -23,16 +26,72 @@ import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
shared_memory_exists,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache Messager")
|
||||
parser.add_argument(
|
||||
"--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now",
|
||||
)
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
|
||||
parser.add_argument(
|
||||
"--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
parser.add_argument(
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_config",
|
||||
type=json.loads,
|
||||
default="{}",
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class CacheMessager:
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
@@ -69,11 +128,6 @@ class CacheMessager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
assert splitwise_role in [
|
||||
"prefill",
|
||||
"decode",
|
||||
], "splitwise_role must be prefill or decode"
|
||||
self.splitwise_role = splitwise_role
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
@@ -147,15 +201,16 @@ class CacheMessager:
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
if self.splitwise_role != "mixed":
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _prefill_layerwise_send_cache_thread(self):
|
||||
def prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
@@ -163,23 +218,23 @@ class CacheMessager:
|
||||
try:
|
||||
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}"
|
||||
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
|
||||
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
||||
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=not shared_memory_exists(prefilled_step_name),
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=not shared_memory_exists(prefilled_layer_name),
|
||||
)
|
||||
logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}")
|
||||
logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")
|
||||
|
||||
step_shm_value.value[0] = -1
|
||||
layer_shm_value.value[0] = -1
|
||||
@@ -187,6 +242,9 @@ class CacheMessager:
|
||||
self.last_step_idx = -1
|
||||
self.last_layer_idx = -1 # int32
|
||||
|
||||
max_step_idx = 100003
|
||||
engine_recycled_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
@@ -202,11 +260,9 @@ class CacheMessager:
|
||||
-len(current_info["dest_block_ids"]) :
|
||||
]
|
||||
current_info["src_block_ids"] = current_src_blocks
|
||||
current_info["current_layer_ids"] = 0
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"start cache_infos: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
@@ -223,7 +279,18 @@ class CacheMessager:
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
|
||||
if self.last_step_idx > prefilled_step_idx:
|
||||
engine_recycled_count += 1
|
||||
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
|
||||
prefilled_step_idx = (
|
||||
prefilled_step_idx + max_step_idx * engine_recycled_count
|
||||
) # remap prefilled_step_idx for comparison
|
||||
|
||||
logger.debug(
|
||||
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
|
||||
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
|
||||
)
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
@@ -294,12 +361,493 @@ class CacheMessager:
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
del self.cache_info[req_id]
|
||||
|
||||
self.last_step_idx = prefilled_step_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
|
||||
class CacheMessagerV1:
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
splitwise_role,
|
||||
transfer_protocol,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
local_data_parallel_id,
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port=None,
|
||||
):
|
||||
"""
|
||||
Initialize the CacheMessager object.
|
||||
|
||||
Args:
|
||||
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
||||
transfer_protocol (str): support ipc and rdma
|
||||
engine_worker_queue_port (int): engine_worker_queue port
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.splitwise_role = splitwise_role
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
self.nranks = nranks
|
||||
address = (pod_ip, engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
num_client=self.nranks,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=local_data_parallel_id,
|
||||
)
|
||||
self.block_size = block_size
|
||||
transfer_protocol = transfer_protocol.split(",")
|
||||
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
cache_k_ptr_list = []
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
# 2. initialize the block_bytes
|
||||
cache_shape = key_cache.shape
|
||||
max_block_num = cache_shape[0]
|
||||
block_bytes = math.prod(cache_shape[1:])
|
||||
if key_cache.dtype == paddle.bfloat16:
|
||||
block_bytes *= 2
|
||||
logger.info(
|
||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||
)
|
||||
self.block_bytes = block_bytes
|
||||
|
||||
# 3. initialize the messager
|
||||
for protocol in transfer_protocol:
|
||||
if protocol == "ipc":
|
||||
self.messager[protocol] = IPCCommManager(
|
||||
self.rank,
|
||||
gpu_id,
|
||||
cache_k,
|
||||
cache_v,
|
||||
)
|
||||
local_device_id = int(str(cache_k[0].place)[-2])
|
||||
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
||||
|
||||
elif protocol == "rdma":
|
||||
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
||||
|
||||
self.messager[protocol] = RDMACommManager(
|
||||
splitwise_role,
|
||||
rank,
|
||||
gpu_id,
|
||||
cache_k_ptr_list,
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
self.engine_cache_task_thread_lock = threading.Lock()
|
||||
self.engine_cache_tasks = [dict() for _ in range(512)]
|
||||
self.idx_cache_task_dict = {}
|
||||
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
|
||||
if splitwise_role == "prefill":
|
||||
consume_signals_thread = threading.Thread(target=self.consume_signals)
|
||||
consume_signals_thread.daemon = True
|
||||
consume_signals_thread.start()
|
||||
add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
|
||||
add_cache_task_thread.daemon = True
|
||||
add_cache_task_thread.start()
|
||||
|
||||
if self.splitwise_role != "mixed":
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _add_cache_task_thread(self):
|
||||
while True:
|
||||
try:
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
finished_add_cache_task_req_ids = []
|
||||
if cache_info:
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
current_info = self.cache_info[info["request_id"]]
|
||||
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
|
||||
finished_add_cache_task_req_ids.append(info["request_id"])
|
||||
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
|
||||
current_info["dest_block_ids"]
|
||||
)
|
||||
padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
|
||||
"dest_block_ids"
|
||||
]
|
||||
current_info["dest_block_ids"] = padding_decode_block_ids
|
||||
current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
|
||||
current_info["sended_layer_id"] = -1
|
||||
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"finish add cache task: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
if self.rank == 0 and finished_add_cache_task_req_ids:
|
||||
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.")
|
||||
|
||||
def prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
block_start_end_list = []
|
||||
current_prefilled_token_num_list = []
|
||||
for engine_index in engine_indexes:
|
||||
assert engine_index in self.idx_cache_task_dict
|
||||
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
||||
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
||||
if (
|
||||
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
||||
): # all chunks have been prefilled
|
||||
block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
|
||||
else:
|
||||
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
||||
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
||||
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
||||
logger.info(
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
||||
)
|
||||
assert (
|
||||
current_prefilled_token_num_list[0]
|
||||
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
||||
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
||||
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
||||
else:
|
||||
end_layer_idx = prefilled_layer_idx
|
||||
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
|
||||
time.sleep(0.01)
|
||||
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||
engine_index = engine_indexes[i]
|
||||
task = self.idx_cache_task_dict[engine_index]
|
||||
req_id = task["request_id"]
|
||||
if (
|
||||
block_id_start >= block_id_end
|
||||
): # no blocks need to transfer for this request in this chunk
|
||||
task["sended_layer_id"] += 1
|
||||
assert task["sended_layer_id"] == layer_idx
|
||||
if task["sended_layer_id"] == self.num_layers - 1:
|
||||
task["sended_layer_id"] = -1
|
||||
continue
|
||||
else:
|
||||
current_transfer_protocol = task["transfer_protocol"]
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["status"] == "error":
|
||||
continue
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
task["status"] = "connection error"
|
||||
continue
|
||||
elif task["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(task["device_ids"][self.rank])
|
||||
|
||||
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
|
||||
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
|
||||
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
|
||||
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
|
||||
|
||||
logger.info(
|
||||
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
if return_code != 0:
|
||||
task["status"] = "write cache error"
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
|
||||
task["sended_layer_id"] += 1
|
||||
assert task["sended_layer_id"] == layer_idx
|
||||
if task["sended_layer_id"] == self.num_layers - 1:
|
||||
self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
|
||||
block_id_end - block_id_start
|
||||
)
|
||||
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
|
||||
if task["status"] != "error":
|
||||
task["status"] = "finished"
|
||||
logger.info(
|
||||
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||
)
|
||||
else:
|
||||
task["sended_layer_id"] = -1
|
||||
if end_layer_idx == self.num_layers - 1:
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for engine_idx in engine_indexes:
|
||||
task = self.idx_cache_task_dict[engine_idx]
|
||||
if task["status"] == "finished" or ("error" in task["status"]):
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req(
|
||||
[(task["request_id"], task["status"])]
|
||||
)
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
self.engine_cache_tasks[task["current_id"]] = dict()
|
||||
del self.cache_info[task["request_id"]]
|
||||
del self.idx_cache_task_dict[task["current_id"]]
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
|
||||
time.sleep(0.01)
|
||||
|
||||
def consume_signals(self):
|
||||
paddle.device.set_device("cpu")
|
||||
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
|
||||
while True:
|
||||
try:
|
||||
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
|
||||
if not self.cache_info:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
tasks_count = kv_signal_data[0]
|
||||
if tasks_count == -1:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
layer_id = kv_signal_data[1].numpy().tolist()
|
||||
if layer_id == self.num_layers - 1:
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
||||
batch_engine_ids = []
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for bi in range(tasks_count):
|
||||
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
||||
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
|
||||
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
||||
chuck_token_offset + current_seq_len
|
||||
)
|
||||
batch_engine_ids.append(engine_idx)
|
||||
if layer_id == 0:
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Consume signals get exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
cache_type = args.cache_dtype
|
||||
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
||||
gpu_cache_kvs = {}
|
||||
gpu_cache_k_tensors = []
|
||||
gpu_cache_v_tensors = []
|
||||
|
||||
for i in range(args.num_layers + num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
|
||||
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
cache_messager = CacheMessagerV1(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
else:
|
||||
cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
cache_ready_signal.value[rank] = 1
|
||||
if args.splitwise_role == "mixed":
|
||||
while True:
|
||||
time.sleep(1)
|
||||
cache_messager.prefill_layerwise_send_cache_thread()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
main()
|
||||
|
Reference in New Issue
Block a user