mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Optimize prefix cache (#3208)
* [LLM] support ep * Update worker_process.py * Update expert_service.py * Update worker_process.py * format files * optimize prefix cache * optimize prefix cache * optimize prefix cache * pre commit format * pre commit format * pre commit format * Update cache_messager.py
This commit is contained in:
@@ -14,18 +14,72 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||||
|
from fastdeploy.config import SpeculativeConfig
|
||||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||||
|
from fastdeploy.model_executor.ops.gpu import set_data_ipc
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger("cache_messager", "cache_messager.log")
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
从命令行解析参数
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser("Cache Messager")
|
||||||
|
parser.add_argument(
|
||||||
|
"--splitwise_role",
|
||||||
|
type=str,
|
||||||
|
default="mixed",
|
||||||
|
help="splitwise role, can be decode, prefill or mixed",
|
||||||
|
)
|
||||||
|
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||||
|
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
|
||||||
|
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||||
|
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||||
|
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||||
|
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||||
|
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||||
|
parser.add_argument(
|
||||||
|
"--protocol",
|
||||||
|
type=str,
|
||||||
|
default="ipc",
|
||||||
|
help="cache transfer protocol, only surport ipc now",
|
||||||
|
)
|
||||||
|
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||||
|
parser.add_argument(
|
||||||
|
"--engine_worker_queue_port",
|
||||||
|
type=int,
|
||||||
|
default=9923,
|
||||||
|
help="engine worker queue port",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||||
|
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dtype",
|
||||||
|
type=str,
|
||||||
|
default="bfloat16",
|
||||||
|
choices=["uint8", "bfloat16"],
|
||||||
|
help="cache dtype",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative_config",
|
||||||
|
type=json.loads,
|
||||||
|
default="{}",
|
||||||
|
help="speculative config",
|
||||||
|
)
|
||||||
|
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
class CacheMessager:
|
class CacheMessager:
|
||||||
@@ -43,7 +97,7 @@ class CacheMessager:
|
|||||||
gpu_cache_kvs,
|
gpu_cache_kvs,
|
||||||
rank,
|
rank,
|
||||||
nranks,
|
nranks,
|
||||||
num_layers,
|
num_hidden_layers,
|
||||||
gpu_id=0,
|
gpu_id=0,
|
||||||
rdma_port=None,
|
rdma_port=None,
|
||||||
):
|
):
|
||||||
@@ -57,7 +111,7 @@ class CacheMessager:
|
|||||||
gpu_cache_kvs (dict): GPU kv cache
|
gpu_cache_kvs (dict): GPU kv cache
|
||||||
rank (int): current rank
|
rank (int): current rank
|
||||||
nranks (int): global rank number
|
nranks (int): global rank number
|
||||||
num_layers (int): model layer number
|
num_hidden_layers (int): model layer number
|
||||||
gpu_id (int, optional): GPU ID
|
gpu_id (int, optional): GPU ID
|
||||||
rdma_port (int, optional): RDMA port
|
rdma_port (int, optional): RDMA port
|
||||||
|
|
||||||
@@ -86,13 +140,13 @@ class CacheMessager:
|
|||||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||||
|
|
||||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||||
self.num_layers = num_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
cache_k_ptr_list = []
|
cache_k_ptr_list = []
|
||||||
cache_v_ptr_list = []
|
cache_v_ptr_list = []
|
||||||
cache_k = []
|
cache_k = []
|
||||||
cache_v = []
|
cache_v = []
|
||||||
self.messager = {}
|
self.messager = {}
|
||||||
for layer_idx in range(self.num_layers):
|
for layer_idx in range(self.num_hidden_layers):
|
||||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||||
cache_k.append(key_cache)
|
cache_k.append(key_cache)
|
||||||
@@ -109,7 +163,7 @@ class CacheMessager:
|
|||||||
if key_cache.dtype == paddle.bfloat16:
|
if key_cache.dtype == paddle.bfloat16:
|
||||||
block_bytes *= 2
|
block_bytes *= 2
|
||||||
logger.info(
|
logger.info(
|
||||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||||
)
|
)
|
||||||
self.block_bytes = block_bytes
|
self.block_bytes = block_bytes
|
||||||
@@ -144,17 +198,13 @@ class CacheMessager:
|
|||||||
self.cache_info = dict()
|
self.cache_info = dict()
|
||||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
|
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
|
||||||
|
|
||||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
|
||||||
layerwise_send_cache_thread.daemon = True
|
|
||||||
layerwise_send_cache_thread.start()
|
|
||||||
|
|
||||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||||
connect_rdma_thread.daemon = True
|
connect_rdma_thread.daemon = True
|
||||||
connect_rdma_thread.start()
|
connect_rdma_thread.start()
|
||||||
|
|
||||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||||
|
|
||||||
def _prefill_layerwise_send_cache_thread(self):
|
def prefill_layerwise_send_cache_thread(self):
|
||||||
"""
|
"""
|
||||||
layerwise_send_cache_thread:
|
layerwise_send_cache_thread:
|
||||||
send cache to other instance
|
send cache to other instance
|
||||||
@@ -204,7 +254,7 @@ class CacheMessager:
|
|||||||
cache_info = self.engine_worker_queue.get_cache_info()
|
cache_info = self.engine_worker_queue.get_cache_info()
|
||||||
|
|
||||||
if cache_info:
|
if cache_info:
|
||||||
logger.debug(f"cache info {cache_info}")
|
logger.info(f"cache info {cache_info}")
|
||||||
for info in cache_info:
|
for info in cache_info:
|
||||||
if info["request_id"] in self.cache_info:
|
if info["request_id"] in self.cache_info:
|
||||||
self.cache_info[info["request_id"]].update(info)
|
self.cache_info[info["request_id"]].update(info)
|
||||||
@@ -223,7 +273,7 @@ class CacheMessager:
|
|||||||
self.cache_info[info["request_id"]] = info
|
self.cache_info[info["request_id"]] = info
|
||||||
prefilled_layer_idx = layer_shm_value.value[0]
|
prefilled_layer_idx = layer_shm_value.value[0]
|
||||||
prefilled_step_idx = step_shm_value.value[0]
|
prefilled_step_idx = step_shm_value.value[0]
|
||||||
if prefilled_layer_idx == self.num_layers - 1:
|
if prefilled_layer_idx == self.num_hidden_layers - 1:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
prefilled_layer_idx = layer_shm_value.value[0]
|
prefilled_layer_idx = layer_shm_value.value[0]
|
||||||
prefilled_step_idx = step_shm_value.value[0]
|
prefilled_step_idx = step_shm_value.value[0]
|
||||||
@@ -234,7 +284,7 @@ class CacheMessager:
|
|||||||
if not self.cache_info:
|
if not self.cache_info:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||||
for req_id, item in list(self.cache_info.items()):
|
for req_id, item in list(self.cache_info.items()):
|
||||||
if "status" not in item:
|
if "status" not in item:
|
||||||
continue
|
continue
|
||||||
@@ -251,7 +301,7 @@ class CacheMessager:
|
|||||||
target_id = int(item["rdma_ports"][self.rank])
|
target_id = int(item["rdma_ports"][self.rank])
|
||||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||||
if not status:
|
if not status:
|
||||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
logger.info(f"connect to {target_ip}:{target_id} failed")
|
||||||
item["status"] = "error"
|
item["status"] = "error"
|
||||||
self.engine_worker_queue.finish_request_barrier.wait()
|
self.engine_worker_queue.finish_request_barrier.wait()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
@@ -263,7 +313,7 @@ class CacheMessager:
|
|||||||
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
||||||
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
||||||
if item["current_id"] < prefilled_step_idx:
|
if item["current_id"] < prefilled_step_idx:
|
||||||
current_layer_idx = self.num_layers
|
current_layer_idx = self.num_hidden_layers
|
||||||
else:
|
else:
|
||||||
current_layer_idx = prefilled_layer_idx + 1
|
current_layer_idx = prefilled_layer_idx + 1
|
||||||
|
|
||||||
@@ -281,7 +331,7 @@ class CacheMessager:
|
|||||||
self.engine_worker_queue.finish_request_barrier.wait()
|
self.engine_worker_queue.finish_request_barrier.wait()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
||||||
logger.error(
|
logger.info(
|
||||||
f"write cache failed, layer_idx: {layer_idx}, "
|
f"write cache failed, layer_idx: {layer_idx}, "
|
||||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||||
)
|
)
|
||||||
@@ -292,14 +342,14 @@ class CacheMessager:
|
|||||||
block_num = len(src_block_ids)
|
block_num = len(src_block_ids)
|
||||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||||
f" {current_transfer_protocol}"
|
f" {current_transfer_protocol}"
|
||||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||||
)
|
)
|
||||||
item["layer_idx"] = current_layer_idx
|
item["layer_idx"] = current_layer_idx
|
||||||
if item["layer_idx"] == self.num_layers:
|
if item["layer_idx"] == self.num_hidden_layers:
|
||||||
if item["transfer_protocol"] == "ipc":
|
if item["transfer_protocol"] == "ipc":
|
||||||
self.messager["ipc"].write_block_by_sync(target_id)
|
self.messager["ipc"].write_block_by_sync(target_id)
|
||||||
logger.info(f"finish write cache {item['request_id']}")
|
logger.info(f"finish write cache {item['request_id']}")
|
||||||
@@ -313,7 +363,7 @@ class CacheMessager:
|
|||||||
self.last_layer_idx = prefilled_layer_idx
|
self.last_layer_idx = prefilled_layer_idx
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"prefill layerwise send cache thread has exception: {e}")
|
logger.info(f"prefill layerwise send cache thread has exception: {e}")
|
||||||
|
|
||||||
def _handle_connect_task(self):
|
def _handle_connect_task(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -333,3 +383,90 @@ class CacheMessager:
|
|||||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"handle_connect_task has exception: {e}")
|
logger.error(f"handle_connect_task has exception: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = args.device_id
|
||||||
|
rank = args.rank
|
||||||
|
paddle.set_device(f"gpu:{device}")
|
||||||
|
cache_type = args.cache_dtype
|
||||||
|
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||||
|
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||||
|
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
||||||
|
gpu_cache_kvs = {}
|
||||||
|
gpu_cache_k_tensors = []
|
||||||
|
gpu_cache_v_tensors = []
|
||||||
|
|
||||||
|
for i in range(args.num_hidden_layers + num_extra_layers):
|
||||||
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks
|
||||||
|
|
||||||
|
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||||
|
shape=[
|
||||||
|
num_gpu_blocks,
|
||||||
|
args.kv_num_head,
|
||||||
|
args.block_size,
|
||||||
|
args.head_dim,
|
||||||
|
],
|
||||||
|
fill_value=0,
|
||||||
|
dtype=cache_type,
|
||||||
|
)
|
||||||
|
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||||
|
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||||
|
shape=[
|
||||||
|
num_gpu_blocks,
|
||||||
|
args.kv_num_head,
|
||||||
|
args.block_size,
|
||||||
|
args.head_dim,
|
||||||
|
],
|
||||||
|
fill_value=0,
|
||||||
|
dtype=cache_type,
|
||||||
|
)
|
||||||
|
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||||
|
|
||||||
|
set_data_ipc(
|
||||||
|
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||||
|
f"key_caches_{i}_rank{rank}.device{device}",
|
||||||
|
)
|
||||||
|
set_data_ipc(
|
||||||
|
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||||
|
f"value_caches_{i}_rank{rank}.device{device}",
|
||||||
|
)
|
||||||
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||||
|
logger.info(f"device :{device}")
|
||||||
|
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||||
|
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||||
|
|
||||||
|
cache_messager = CacheMessager(
|
||||||
|
splitwise_role=args.splitwise_role,
|
||||||
|
transfer_protocol=args.protocol,
|
||||||
|
pod_ip=args.pod_ip,
|
||||||
|
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||||
|
local_data_parallel_id=args.local_data_parallel_id,
|
||||||
|
gpu_cache_kvs=gpu_cache_kvs,
|
||||||
|
rank=rank,
|
||||||
|
nranks=args.mp_num,
|
||||||
|
num_hidden_layers=args.num_hidden_layers + num_extra_layers,
|
||||||
|
gpu_id=device,
|
||||||
|
rdma_port=args.rdma_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||||
|
cache_ready_signal = IPCSignal(
|
||||||
|
name="cache_ready_signal",
|
||||||
|
array=cache_ready_signal_data,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=args.engine_pid,
|
||||||
|
create=False,
|
||||||
|
)
|
||||||
|
cache_ready_signal.value[rank] = 1
|
||||||
|
cache_messager.prefill_layerwise_send_cache_thread()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
logger = get_logger("cache_messager", "cache_messager.log")
|
||||||
|
|
||||||
|
logger.info("create cache messager...")
|
||||||
|
logger.info(f"{args}")
|
||||||
|
main()
|
||||||
|
@@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig
|
|||||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
cuda_host_alloc,
|
cuda_host_alloc,
|
||||||
set_data_ipc,
|
share_external_data,
|
||||||
swap_cache_all_layers,
|
swap_cache_all_layers,
|
||||||
)
|
)
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
@@ -39,26 +39,12 @@ def parse_args():
|
|||||||
从命令行解析参数
|
从命令行解析参数
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser("Cache transfer manager")
|
parser = argparse.ArgumentParser("Cache transfer manager")
|
||||||
parser.add_argument(
|
|
||||||
"--splitwise_role",
|
|
||||||
type=str,
|
|
||||||
default="mixed",
|
|
||||||
help="splitwise role, can be decode, prefill or mixed",
|
|
||||||
)
|
|
||||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
|
||||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
|
||||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||||
parser.add_argument(
|
|
||||||
"--protocol",
|
|
||||||
type=str,
|
|
||||||
default="ipc",
|
|
||||||
help="cache transfer protocol, only surport ipc now",
|
|
||||||
)
|
|
||||||
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
|
|
||||||
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
||||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -68,7 +54,6 @@ def parse_args():
|
|||||||
help="engine worker queue port",
|
help="engine worker queue port",
|
||||||
)
|
)
|
||||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||||
|
|
||||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||||
@@ -109,7 +94,6 @@ class CacheTransferManager:
|
|||||||
|
|
||||||
device = args.device_id
|
device = args.device_id
|
||||||
rank = args.rank
|
rank = args.rank
|
||||||
paddle.set_device(f"gpu:{device}")
|
|
||||||
self.gpu_cache_kvs = {}
|
self.gpu_cache_kvs = {}
|
||||||
self.cpu_cache_kvs = {}
|
self.cpu_cache_kvs = {}
|
||||||
self.gpu_cache_k_tensors = []
|
self.gpu_cache_k_tensors = []
|
||||||
@@ -138,40 +122,27 @@ class CacheTransferManager:
|
|||||||
self.num_cpu_blocks = args.num_cpu_blocks
|
self.num_cpu_blocks = args.num_cpu_blocks
|
||||||
|
|
||||||
cache_type = args.cache_dtype
|
cache_type = args.cache_dtype
|
||||||
for i in range(args.num_layers + self.num_extra_layers):
|
cache_shape = [
|
||||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
args.num_gpu_blocks,
|
||||||
|
|
||||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
|
||||||
shape=[
|
|
||||||
num_gpu_blocks,
|
|
||||||
args.kv_num_head,
|
args.kv_num_head,
|
||||||
args.block_size,
|
args.block_size,
|
||||||
args.head_dim,
|
args.head_dim,
|
||||||
],
|
]
|
||||||
fill_value=0,
|
|
||||||
dtype=cache_type,
|
for i in range(args.num_hidden_layers + self.num_extra_layers):
|
||||||
)
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
|
||||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
cache_shape[0] = num_gpu_blocks
|
||||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
key_name = f"key_caches_{i}_rank{rank}.device{device}"
|
||||||
shape=[
|
value_name = f"value_caches_{i}_rank{rank}.device{device}"
|
||||||
num_gpu_blocks,
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
args.kv_num_head,
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
args.block_size,
|
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||||
args.head_dim,
|
value_cache = share_external_data(value_cache, value_name, cache_shape)
|
||||||
],
|
self.gpu_cache_kvs[key_name] = key_cache
|
||||||
fill_value=0,
|
self.gpu_cache_kvs[value_name] = value_cache
|
||||||
dtype=cache_type,
|
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||||
)
|
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
|
||||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
|
||||||
|
|
||||||
set_data_ipc(
|
|
||||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
|
||||||
f"key_caches_{i}_rank{rank}.device{device}",
|
|
||||||
)
|
|
||||||
set_data_ipc(
|
|
||||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
|
||||||
f"value_caches_{i}_rank{rank}.device{device}",
|
|
||||||
)
|
|
||||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||||
logger.info(f"device :{self.device}")
|
logger.info(f"device :{self.device}")
|
||||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||||
@@ -180,7 +151,7 @@ class CacheTransferManager:
|
|||||||
paddle.set_device("cpu")
|
paddle.set_device("cpu")
|
||||||
self.k_dst_ptrs = []
|
self.k_dst_ptrs = []
|
||||||
self.v_dst_ptrs = []
|
self.v_dst_ptrs = []
|
||||||
for i in range(args.num_layers + self.num_extra_layers):
|
for i in range(args.num_hidden_layers + self.num_extra_layers):
|
||||||
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
||||||
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||||
)
|
)
|
||||||
@@ -190,38 +161,6 @@ class CacheTransferManager:
|
|||||||
)
|
)
|
||||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
|
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
|
||||||
|
|
||||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
|
||||||
self.cache_ready_signal = IPCSignal(
|
|
||||||
name="cache_ready_signal",
|
|
||||||
array=cache_ready_signal_data,
|
|
||||||
dtype=np.int32,
|
|
||||||
suffix=args.engine_pid,
|
|
||||||
create=False,
|
|
||||||
)
|
|
||||||
self.cache_ready_signal.value[self.rank] = 1
|
|
||||||
|
|
||||||
paddle.set_device(f"gpu:{device}")
|
|
||||||
if args.enable_splitwise:
|
|
||||||
logger.debug("create cache messager...")
|
|
||||||
logger.info(f"{args}")
|
|
||||||
from fastdeploy.cache_manager.cache_messager import CacheMessager
|
|
||||||
|
|
||||||
self.cache_messager = CacheMessager(
|
|
||||||
splitwise_role=args.splitwise_role,
|
|
||||||
transfer_protocol=args.protocol,
|
|
||||||
pod_ip=args.pod_ip,
|
|
||||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
|
||||||
local_data_parallel_id=args.local_data_parallel_id,
|
|
||||||
gpu_cache_kvs=self.gpu_cache_kvs,
|
|
||||||
rank=self.rank,
|
|
||||||
nranks=args.mp_num,
|
|
||||||
num_layers=args.num_layers + self.num_extra_layers,
|
|
||||||
gpu_id=self.device,
|
|
||||||
rdma_port=args.rdma_port,
|
|
||||||
)
|
|
||||||
logger.info("successfully create cache messager")
|
|
||||||
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
|
||||||
|
|
||||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||||
self.cache_task_broadcast_signal = IPCSignal(
|
self.cache_task_broadcast_signal = IPCSignal(
|
||||||
name="cache_task_broadcast_signal",
|
name="cache_task_broadcast_signal",
|
||||||
|
@@ -141,6 +141,76 @@ class PrefixCacheManager:
|
|||||||
filename = "cache_transfer_manager.py"
|
filename = "cache_transfer_manager.py"
|
||||||
py_path = os.path.join(current_dir_path, filename)
|
py_path = os.path.join(current_dir_path, filename)
|
||||||
|
|
||||||
|
cache_messager_processes = []
|
||||||
|
if self.splitwise_role != "mixed":
|
||||||
|
cache_messager_processes = self.launch_cache_messager(
|
||||||
|
cache_config,
|
||||||
|
tensor_parallel_size,
|
||||||
|
device_ids,
|
||||||
|
pod_ip,
|
||||||
|
engine_worker_queue_port,
|
||||||
|
pid_suffix,
|
||||||
|
)
|
||||||
|
if cache_messager_processes is None:
|
||||||
|
raise RuntimeError("Launch cache messager failed")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
|
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
|
and cache_config.model_cfg.num_key_value_heads is not None
|
||||||
|
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||||
|
):
|
||||||
|
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||||
|
else:
|
||||||
|
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||||
|
|
||||||
|
log_dir = envs.FD_LOG_DIR
|
||||||
|
cache_manager_processes = []
|
||||||
|
for i in range(tensor_parallel_size):
|
||||||
|
launch_cmd = (
|
||||||
|
f" {sys.executable} {py_path}"
|
||||||
|
+ f" --device_id {int(device_ids[i])}"
|
||||||
|
+ f" --rank {i}"
|
||||||
|
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||||
|
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||||
|
+ f" --kv_num_head {kv_num_head}"
|
||||||
|
+ f" --mp_num {tensor_parallel_size}"
|
||||||
|
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||||
|
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||||
|
+ f" --pod_ip {pod_ip}"
|
||||||
|
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||||
|
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||||
|
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||||
|
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
||||||
|
+ f" --block_size {cache_config.block_size}"
|
||||||
|
+ f" --engine_pid {pid_suffix}"
|
||||||
|
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||||
|
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||||
|
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||||
|
)
|
||||||
|
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||||
|
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||||
|
exit_code = cache_manager_processes[-1].poll()
|
||||||
|
if exit_code is None:
|
||||||
|
logger.info("Launch cache transfer manager successful")
|
||||||
|
else:
|
||||||
|
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||||
|
|
||||||
|
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||||
|
logger.info("Enable hierarchical cache.")
|
||||||
|
self._enable_cpu_cache()
|
||||||
|
cache_manager_processes.extend(cache_messager_processes)
|
||||||
|
return cache_manager_processes
|
||||||
|
|
||||||
|
def launch_cache_messager(
|
||||||
|
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
launch_cache_messager function used to initialize the cache messager.
|
||||||
|
"""
|
||||||
|
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||||
|
filename = "cache_messager.py"
|
||||||
if (
|
if (
|
||||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
@@ -159,8 +229,10 @@ class PrefixCacheManager:
|
|||||||
suffix=pid_suffix,
|
suffix=pid_suffix,
|
||||||
create=True,
|
create=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_path = os.path.join(current_dir_path, filename)
|
||||||
log_dir = envs.FD_LOG_DIR
|
log_dir = envs.FD_LOG_DIR
|
||||||
cache_manager_processes = []
|
cache_messager_processes = []
|
||||||
for i in range(tensor_parallel_size):
|
for i in range(tensor_parallel_size):
|
||||||
launch_cmd = (
|
launch_cmd = (
|
||||||
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||||
@@ -169,42 +241,34 @@ class PrefixCacheManager:
|
|||||||
+ f" --device_id {int(device_ids[i])}"
|
+ f" --device_id {int(device_ids[i])}"
|
||||||
+ f" --rank {i}"
|
+ f" --rank {i}"
|
||||||
+ f" --splitwise_role {self.splitwise_role}"
|
+ f" --splitwise_role {self.splitwise_role}"
|
||||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||||
+ f" --kv_num_head {kv_num_head}"
|
+ f" --kv_num_head {kv_num_head}"
|
||||||
+ f" --mp_num {tensor_parallel_size}"
|
+ f" --mp_num {tensor_parallel_size}"
|
||||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
|
||||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
|
||||||
+ f" --pod_ip {pod_ip}"
|
+ f" --pod_ip {pod_ip}"
|
||||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
|
||||||
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
|
||||||
+ f" --block_size {cache_config.block_size}"
|
+ f" --block_size {cache_config.block_size}"
|
||||||
+ f" --engine_pid {pid_suffix}"
|
|
||||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||||
|
+ f" --engine_pid {pid_suffix}"
|
||||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
|
||||||
)
|
)
|
||||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||||
# 等待cache初始化完毕
|
logger.info("Waiting for cache ready...")
|
||||||
logger.info("Waiting for cache transfer manager ready...")
|
|
||||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
exit_code = cache_manager_processes[-1].poll()
|
exit_code = cache_messager_processes[-1].poll()
|
||||||
if exit_code is None:
|
if exit_code is None:
|
||||||
logger.info("Launch cache transfer manager successful")
|
logger.info("Launch cache messager successful")
|
||||||
else:
|
else:
|
||||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
|
||||||
|
cache_messager_processes = None
|
||||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
return cache_messager_processes
|
||||||
logger.info("Enable hierarchical cache.")
|
|
||||||
self._enable_cpu_cache()
|
|
||||||
return cache_manager_processes
|
|
||||||
|
|
||||||
def update_cache_config(self, cache_config):
|
def update_cache_config(self, cache_config):
|
||||||
"""
|
"""
|
||||||
|
@@ -775,10 +775,6 @@ class LLMEngine:
|
|||||||
"""
|
"""
|
||||||
Insert tasks to engine.
|
Insert tasks to engine.
|
||||||
"""
|
"""
|
||||||
for task in tasks:
|
|
||||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
|
||||||
if task.sampling_params.bad_words is not None:
|
|
||||||
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
|
||||||
# TODO 返回至 scheduler
|
# TODO 返回至 scheduler
|
||||||
if allocated:
|
if allocated:
|
||||||
current_tasks = []
|
current_tasks = []
|
||||||
@@ -805,6 +801,11 @@ class LLMEngine:
|
|||||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||||
|
if task.sampling_params.bad_words is not None:
|
||||||
|
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||||
|
|
||||||
self.resource_manager.check_and_free_block_tables()
|
self.resource_manager.check_and_free_block_tables()
|
||||||
|
|
||||||
if not isinstance(tasks, list):
|
if not isinstance(tasks, list):
|
||||||
@@ -846,7 +847,6 @@ class LLMEngine:
|
|||||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.inference_start_time = time.time()
|
task.inference_start_time = time.time()
|
||||||
if not is_prefill:
|
|
||||||
if not self.cfg.enable_mm:
|
if not self.cfg.enable_mm:
|
||||||
self.update_requests_chunk_size(tasks)
|
self.update_requests_chunk_size(tasks)
|
||||||
else:
|
else:
|
||||||
@@ -992,14 +992,17 @@ class LLMEngine:
|
|||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
if hasattr(self, "cache_manager_processes"):
|
if hasattr(self, "cache_manager_processes"):
|
||||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
|
||||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
|
||||||
for p in self.cache_manager_processes:
|
for p in self.cache_manager_processes:
|
||||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||||
try:
|
try:
|
||||||
os.killpg(p.pid, signal.SIGTERM)
|
os.killpg(p.pid, signal.SIGTERM)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting file: {e}")
|
print(f"Error extracting file: {e}")
|
||||||
|
if hasattr(self.resource_manager.cache_manager, "cache_ready_signal"):
|
||||||
|
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||||
|
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||||
|
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||||||
|
self.zmq_server.close()
|
||||||
self.worker_ready_signal.clear()
|
self.worker_ready_signal.clear()
|
||||||
self.exist_task_signal.clear()
|
self.exist_task_signal.clear()
|
||||||
self.exist_swapped_task_signal.clear()
|
self.exist_swapped_task_signal.clear()
|
||||||
@@ -1024,6 +1027,7 @@ class LLMEngine:
|
|||||||
if hasattr(self, "dp_processed"):
|
if hasattr(self, "dp_processed"):
|
||||||
for p in self.dp_processed:
|
for p in self.dp_processed:
|
||||||
p.join()
|
p.join()
|
||||||
|
self.engine_worker_queue_server.cleanup()
|
||||||
|
|
||||||
def _setting_environ_variables(self):
|
def _setting_environ_variables(self):
|
||||||
"""
|
"""
|
||||||
|
@@ -37,6 +37,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
|||||||
eagle_get_self_hidden_states,
|
eagle_get_self_hidden_states,
|
||||||
mtp_save_first_token,
|
mtp_save_first_token,
|
||||||
mtp_step_paddle,
|
mtp_step_paddle,
|
||||||
|
set_data_ipc,
|
||||||
share_external_data,
|
share_external_data,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||||
@@ -141,9 +142,7 @@ class MTPProposer(Proposer):
|
|||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||||
)
|
)
|
||||||
if not self.parallel_config.do_profile and (
|
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
|
||||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
|
||||||
):
|
|
||||||
cache_kvs_list = []
|
cache_kvs_list = []
|
||||||
for i in range(
|
for i in range(
|
||||||
self.num_main_model_layers,
|
self.num_main_model_layers,
|
||||||
@@ -160,7 +159,10 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
self.model_inputs["caches"] = cache_kvs_list
|
self.model_inputs["caches"] = cache_kvs_list
|
||||||
else:
|
else:
|
||||||
for i in range(self.model_config.num_hidden_layers):
|
for i in range(
|
||||||
|
self.num_main_model_layers,
|
||||||
|
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||||
|
):
|
||||||
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
|
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||||
shape=kv_cache_shape,
|
shape=kv_cache_shape,
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
@@ -171,6 +173,15 @@ class MTPProposer(Proposer):
|
|||||||
fill_value=0,
|
fill_value=0,
|
||||||
dtype=cache_type,
|
dtype=cache_type,
|
||||||
)
|
)
|
||||||
|
if self.cache_config.enable_prefix_caching:
|
||||||
|
set_data_ipc(
|
||||||
|
self.cache_kvs[f"key_caches_{i}"],
|
||||||
|
f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}",
|
||||||
|
)
|
||||||
|
set_data_ipc(
|
||||||
|
self.cache_kvs[f"value_caches_{i}"],
|
||||||
|
f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}",
|
||||||
|
)
|
||||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||||
for value in self.cache_kvs.values():
|
for value in self.cache_kvs.values():
|
||||||
del value
|
del value
|
||||||
@@ -235,7 +246,7 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
self.main_model_num_gpu_blocks = num_gpu_blocks
|
self.main_model_num_gpu_blocks = num_gpu_blocks
|
||||||
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||||
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
if self.parallel_config.splitwise_role == "mixed":
|
||||||
self.initialize_kv_cache()
|
self.initialize_kv_cache()
|
||||||
|
|
||||||
# Reset free list
|
# Reset free list
|
||||||
|
@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
|
|||||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
recover_decode_task,
|
recover_decode_task,
|
||||||
|
set_data_ipc,
|
||||||
set_value_by_flags_and_idx,
|
set_value_by_flags_and_idx,
|
||||||
share_external_data,
|
share_external_data,
|
||||||
)
|
)
|
||||||
@@ -904,7 +905,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
if not profile and self.parallel_config.splitwise_role != "mixed":
|
||||||
cache_kvs_list = []
|
cache_kvs_list = []
|
||||||
for i in range(self.model_config.num_hidden_layers):
|
for i in range(self.model_config.num_hidden_layers):
|
||||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
@@ -930,6 +931,15 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
fill_value=0,
|
fill_value=0,
|
||||||
dtype=cache_type,
|
dtype=cache_type,
|
||||||
)
|
)
|
||||||
|
if self.cache_config.enable_prefix_caching:
|
||||||
|
set_data_ipc(
|
||||||
|
cache_kvs[f"key_caches_{i}"],
|
||||||
|
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||||
|
)
|
||||||
|
set_data_ipc(
|
||||||
|
cache_kvs[f"value_caches_{i}"],
|
||||||
|
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||||
|
)
|
||||||
self.share_inputs["caches"] = list(cache_kvs.values())
|
self.share_inputs["caches"] = list(cache_kvs.values())
|
||||||
for value in cache_kvs.values():
|
for value in cache_kvs.values():
|
||||||
del value
|
del value
|
||||||
@@ -1138,6 +1148,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
if task.chunk_idx > len(task.prefill_chunk_info):
|
if task.chunk_idx > len(task.prefill_chunk_info):
|
||||||
continue
|
continue
|
||||||
self.restore_chunked_prefill_request[task.request_id] = task
|
self.restore_chunked_prefill_request[task.request_id] = task
|
||||||
|
if len(self.restore_chunked_prefill_request) > 0:
|
||||||
|
self.share_inputs["not_need_stop"][0] = True
|
||||||
|
|
||||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||||
idx = task.idx
|
idx = task.idx
|
||||||
@@ -1182,7 +1194,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||||
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
|
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
|
||||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||||
|
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
||||||
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
|
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
|
||||||
self.proposer.update_task_chunk_prefill(task)
|
self.proposer.update_task_chunk_prefill(task)
|
||||||
task.chunk_idx += 1
|
task.chunk_idx += 1
|
||||||
@@ -1507,12 +1519,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
|
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
|
||||||
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
|
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
|
||||||
num_layers = (
|
num_hidden_layers = (
|
||||||
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
|
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
|
||||||
if self.speculative_method in ["mtp"]
|
if self.speculative_method in ["mtp"]
|
||||||
else self.model_config.num_hidden_layers
|
else self.model_config.num_hidden_layers
|
||||||
)
|
)
|
||||||
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
|
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_hidden_layers # k + v
|
||||||
return required_memory
|
return required_memory
|
||||||
|
|
||||||
def not_need_stop(self) -> bool:
|
def not_need_stop(self) -> bool:
|
||||||
|
@@ -408,7 +408,7 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||||
# wait engine launch cache_manager
|
# wait engine launch cache_manager
|
||||||
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
if self.parallel_config.splitwise_role != "mixed":
|
||||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||||
self.launched_cache_manager_signal = IPCSignal(
|
self.launched_cache_manager_signal = IPCSignal(
|
||||||
name="launched_cache_manager_signal",
|
name="launched_cache_manager_signal",
|
||||||
|
Reference in New Issue
Block a user