[Feature] Optimize prefix cache (#3208)

* [LLM] support ep

* Update worker_process.py

* Update expert_service.py

* Update worker_process.py

* format files

* optimize prefix cache

* optimize prefix cache

* optimize prefix cache

* pre commit format

* pre commit format

* pre commit format

* Update cache_messager.py
This commit is contained in:
ltd0924
2025-08-05 17:13:11 +08:00
committed by GitHub
parent 9f9971844f
commit dcf9c2daff
7 changed files with 314 additions and 147 deletions

View File

@@ -14,18 +14,72 @@
# limitations under the License. # limitations under the License.
""" """
import argparse
import json
import math import math
import threading
import time import time
import threading
import numpy as np import numpy as np
import paddle import paddle
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import set_data_ipc
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
"""
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache Messager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port",
)
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
default="{}",
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args()
return args
class CacheMessager: class CacheMessager:
@@ -43,7 +97,7 @@ class CacheMessager:
gpu_cache_kvs, gpu_cache_kvs,
rank, rank,
nranks, nranks,
num_layers, num_hidden_layers,
gpu_id=0, gpu_id=0,
rdma_port=None, rdma_port=None,
): ):
@@ -57,7 +111,7 @@ class CacheMessager:
gpu_cache_kvs (dict): GPU kv cache gpu_cache_kvs (dict): GPU kv cache
rank (int): current rank rank (int): current rank
nranks (int): global rank number nranks (int): global rank number
num_layers (int): model layer number num_hidden_layers (int): model layer number
gpu_id (int, optional): GPU ID gpu_id (int, optional): GPU ID
rdma_port (int, optional): RDMA port rdma_port (int, optional): RDMA port
@@ -86,13 +140,13 @@ class CacheMessager:
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}") logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list # 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers self.num_hidden_layers = num_hidden_layers
cache_k_ptr_list = [] cache_k_ptr_list = []
cache_v_ptr_list = [] cache_v_ptr_list = []
cache_k = [] cache_k = []
cache_v = [] cache_v = []
self.messager = {} self.messager = {}
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_hidden_layers):
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache) cache_k.append(key_cache)
@@ -109,7 +163,7 @@ class CacheMessager:
if key_cache.dtype == paddle.bfloat16: if key_cache.dtype == paddle.bfloat16:
block_bytes *= 2 block_bytes *= 2
logger.info( logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}" f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
) )
self.block_bytes = block_bytes self.block_bytes = block_bytes
@@ -144,17 +198,13 @@ class CacheMessager:
self.cache_info = dict() self.cache_info = dict()
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch) self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start()
connect_rdma_thread = threading.Thread(target=self._handle_connect_task) connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True connect_rdma_thread.daemon = True
connect_rdma_thread.start() connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}") logger.info(f"cache messager init finished, use {transfer_protocol}")
def _prefill_layerwise_send_cache_thread(self): def prefill_layerwise_send_cache_thread(self):
""" """
layerwise_send_cache_thread: layerwise_send_cache_thread:
send cache to other instance send cache to other instance
@@ -204,7 +254,7 @@ class CacheMessager:
cache_info = self.engine_worker_queue.get_cache_info() cache_info = self.engine_worker_queue.get_cache_info()
if cache_info: if cache_info:
logger.debug(f"cache info {cache_info}") logger.info(f"cache info {cache_info}")
for info in cache_info: for info in cache_info:
if info["request_id"] in self.cache_info: if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info) self.cache_info[info["request_id"]].update(info)
@@ -223,7 +273,7 @@ class CacheMessager:
self.cache_info[info["request_id"]] = info self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0] prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0] prefilled_step_idx = step_shm_value.value[0]
if prefilled_layer_idx == self.num_layers - 1: if prefilled_layer_idx == self.num_hidden_layers - 1:
time.sleep(0.001) time.sleep(0.001)
prefilled_layer_idx = layer_shm_value.value[0] prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0] prefilled_step_idx = step_shm_value.value[0]
@@ -234,7 +284,7 @@ class CacheMessager:
if not self.cache_info: if not self.cache_info:
time.sleep(0.001) time.sleep(0.001)
continue continue
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
for req_id, item in list(self.cache_info.items()): for req_id, item in list(self.cache_info.items()):
if "status" not in item: if "status" not in item:
continue continue
@@ -251,7 +301,7 @@ class CacheMessager:
target_id = int(item["rdma_ports"][self.rank]) target_id = int(item["rdma_ports"][self.rank])
status = self.messager[current_transfer_protocol].connect(target_ip, target_id) status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status: if not status:
logger.error(f"connect to {target_ip}:{target_id} failed") logger.info(f"connect to {target_ip}:{target_id} failed")
item["status"] = "error" item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait() self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0: if self.rank == 0:
@@ -263,7 +313,7 @@ class CacheMessager:
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu") src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu") dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
if item["current_id"] < prefilled_step_idx: if item["current_id"] < prefilled_step_idx:
current_layer_idx = self.num_layers current_layer_idx = self.num_hidden_layers
else: else:
current_layer_idx = prefilled_layer_idx + 1 current_layer_idx = prefilled_layer_idx + 1
@@ -281,7 +331,7 @@ class CacheMessager:
self.engine_worker_queue.finish_request_barrier.wait() self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0: if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")]) self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
logger.error( logger.info(
f"write cache failed, layer_idx: {layer_idx}, " f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}" f"req_id: {item['request_id']}, dest_ip: {target_ip}"
) )
@@ -292,14 +342,14 @@ class CacheMessager:
block_num = len(src_block_ids) block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug( logger.info(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}" f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}" f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}" f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
) )
item["layer_idx"] = current_layer_idx item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers: if item["layer_idx"] == self.num_hidden_layers:
if item["transfer_protocol"] == "ipc": if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id) self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}") logger.info(f"finish write cache {item['request_id']}")
@@ -313,7 +363,7 @@ class CacheMessager:
self.last_layer_idx = prefilled_layer_idx self.last_layer_idx = prefilled_layer_idx
except Exception as e: except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e}") logger.info(f"prefill layerwise send cache thread has exception: {e}")
def _handle_connect_task(self): def _handle_connect_task(self):
while True: while True:
@@ -333,3 +383,90 @@ class CacheMessager:
self.engine_worker_queue.put_connect_rdma_task_response(response) self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e: except Exception as e:
logger.error(f"handle_connect_task has exception: {e}") logger.error(f"handle_connect_task has exception: {e}")
def main():
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
gpu_cache_kvs = {}
gpu_cache_k_tensors = []
gpu_cache_v_tensors = []
for i in range(args.num_hidden_layers + num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=gpu_cache_kvs,
rank=rank,
nranks=args.mp_num,
num_hidden_layers=args.num_hidden_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
cache_ready_signal.value[rank] = 1
cache_messager.prefill_layerwise_send_cache_thread()
if __name__ == "__main__":
args = parse_args()
logger = get_logger("cache_messager", "cache_messager.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc, cuda_host_alloc,
set_data_ipc, share_external_data,
swap_cache_all_layers, swap_cache_all_layers,
) )
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -39,26 +39,12 @@ def parse_args():
从命令行解析参数 从命令行解析参数
""" """
parser = argparse.ArgumentParser("Cache transfer manager") parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers") parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim") parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head") parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port") parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip") parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument( parser.add_argument(
@@ -68,7 +54,6 @@ def parse_args():
help="engine worker queue port", help="engine worker queue port",
) )
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number") parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number") parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)") parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
@@ -109,7 +94,6 @@ class CacheTransferManager:
device = args.device_id device = args.device_id
rank = args.rank rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {} self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {} self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = [] self.gpu_cache_k_tensors = []
@@ -138,40 +122,27 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers): cache_shape = [
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks args.num_gpu_blocks,
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head, args.kv_num_head,
args.block_size, args.block_size,
args.head_dim, args.head_dim,
], ]
fill_value=0,
dtype=cache_type, for i in range(args.num_hidden_layers + self.num_extra_layers):
) num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"]) cache_shape[0] = num_gpu_blocks
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( key_name = f"key_caches_{i}_rank{rank}.device{device}"
shape=[ value_name = f"value_caches_{i}_rank{rank}.device{device}"
num_gpu_blocks, key_cache = paddle.empty(shape=[], dtype=cache_type)
args.kv_num_head, value_cache = paddle.empty(shape=[], dtype=cache_type)
args.block_size, key_cache = share_external_data(key_cache, key_name, cache_shape)
args.head_dim, value_cache = share_external_data(value_cache, value_name, cache_shape)
], self.gpu_cache_kvs[key_name] = key_cache
fill_value=0, self.gpu_cache_kvs[value_name] = value_cache
dtype=cache_type, self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
) self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()]) cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}") logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -180,7 +151,7 @@ class CacheTransferManager:
paddle.set_device("cpu") paddle.set_device("cpu")
self.k_dst_ptrs = [] self.k_dst_ptrs = []
self.v_dst_ptrs = [] self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers): for i in range(args.num_hidden_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc( self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block args.num_cpu_blocks * args.bytes_per_layer_per_block
) )
@@ -190,38 +161,6 @@ class CacheTransferManager:
) )
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"]) self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager
self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal( self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal", name="cache_task_broadcast_signal",

View File

@@ -141,6 +141,76 @@ class PrefixCacheManager:
filename = "cache_transfer_manager.py" filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename) py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = []
if self.splitwise_role != "mixed":
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
cache_manager_processes.extend(cache_messager_processes)
return cache_manager_processes
def launch_cache_messager(
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
):
"""
launch_cache_messager function used to initialize the cache messager.
"""
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_messager.py"
if ( if (
hasattr(cache_config.model_cfg, "num_key_value_heads") hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads") and hasattr(cache_config.model_cfg, "num_key_value_heads")
@@ -159,8 +229,10 @@ class PrefixCacheManager:
suffix=pid_suffix, suffix=pid_suffix,
create=True, create=True,
) )
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR log_dir = envs.FD_LOG_DIR
cache_manager_processes = [] cache_messager_processes = []
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
launch_cmd = ( launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7" "FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
@@ -169,42 +241,34 @@ class PrefixCacheManager:
+ f" --device_id {int(device_ids[i])}" + f" --device_id {int(device_ids[i])}"
+ f" --rank {i}" + f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}" + f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}" + f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}" + f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}" + f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}" + f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}" + f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}" + f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}" + f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}" + f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}" + f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}" + f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1" + f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
) )
logger.info(f"Launch cache transfer manager, command:{launch_cmd}") logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕 logger.info("Waiting for cache ready...")
logger.info("Waiting for cache transfer manager ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1) time.sleep(1)
exit_code = cache_manager_processes[-1].poll() exit_code = cache_messager_processes[-1].poll()
if exit_code is None: if exit_code is None:
logger.info("Launch cache transfer manager successful") logger.info("Launch cache messager successful")
else: else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information") logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
cache_messager_processes = None
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: return cache_messager_processes
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
return cache_manager_processes
def update_cache_config(self, cache_config): def update_cache_config(self, cache_config):
""" """

View File

@@ -775,10 +775,6 @@ class LLMEngine:
""" """
Insert tasks to engine. Insert tasks to engine.
""" """
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
# TODO 返回至 scheduler # TODO 返回至 scheduler
if allocated: if allocated:
current_tasks = [] current_tasks = []
@@ -805,6 +801,11 @@ class LLMEngine:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True return True
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list): if not isinstance(tasks, list):
@@ -846,7 +847,6 @@ class LLMEngine:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks: for task in tasks:
task.inference_start_time = time.time() task.inference_start_time = time.time()
if not is_prefill:
if not self.cfg.enable_mm: if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks) self.update_requests_chunk_size(tasks)
else: else:
@@ -992,14 +992,17 @@ class LLMEngine:
self.running = False self.running = False
if hasattr(self, "cache_manager_processes"): if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes: for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}") llm_logger.info(f"Killing cache manager process {p.pid}")
try: try:
os.killpg(p.pid, signal.SIGTERM) os.killpg(p.pid, signal.SIGTERM)
except Exception as e: except Exception as e:
print(f"Error extracting file: {e}") print(f"Error extracting file: {e}")
if hasattr(self.resource_manager.cache_manager, "cache_ready_signal"):
self.resource_manager.cache_manager.cache_ready_signal.clear()
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
self.worker_ready_signal.clear() self.worker_ready_signal.clear()
self.exist_task_signal.clear() self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear() self.exist_swapped_task_signal.clear()
@@ -1024,6 +1027,7 @@ class LLMEngine:
if hasattr(self, "dp_processed"): if hasattr(self, "dp_processed"):
for p in self.dp_processed: for p in self.dp_processed:
p.join() p.join()
self.engine_worker_queue_server.cleanup()
def _setting_environ_variables(self): def _setting_environ_variables(self):
""" """

View File

@@ -37,6 +37,7 @@ from fastdeploy.model_executor.ops.gpu import (
eagle_get_self_hidden_states, eagle_get_self_hidden_states,
mtp_save_first_token, mtp_save_first_token,
mtp_step_paddle, mtp_step_paddle,
set_data_ipc,
share_external_data, share_external_data,
) )
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
@@ -141,9 +142,7 @@ class MTPProposer(Proposer):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
) )
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
cache_kvs_list = [] cache_kvs_list = []
for i in range( for i in range(
self.num_main_model_layers, self.num_main_model_layers,
@@ -160,7 +159,10 @@ class MTPProposer(Proposer):
self.model_inputs["caches"] = cache_kvs_list self.model_inputs["caches"] = cache_kvs_list
else: else:
for i in range(self.model_config.num_hidden_layers): for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
self.cache_kvs[f"key_caches_{i}"] = paddle.full( self.cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape, shape=kv_cache_shape,
fill_value=0, fill_value=0,
@@ -171,6 +173,15 @@ class MTPProposer(Proposer):
fill_value=0, fill_value=0,
dtype=cache_type, dtype=cache_type,
) )
if self.cache_config.enable_prefix_caching:
set_data_ipc(
self.cache_kvs[f"key_caches_{i}"],
f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}",
)
set_data_ipc(
self.cache_kvs[f"value_caches_{i}"],
f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}",
)
self.model_inputs["caches"] = list(self.cache_kvs.values()) self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values(): for value in self.cache_kvs.values():
del value del value
@@ -235,7 +246,7 @@ class MTPProposer(Proposer):
self.main_model_num_gpu_blocks = num_gpu_blocks self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if self.parallel_config.splitwise_role == "mixed":
self.initialize_kv_cache() self.initialize_kv_cache()
# Reset free list # Reset free list

View File

@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
recover_decode_task, recover_decode_task,
set_data_ipc,
set_value_by_flags_and_idx, set_value_by_flags_and_idx,
share_external_data, share_external_data,
) )
@@ -904,7 +905,7 @@ class GPUModelRunner(ModelRunnerBase):
) )
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if not profile and self.parallel_config.splitwise_role != "mixed":
cache_kvs_list = [] cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -930,6 +931,15 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0, fill_value=0,
dtype=cache_type, dtype=cache_type,
) )
if self.cache_config.enable_prefix_caching:
set_data_ipc(
cache_kvs[f"key_caches_{i}"],
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
)
set_data_ipc(
cache_kvs[f"value_caches_{i}"],
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
)
self.share_inputs["caches"] = list(cache_kvs.values()) self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values(): for value in cache_kvs.values():
del value del value
@@ -1138,6 +1148,8 @@ class GPUModelRunner(ModelRunnerBase):
if task.chunk_idx > len(task.prefill_chunk_info): if task.chunk_idx > len(task.prefill_chunk_info):
continue continue
self.restore_chunked_prefill_request[task.request_id] = task self.restore_chunked_prefill_request[task.request_id] = task
if len(self.restore_chunked_prefill_request) > 0:
self.share_inputs["not_need_stop"][0] = True
for id, task in list(self.restore_chunked_prefill_request.items()): for id, task in list(self.restore_chunked_prefill_request.items()):
idx = task.idx idx = task.idx
@@ -1182,7 +1194,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
self.share_inputs["step_idx"][idx : idx + 1] = 0 self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["stop_flags"][idx : idx + 1] = False
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(): if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
self.proposer.update_task_chunk_prefill(task) self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1 task.chunk_idx += 1
@@ -1507,12 +1519,12 @@ class GPUModelRunner(ModelRunnerBase):
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future # NOTE(liuzichang): Implement multi-layer MTP architecture in the future
num_layers = ( num_hidden_layers = (
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
if self.speculative_method in ["mtp"] if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers else self.model_config.num_hidden_layers
) )
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_hidden_layers # k + v
return required_memory return required_memory
def not_need_stop(self) -> bool: def not_need_stop(self) -> bool:

View File

@@ -408,7 +408,7 @@ class PaddleDisWorkerProc:
logger.info(f"------- num_blocks_global: {num_blocks_local} --------") logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager # wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": if self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal( self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal", name="launched_cache_manager_signal",