mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Feature] Optimize prefix cache (#3208)
* [LLM] support ep * Update worker_process.py * Update expert_service.py * Update worker_process.py * format files * optimize prefix cache * optimize prefix cache * optimize prefix cache * pre commit format * pre commit format * pre commit format * Update cache_messager.py
This commit is contained in:
@@ -14,18 +14,72 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
|
||||
import threading
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.model_executor.ops.gpu import set_data_ipc
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache Messager")
|
||||
parser.add_argument(
|
||||
"--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now",
|
||||
)
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument(
|
||||
"--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
parser.add_argument(
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_config",
|
||||
type=json.loads,
|
||||
default="{}",
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class CacheMessager:
|
||||
@@ -43,7 +97,7 @@ class CacheMessager:
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
num_hidden_layers,
|
||||
gpu_id=0,
|
||||
rdma_port=None,
|
||||
):
|
||||
@@ -57,7 +111,7 @@ class CacheMessager:
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
num_hidden_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
|
||||
@@ -86,13 +140,13 @@ class CacheMessager:
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
cache_k_ptr_list = []
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
for layer_idx in range(self.num_hidden_layers):
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
@@ -109,7 +163,7 @@ class CacheMessager:
|
||||
if key_cache.dtype == paddle.bfloat16:
|
||||
block_bytes *= 2
|
||||
logger.info(
|
||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||
)
|
||||
self.block_bytes = block_bytes
|
||||
@@ -144,17 +198,13 @@ class CacheMessager:
|
||||
self.cache_info = dict()
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _prefill_layerwise_send_cache_thread(self):
|
||||
def prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
@@ -204,7 +254,7 @@ class CacheMessager:
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
|
||||
if cache_info:
|
||||
logger.debug(f"cache info {cache_info}")
|
||||
logger.info(f"cache info {cache_info}")
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
@@ -223,7 +273,7 @@ class CacheMessager:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
if prefilled_layer_idx == self.num_layers - 1:
|
||||
if prefilled_layer_idx == self.num_hidden_layers - 1:
|
||||
time.sleep(0.001)
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
@@ -234,7 +284,7 @@ class CacheMessager:
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
@@ -251,7 +301,7 @@ class CacheMessager:
|
||||
target_id = int(item["rdma_ports"][self.rank])
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
logger.info(f"connect to {target_ip}:{target_id} failed")
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
@@ -263,7 +313,7 @@ class CacheMessager:
|
||||
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
||||
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
||||
if item["current_id"] < prefilled_step_idx:
|
||||
current_layer_idx = self.num_layers
|
||||
current_layer_idx = self.num_hidden_layers
|
||||
else:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
|
||||
@@ -281,7 +331,7 @@ class CacheMessager:
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
||||
logger.error(
|
||||
logger.info(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
)
|
||||
@@ -292,14 +342,14 @@ class CacheMessager:
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
item["layer_idx"] = current_layer_idx
|
||||
if item["layer_idx"] == self.num_layers:
|
||||
if item["layer_idx"] == self.num_hidden_layers:
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
@@ -313,8 +363,8 @@ class CacheMessager:
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e}")
|
||||
|
||||
logger.info(f"prefill layerwise send cache thread has exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
@@ -333,3 +383,90 @@ class CacheMessager:
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
cache_type = args.cache_dtype
|
||||
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
||||
gpu_cache_kvs = {}
|
||||
gpu_cache_k_tensors = []
|
||||
gpu_cache_v_tensors = []
|
||||
|
||||
for i in range(args.num_hidden_layers + num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks
|
||||
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=rank,
|
||||
nranks=args.mp_num,
|
||||
num_hidden_layers=args.num_hidden_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
cache_ready_signal.value[rank] = 1
|
||||
cache_messager.prefill_layerwise_send_cache_thread()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
main()
|
||||
|
Reference in New Issue
Block a user