mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
Compare commits
16 Commits
develop
...
feature/on
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7b09611d6b | ||
![]() |
606d9e9c2c | ||
![]() |
d18a637a17 | ||
![]() |
6854506533 | ||
![]() |
c487b62ee0 | ||
![]() |
d2f6c3b998 | ||
![]() |
aba94169dc | ||
![]() |
3f86ae0007 | ||
![]() |
89177d881c | ||
![]() |
7573802a88 | ||
![]() |
110f33a530 | ||
![]() |
a4572a5e5d | ||
![]() |
a9d231c900 | ||
![]() |
b20ffe3697 | ||
![]() |
dcf9c2daff | ||
![]() |
9f9971844f |
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
@@ -22,10 +24,63 @@ import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.model_executor.ops.gpu import set_data_ipc
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache Messager")
|
||||
parser.add_argument(
|
||||
"--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now",
|
||||
)
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument(
|
||||
"--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
parser.add_argument(
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_config",
|
||||
type=json.loads,
|
||||
default="{}",
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class CacheMessager:
|
||||
@@ -43,7 +98,7 @@ class CacheMessager:
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
num_hidden_layers,
|
||||
gpu_id=0,
|
||||
rdma_port=None,
|
||||
):
|
||||
@@ -57,7 +112,7 @@ class CacheMessager:
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
num_hidden_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
|
||||
@@ -73,7 +128,7 @@ class CacheMessager:
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
self.nranks = nranks
|
||||
address = (pod_ip, engine_worker_queue_port)
|
||||
address = (pod_ip, engine_worker_queue_port + local_data_parallel_id)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
@@ -86,13 +141,13 @@ class CacheMessager:
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
cache_k_ptr_list = []
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
for layer_idx in range(self.num_hidden_layers):
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
@@ -109,7 +164,7 @@ class CacheMessager:
|
||||
if key_cache.dtype == paddle.bfloat16:
|
||||
block_bytes *= 2
|
||||
logger.info(
|
||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||
)
|
||||
self.block_bytes = block_bytes
|
||||
@@ -142,15 +197,17 @@ class CacheMessager:
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
self.rank_id = (
|
||||
self.rank + local_data_parallel_id * self.nranks
|
||||
) # align with engine worker rank (paddle.distributed.launch)
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _prefill_layerwise_send_cache_thread(self):
|
||||
def prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
@@ -160,14 +217,14 @@ class CacheMessager:
|
||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
try:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=True,
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
@@ -175,14 +232,14 @@ class CacheMessager:
|
||||
)
|
||||
except:
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=False,
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
@@ -195,12 +252,15 @@ class CacheMessager:
|
||||
self.last_step_idx = -1
|
||||
self.last_layer_idx = -1 # int32
|
||||
|
||||
max_step_idx = 100003
|
||||
engine_recycled_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
|
||||
if cache_info:
|
||||
logger.debug(f"cache info {cache_info}")
|
||||
logger.info(f"cache info {cache_info}")
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
@@ -214,12 +274,11 @@ class CacheMessager:
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"start cache_infos: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
if prefilled_layer_idx == self.num_layers - 1:
|
||||
if prefilled_layer_idx == self.num_hidden_layers - 1:
|
||||
time.sleep(0.001)
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
@@ -230,7 +289,18 @@ class CacheMessager:
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
if self.last_step_idx > prefilled_step_idx:
|
||||
engine_recycled_count += 1
|
||||
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
|
||||
prefilled_step_idx = (
|
||||
prefilled_step_idx + max_step_idx * engine_recycled_count
|
||||
) # remap prefilled_step_idx for comparison
|
||||
|
||||
logger.debug(
|
||||
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
|
||||
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
|
||||
)
|
||||
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
@@ -247,7 +317,7 @@ class CacheMessager:
|
||||
target_id = int(item["rdma_ports"][self.rank])
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
logger.info(f"connect to {target_ip}:{target_id} failed")
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
@@ -259,9 +329,10 @@ class CacheMessager:
|
||||
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
||||
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
||||
if item["current_id"] < prefilled_step_idx:
|
||||
current_layer_idx = self.num_layers
|
||||
current_layer_idx = self.num_hidden_layers
|
||||
else:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
if item["current_id"] == prefilled_step_idx:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
@@ -277,7 +348,7 @@ class CacheMessager:
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
||||
logger.error(
|
||||
logger.info(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
)
|
||||
@@ -288,14 +359,14 @@ class CacheMessager:
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
item["layer_idx"] = current_layer_idx
|
||||
if item["layer_idx"] == self.num_layers:
|
||||
if item["layer_idx"] == self.num_hidden_layers:
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
@@ -304,9 +375,114 @@ class CacheMessager:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
del self.cache_info[req_id]
|
||||
|
||||
self.last_step_idx = prefilled_step_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e}")
|
||||
logger.info(f"prefill layerwise send cache thread has exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
cache_type = args.cache_dtype
|
||||
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
||||
gpu_cache_kvs = {}
|
||||
gpu_cache_k_tensors = []
|
||||
gpu_cache_v_tensors = []
|
||||
|
||||
for i in range(args.num_hidden_layers + num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks
|
||||
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=rank,
|
||||
nranks=args.mp_num,
|
||||
num_hidden_layers=args.num_hidden_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
cache_ready_signal.value[rank] = 1
|
||||
cache_messager.prefill_layerwise_send_cache_thread()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
main()
|
||||
|
@@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cuda_host_alloc,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -39,26 +39,12 @@ def parse_args():
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache transfer manager")
|
||||
parser.add_argument(
|
||||
"--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now",
|
||||
)
|
||||
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument(
|
||||
@@ -68,7 +54,6 @@ def parse_args():
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
@@ -109,7 +94,6 @@ class CacheTransferManager:
|
||||
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
self.gpu_cache_kvs = {}
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
@@ -138,40 +122,27 @@ class CacheTransferManager:
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
cache_type = args.cache_dtype
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape = [
|
||||
args.num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
]
|
||||
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
for i in range(args.num_hidden_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape[0] = num_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{rank}.device{device}"
|
||||
value_name = f"value_caches_{i}_rank{rank}.device{device}"
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
value_cache = share_external_data(value_cache, value_name, cache_shape)
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_kvs[value_name] = value_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
|
||||
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
@@ -180,7 +151,7 @@ class CacheTransferManager:
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
for i in range(args.num_hidden_layers + self.num_extra_layers):
|
||||
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
||||
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
)
|
||||
@@ -190,38 +161,6 @@ class CacheTransferManager:
|
||||
)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
if args.enable_splitwise:
|
||||
logger.debug("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
from fastdeploy.cache_manager.cache_messager import CacheMessager
|
||||
|
||||
self.cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=self.gpu_cache_kvs,
|
||||
rank=self.rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + self.num_extra_layers,
|
||||
gpu_id=self.device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
logger.info("successfully create cache messager")
|
||||
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
name="cache_task_broadcast_signal",
|
||||
@@ -443,4 +382,5 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
|
||||
paddle.set_device(f"gpu:{args.device_id}")
|
||||
main()
|
||||
|
@@ -31,6 +31,7 @@ from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
@@ -106,6 +107,10 @@ class PrefixCacheManager:
|
||||
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
|
||||
)
|
||||
|
||||
@property
|
||||
def available_gpu_resource(self):
|
||||
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
|
||||
|
||||
def launch_cache_manager(
|
||||
self,
|
||||
cache_config,
|
||||
@@ -141,6 +146,76 @@ class PrefixCacheManager:
|
||||
filename = "cache_transfer_manager.py"
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
|
||||
cache_messager_processes = []
|
||||
if self.splitwise_role != "mixed":
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
)
|
||||
if cache_messager_processes is None:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
return []
|
||||
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and cache_config.model_cfg.num_key_value_heads is not None
|
||||
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||
):
|
||||
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
else:
|
||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
cache_manager_processes.extend(cache_messager_processes)
|
||||
return cache_manager_processes
|
||||
|
||||
def launch_cache_messager(
|
||||
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
|
||||
):
|
||||
"""
|
||||
launch_cache_messager function used to initialize the cache messager.
|
||||
"""
|
||||
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||
filename = "cache_messager.py"
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
@@ -159,8 +234,10 @@ class PrefixCacheManager:
|
||||
suffix=pid_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
cache_messager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||
@@ -169,42 +246,34 @@ class PrefixCacheManager:
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
# 等待cache初始化完毕
|
||||
logger.info("Waiting for cache transfer manager ready...")
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
logger.info("Waiting for cache ready...")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
exit_code = cache_messager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
logger.info("Launch cache messager successful")
|
||||
else:
|
||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
return cache_manager_processes
|
||||
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
|
||||
cache_messager_processes = None
|
||||
return cache_messager_processes
|
||||
|
||||
def update_cache_config(self, cache_config):
|
||||
"""
|
||||
@@ -225,6 +294,9 @@ class PrefixCacheManager:
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
|
||||
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_resource.set(1.0)
|
||||
|
||||
def _enable_cpu_cache(self):
|
||||
"""
|
||||
_enable_cpu_cache function used to enable cpu cache.
|
||||
@@ -260,6 +332,8 @@ class PrefixCacheManager:
|
||||
logger.info(
|
||||
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
|
||||
)
|
||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||
return allocated_block_ids
|
||||
|
||||
def recycle_gpu_blocks(self, gpu_block_ids):
|
||||
@@ -274,6 +348,8 @@ class PrefixCacheManager:
|
||||
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
|
||||
else:
|
||||
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
|
||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||
|
||||
def allocate_cpu_blocks(self, num_blocks):
|
||||
"""
|
||||
|
@@ -61,18 +61,12 @@ class RDMACommManager:
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
assert self.splitwise_role == "prefill", "only prefill can call this method"
|
||||
addr = f"{ip}:{port!s}"
|
||||
if addr in self.connected_rdma:
|
||||
return True
|
||||
ret = self.messager.is_connected(ip, str(port))
|
||||
if ret:
|
||||
self.connected_rdma.add(addr)
|
||||
return True
|
||||
|
||||
ret = self.messager.connect(ip, str(port))
|
||||
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||
if ret == 0:
|
||||
self.connected_rdma.add(addr)
|
||||
return ret == 0
|
||||
|
||||
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):
|
||||
|
@@ -820,6 +820,7 @@ class EngineArgs:
|
||||
"max_num_partial_prefills",
|
||||
"max_long_partial_prefills",
|
||||
"long_prefill_token_threshold",
|
||||
"splitwise_role"
|
||||
]
|
||||
|
||||
all = asdict(self)
|
||||
|
@@ -293,10 +293,11 @@ class Config:
|
||||
)
|
||||
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
)
|
||||
if not int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")):
|
||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
)
|
||||
else:
|
||||
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
|
@@ -28,6 +28,7 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
import weakref
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -47,12 +48,14 @@ from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
ZmqClient,
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
@@ -110,6 +113,8 @@ class LLMEngine:
|
||||
|
||||
self.start_queue_service()
|
||||
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
@@ -123,9 +128,17 @@ class LLMEngine:
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
||||
|
||||
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
|
||||
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
self.splitwise_queue = deque()
|
||||
self.split_connector = SplitwiseConnector(
|
||||
cfg,
|
||||
self.scheduler,
|
||||
self.engine_worker_queue,
|
||||
self.resource_manager,
|
||||
self.splitwise_queue,
|
||||
)
|
||||
|
||||
self.token_processor = TokenProcessor(
|
||||
cfg=self.cfg,
|
||||
@@ -177,13 +190,71 @@ class LLMEngine:
|
||||
self._init_worker_signals()
|
||||
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
|
||||
|
||||
if api_server_pid is not None:
|
||||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||
self.external_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
|
||||
)
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
self.recv_result_handle_thread = threading.Thread(
|
||||
target=self.send_response_server.recv_result_handle, daemon=True
|
||||
)
|
||||
self.recv_result_handle_thread.start()
|
||||
time.sleep(3)
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
request_queues_for_dp_ipc = (
|
||||
None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
|
||||
)
|
||||
result_queue_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_expert_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queue_for_dp_ipc,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
||||
if self.do_profile == 0 and (
|
||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
|
||||
):
|
||||
@@ -238,44 +309,11 @@ class LLMEngine:
|
||||
# 单机逻辑
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
self.split_mode_get_tasks()
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp":
|
||||
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
|
||||
self.splitwise_receive_thread.daemon = True
|
||||
self.splitwise_receive_thread.start()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_expert_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
||||
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
|
||||
return True
|
||||
|
||||
@@ -290,8 +328,9 @@ class LLMEngine:
|
||||
if len(results) == 0:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
with self.response_lock:
|
||||
for request_id, contents in results.items():
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
@@ -308,7 +347,7 @@ class LLMEngine:
|
||||
Insert task to engine thread, monitor scheduler request queue.
|
||||
if the engine has resource, insert task to engine
|
||||
"""
|
||||
current_id = -1
|
||||
current_id = 0
|
||||
while self.running:
|
||||
try:
|
||||
if self.resource_manager.available_batch() == 0:
|
||||
@@ -321,18 +360,15 @@ class LLMEngine:
|
||||
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if self.engine_worker_queue.num_cache_infos() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(self.split_connector.current_request_ids) > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
num_prefill_batch = int(self.resource_manager.available_batch())
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.resource_manager.available_block_num(),
|
||||
@@ -346,12 +382,15 @@ class LLMEngine:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
current_id = (current_id + 1) % 100003
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
self.insert_tasks(tasks, current_id)
|
||||
insert_successful = self.insert_tasks(tasks, current_id)
|
||||
if insert_successful:
|
||||
current_id = current_id + 1
|
||||
else:
|
||||
continue
|
||||
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
@@ -400,6 +439,8 @@ class LLMEngine:
|
||||
get_request_pool.submit(_fetch_request)
|
||||
# 2. Schedule requests
|
||||
tasks = self.resource_manager.schedule()
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
# 3. Send to engine
|
||||
if tasks:
|
||||
self.resource_manager.get_real_bsz()
|
||||
@@ -415,14 +456,18 @@ class LLMEngine:
|
||||
if self.api_server_pid is None:
|
||||
return
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if self.cfg.splitwise_role == "decode":
|
||||
return
|
||||
|
||||
added_requests: Dict[str, int] = dict()
|
||||
while self.running:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
if not self.cfg.enable_mm:
|
||||
err, data = self.zmq_server.receive_json_once(block)
|
||||
err, data = self.recv_request_server.receive_json_once(block)
|
||||
else:
|
||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
@@ -433,6 +478,7 @@ class LLMEngine:
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
|
||||
main_process_metrics.requests_number.inc()
|
||||
llm_logger.debug(f"Receive request: {request}")
|
||||
|
||||
err_msg = None
|
||||
@@ -461,7 +507,7 @@ class LLMEngine:
|
||||
if failed is None:
|
||||
main_process_metrics.num_requests_waiting.inc(1)
|
||||
continue
|
||||
|
||||
llm_logger.error(f"request {request_id} insert to scheduler failed: {failed}")
|
||||
error_result = RequestOutput(
|
||||
request_id=request_id,
|
||||
finished=True,
|
||||
@@ -470,7 +516,8 @@ class LLMEngine:
|
||||
)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
self.zmq_server.send_multipart(request_id, error_result)
|
||||
with self.response_lock:
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
f"Error happend while receving new request from zmq, details={e}, "
|
||||
@@ -570,41 +617,44 @@ class LLMEngine:
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks()
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
if len(self.splitwise_queue) > 0:
|
||||
items = self.splitwise_queue.pop()
|
||||
role = items[0]
|
||||
tasks = items[1]
|
||||
|
||||
if role == "prefill":
|
||||
if role == "prefill":
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
|
||||
elif role == "decode":
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
task.finished = False
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
elif role == "decode":
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
else:
|
||||
new_waiting.append(task)
|
||||
|
||||
if new_waiting:
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
else:
|
||||
if not self.enable_decode_cache_task:
|
||||
task.error_msg = "Not enough resources"
|
||||
new_waiting.append(task)
|
||||
if new_waiting:
|
||||
if not self.enable_decode_cache_task:
|
||||
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||
else:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||
|
||||
@@ -749,10 +799,6 @@ class LLMEngine:
|
||||
"""
|
||||
Insert tasks to engine.
|
||||
"""
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
if task.sampling_params.bad_words is not None:
|
||||
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||
# TODO 返回至 scheduler
|
||||
if allocated:
|
||||
current_tasks = []
|
||||
@@ -760,6 +806,15 @@ class LLMEngine:
|
||||
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
|
||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||
self.resource_manager._recycle_block_tables(cur_task)
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
llm_logger.warning(f"{task.request_id} need not decode after first token")
|
||||
continue
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
|
||||
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
|
||||
@@ -769,32 +824,58 @@ class LLMEngine:
|
||||
self.resource_manager._recycle_block_tables(cur_task)
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.scheduler.put_results([task])
|
||||
llm_logger.warning(
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
current_tasks.append(cur_task)
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
if current_tasks:
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
return True
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
if task.sampling_params.bad_words is not None:
|
||||
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
|
||||
for item in tasks:
|
||||
item.schedule_start_time = time.time()
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
|
||||
if len(tasks) == 0:
|
||||
return False
|
||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||
if len(tasks) > available_batch:
|
||||
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
||||
llm_logger.error("The exceeded part will be ignored!")
|
||||
tasks = tasks[:available_batch]
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
|
||||
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
|
||||
|
||||
if not tasks:
|
||||
@@ -815,19 +896,19 @@ class LLMEngine:
|
||||
is_prefill = True
|
||||
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
for task in tasks:
|
||||
task.inference_start_time = time.time()
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
for task in tasks:
|
||||
task.inference_start_time = time.time()
|
||||
if not is_prefill:
|
||||
if not self.cfg.enable_mm:
|
||||
self.update_requests_chunk_size(tasks)
|
||||
else:
|
||||
self.update_mm_requests_chunk_size(tasks)
|
||||
if not self.cfg.enable_mm:
|
||||
self.update_requests_chunk_size(tasks)
|
||||
else:
|
||||
self.update_mm_requests_chunk_size(tasks)
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
return True
|
||||
|
||||
def task_is_finished(self, index):
|
||||
@@ -966,14 +1047,17 @@ class LLMEngine:
|
||||
self.running = False
|
||||
|
||||
if hasattr(self, "cache_manager_processes"):
|
||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
os.killpg(p.pid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
print(f"Error extracting file: {e}")
|
||||
if hasattr(self.resource_manager.cache_manager, "cache_ready_signal"):
|
||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||||
self.zmq_server.close()
|
||||
self.worker_ready_signal.clear()
|
||||
self.exist_task_signal.clear()
|
||||
self.exist_swapped_task_signal.clear()
|
||||
@@ -988,12 +1072,19 @@ class LLMEngine:
|
||||
except Exception as e:
|
||||
print(f"Error extracting sub services: {e}")
|
||||
|
||||
self.engine_worker_queue.cleanup()
|
||||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||||
self.zmq_server.close()
|
||||
for worker_queue in self.engine_worker_queue_server:
|
||||
worker_queue.cleanup()
|
||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||
self.send_response_server.close()
|
||||
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||
self.recv_request_server.close()
|
||||
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
||||
self.recv_control_cmd_server.close()
|
||||
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
p.join()
|
||||
self.engine_worker_queue_server.cleanup()
|
||||
|
||||
def _setting_environ_variables(self):
|
||||
"""
|
||||
@@ -1291,15 +1382,20 @@ class LLMEngine:
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
"""
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
|
||||
self.engine_worker_queue_server = list()
|
||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
|
||||
llm_logger.info(f"Starting engine worker queue service at {address}")
|
||||
self.engine_worker_queue_server.append(
|
||||
EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -1314,6 +1410,7 @@ class LLMEngine:
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
|
@@ -16,21 +16,25 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
|
||||
|
||||
|
||||
class ExpertService:
|
||||
@@ -52,6 +56,10 @@ class ExpertService:
|
||||
self.cfg = cfg
|
||||
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
|
||||
end_pos = start_pos + self.cfg.tensor_parallel_size
|
||||
self.waiting_requests = []
|
||||
self.disaggregate_queue = deque()
|
||||
|
||||
self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log")
|
||||
if cfg.splitwise_role != "mixed":
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
||||
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
|
||||
@@ -60,11 +68,12 @@ class ExpertService:
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
|
||||
address = (cfg.master_ip, cfg.engine_worker_queue_port)
|
||||
address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
@@ -88,10 +97,7 @@ class ExpertService:
|
||||
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
|
||||
|
||||
self.split_connector = SplitwiseConnector(
|
||||
self.cfg,
|
||||
self.scheduler,
|
||||
self.engine_worker_queue,
|
||||
self.resource_manager,
|
||||
self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue
|
||||
)
|
||||
|
||||
self.token_processor = TokenProcessor(
|
||||
@@ -111,8 +117,12 @@ class ExpertService:
|
||||
)
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
|
||||
|
||||
def start(self, ipc_signal_suffix, local_data_parallel_id):
|
||||
def start(
|
||||
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
If `api_server_pid` is defined, will launch a thread
|
||||
@@ -121,13 +131,13 @@ class ExpertService:
|
||||
# assert not self.is_started, "The engine is already started."
|
||||
start_time = time.time()
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
||||
cache_config=self.cfg.cache_config,
|
||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
||||
device_ids=self.cfg.local_device_ids,
|
||||
pod_ip=self.cfg.pod_ips[0],
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
|
||||
)
|
||||
@@ -139,7 +149,7 @@ class ExpertService:
|
||||
|
||||
# Start TokenProcessor thread
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
|
||||
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
|
||||
self.token_processor.run()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
@@ -147,7 +157,11 @@ class ExpertService:
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
|
||||
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
|
||||
elif self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
self.cfg.print()
|
||||
|
||||
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
|
||||
@@ -158,7 +172,7 @@ class ExpertService:
|
||||
Insert task to engine thread, monitor scheduler request queue.
|
||||
if the engine has resource, insert task to engine
|
||||
"""
|
||||
current_id = -1
|
||||
current_id = 0
|
||||
while True:
|
||||
try:
|
||||
if self.resource_manager.available_batch() == 0:
|
||||
@@ -167,15 +181,13 @@ class ExpertService:
|
||||
if self.engine_worker_queue.num_tasks() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(self.split_connector.current_request_ids) > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
num_prefill_batch = int(self.resource_manager.available_batch())
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.resource_manager.available_block_num(),
|
||||
@@ -190,73 +202,88 @@ class ExpertService:
|
||||
continue
|
||||
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
llm_logger.info("Inserting splitwise tasks")
|
||||
self.llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
current_id = (current_id + 1) % 100003
|
||||
|
||||
self.insert_tasks(tasks, current_id)
|
||||
insert_successful = self.insert_tasks(tasks, current_id)
|
||||
if insert_successful:
|
||||
current_id = current_id + 1
|
||||
else:
|
||||
continue
|
||||
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
except Exception as e:
|
||||
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
llm_logger.error(err_msg)
|
||||
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||
self.llm_logger.error(err_msg)
|
||||
|
||||
def split_mode_get_tasks(self):
|
||||
"""
|
||||
Split mode get tasks
|
||||
"""
|
||||
waiting_requests = []
|
||||
|
||||
def receiver_loop():
|
||||
while True:
|
||||
try:
|
||||
if len(waiting_requests) > 0:
|
||||
for task in waiting_requests:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
waiting_requests.remove(task)
|
||||
else:
|
||||
break
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks()
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
if role == "prefill":
|
||||
llm_logger.info("get prefill tasks")
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
elif role == "decode":
|
||||
llm_logger.info(f"get decode tasks {tasks}")
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
# self.scheduler.put_results(tasks)
|
||||
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
processed_indices = []
|
||||
for idx, task in enumerate(self.waiting_requests):
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
processed_indices.append(idx)
|
||||
else:
|
||||
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||
break
|
||||
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
|
||||
if len(self.disaggregate_queue) > 0:
|
||||
items = self.disaggregate_queue.pop()
|
||||
role = items[0]
|
||||
tasks = items[1]
|
||||
if role == "prefill":
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
elif role == "decode":
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
if len(waiting_requests):
|
||||
for task in tasks:
|
||||
waiting_requests.append(task)
|
||||
else:
|
||||
for task in tasks:
|
||||
if not self.resource_manager.is_resource_sufficient(
|
||||
task.prompt_token_ids_len
|
||||
):
|
||||
waiting_requests.append(task)
|
||||
else:
|
||||
self.insert_tasks([task])
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
else:
|
||||
if not self.enable_decode_cache_task:
|
||||
task.error_msg = "Not enough resources"
|
||||
new_waiting.append(task)
|
||||
if new_waiting:
|
||||
if not self.enable_decode_cache_task:
|
||||
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||
else:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"get decode tasks error: {e}")
|
||||
self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}")
|
||||
time.sleep(0.1)
|
||||
|
||||
threading.Thread(target=receiver_loop, daemon=True).start()
|
||||
|
||||
@@ -270,22 +297,32 @@ class ExpertService:
|
||||
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
|
||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||
self.resource_manager._recycle_block_tables(cur_task)
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
|
||||
continue
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
|
||||
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
|
||||
if task.error_code != 200:
|
||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||
self.resource_manager._recycle_block_tables(cur_task)
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.scheduler.put_results([task])
|
||||
llm_logger.warning(
|
||||
self.llm_logger.warning(
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
llm_logger.info(f"{cur_task_idx} {task.request_id}")
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
current_tasks.append(cur_task)
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
if current_tasks:
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
return True
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
@@ -293,22 +330,48 @@ class ExpertService:
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
|
||||
for item in tasks:
|
||||
item.schedule_start_time = time.time()
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
|
||||
if len(tasks) == 0:
|
||||
return False
|
||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||
if len(tasks) > available_batch:
|
||||
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
||||
llm_logger.error("The exceeded part will be ignored!")
|
||||
self.llm_logger.error(
|
||||
"Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
|
||||
)
|
||||
self.llm_logger.error("The exceeded part will be ignored!")
|
||||
tasks = tasks[:available_batch]
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
|
||||
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
|
||||
|
||||
if not tasks:
|
||||
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
||||
llm_logger.error(error_msg)
|
||||
self.llm_logger.error(error_msg)
|
||||
raise EngineError(error_msg, error_code=500)
|
||||
return False
|
||||
|
||||
@@ -328,7 +391,7 @@ class ExpertService:
|
||||
for task in tasks:
|
||||
task.infer_start_time = time.time()
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
|
||||
if not self.cfg.enable_mm:
|
||||
self.update_requests_chunk_size(tasks)
|
||||
@@ -346,7 +409,7 @@ class ExpertService:
|
||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
self.llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
os.killpg(p.pid, signal.SIGTERM)
|
||||
except:
|
||||
@@ -356,13 +419,17 @@ class ExpertService:
|
||||
self.zmq_server.close()
|
||||
|
||||
|
||||
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
|
||||
def start_expert_service(
|
||||
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Start expert service
|
||||
"""
|
||||
expert_service = ExpertService(cfg, local_data_parallel_id)
|
||||
try:
|
||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
||||
expert_service.start(
|
||||
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
expert_service.split_connector.start_receiver()
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Expert service failed to start: {e}")
|
||||
|
@@ -71,6 +71,7 @@ class Request:
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True,
|
||||
trace_carrier: dict = dict(),
|
||||
dp_rank: Optional[int] = None
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@@ -119,6 +120,7 @@ class Request:
|
||||
self.task_type = RequestType.PREFILL
|
||||
self.idx = None
|
||||
self.need_prefill_tokens = self.prompt_token_ids_len
|
||||
self.dp_rank = dp_rank
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@@ -151,6 +153,7 @@ class Request:
|
||||
guided_json_object=d.get("guided_json_object", None),
|
||||
enable_thinking=d.get("enable_thinking", True),
|
||||
trace_carrier=d.get("trace_carrier", {}),
|
||||
dp_rank=d.get("dp_rank", None)
|
||||
)
|
||||
|
||||
@property
|
||||
|
@@ -22,7 +22,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.utils import get_logger, llm_logger
|
||||
|
||||
|
||||
class ResourceManager:
|
||||
@@ -49,16 +49,23 @@ class ResourceManager:
|
||||
Initializes the engine with the given configuration and sets up necessary
|
||||
data structures to manage tasks and blocks.
|
||||
"""
|
||||
if local_data_parallel_id > 0:
|
||||
self.logger = get_logger(
|
||||
f"expert_service_{local_data_parallel_id}", f"expert_service_{local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = llm_logger
|
||||
self.cfg = config.cache_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.stop_flags = [True] * max_num_seqs
|
||||
self.stop_flags = [True] * max_num_seqs # flag set to true if the slot has not been taken
|
||||
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
|
||||
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
|
||||
self.tasks_list = [None] * max_num_seqs
|
||||
self.tasks_list = [None] * max_num_seqs # task slots
|
||||
self.req_dict = dict()
|
||||
# current batch status of the engine
|
||||
self.real_bsz = 0
|
||||
llm_logger.info(f"{self.info()}")
|
||||
self.logger.info(f"{self.info()}")
|
||||
main_process_metrics.max_batch_size.set(max_num_seqs)
|
||||
|
||||
def reset_cache_config(self, cfg):
|
||||
"""
|
||||
@@ -134,10 +141,10 @@ class ResourceManager:
|
||||
block_list = list()
|
||||
current_block_num = self.available_block_num()
|
||||
if block_num > current_block_num:
|
||||
llm_logger.error(f"block_num:{block_num} > free_list len:{current_block_num}")
|
||||
self.logger.error("block_num:{0} > free_list len:{1}".format(block_num, current_block_num))
|
||||
return block_list
|
||||
block_list = self.cache_manager.allocate_gpu_blocks(block_num)
|
||||
llm_logger.debug(f"dispatch {len(block_list)} blocks.")
|
||||
self.logger.debug(f"dispatch {len(block_list)} blocks.")
|
||||
return block_list
|
||||
|
||||
def check_and_free_block_tables(self):
|
||||
@@ -169,7 +176,7 @@ class ResourceManager:
|
||||
self.cache_manager.recycle_gpu_blocks(block_tables)
|
||||
cur_number = self.available_block_num()
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
llm_logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
|
||||
self.logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
|
||||
|
||||
def available_batch(self):
|
||||
"""
|
||||
@@ -222,47 +229,47 @@ class ResourceManager:
|
||||
Returns:
|
||||
list: processed task list
|
||||
"""
|
||||
|
||||
allocated_position = 0
|
||||
processing_task_index = 0
|
||||
llm_logger.debug(f"Allocating resources for a batch of new tasks: {tasks}")
|
||||
allocated_position = 0 # number of tasks that have been allocated, also the position in request slots
|
||||
processing_task_index = 0 # current task
|
||||
processed_tasks = list()
|
||||
while allocated_position < self.max_num_seqs:
|
||||
if processing_task_index >= len(tasks):
|
||||
while allocated_position < self.max_num_seqs: # loop until all tasks are allocated resources for
|
||||
if processing_task_index >= len(tasks): # if all taskes have been tried, don't give a second chance
|
||||
break
|
||||
|
||||
can_insert = False
|
||||
while allocated_position + 1 <= self.max_num_seqs:
|
||||
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
|
||||
can_insert = True
|
||||
can_insert = True # if there is a empty slot, try to allocate resources for current task
|
||||
break
|
||||
allocated_position += 1
|
||||
if can_insert:
|
||||
if self.stop_flags[allocated_position]:
|
||||
|
||||
task = tasks[processing_task_index]
|
||||
task = tasks[processing_task_index] # retrieve current task
|
||||
|
||||
if task.get("seed") is None:
|
||||
task.set("seed", random.randint(0, 9223372036854775807))
|
||||
task.idx = allocated_position
|
||||
|
||||
if self.enable_prefix_cache:
|
||||
if self.enable_prefix_cache: # if prefix caching is enabled
|
||||
# 1. request for enough blocks for current task
|
||||
cache_prepare_time = time.time()
|
||||
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
|
||||
task,
|
||||
self.cfg.block_size,
|
||||
self.cfg.dec_token_num,
|
||||
task, self.cfg.block_size, self.cfg.dec_token_num
|
||||
)
|
||||
if unique_block_ids is None:
|
||||
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
|
||||
self.logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
|
||||
return
|
||||
|
||||
# 2. record cache hit information, and return the number of tokens already in cache
|
||||
cached_len = self._record_request_cache_info(
|
||||
task, common_block_ids, unique_block_ids, hit_info
|
||||
)
|
||||
task.cache_prepare_time = time.time() - cache_prepare_time
|
||||
|
||||
# 3. if prefill/decode disaggregation is enabled
|
||||
if task.disaggregate_info is not None:
|
||||
if task.disaggregate_info["role"] == "prefill":
|
||||
# record the slot position for current task, indexed by request id
|
||||
self.req_dict[task.request_id] = allocated_position
|
||||
task.disaggregate_info["block_tables"] = task.block_tables
|
||||
self._delete_cached_data(task, cached_len)
|
||||
@@ -270,17 +277,19 @@ class ResourceManager:
|
||||
self.req_dict[task.request_id] = allocated_position
|
||||
task.disaggregate_info["block_tables"] = task.need_block_tables
|
||||
else:
|
||||
# remove cached tokens from prompt token ids to avoid kv recomputation
|
||||
self._delete_cached_data(task, cached_len)
|
||||
|
||||
else:
|
||||
else: # if prefix caching is disabled
|
||||
# 1. directly allocate empty block from the cache, if there is any
|
||||
block_tables = self._get_block_tables(task.prompt_token_ids_len)
|
||||
if not block_tables:
|
||||
llm_logger.error(f"req_id: {task.request_id} block_tables is empty")
|
||||
continue
|
||||
continue # retry
|
||||
else:
|
||||
task.block_tables = block_tables
|
||||
task.need_block_tables = task.block_tables
|
||||
|
||||
# 2. if prefill/decode disaggregation is enabled
|
||||
if task.disaggregate_info is not None:
|
||||
task.disaggregate_info["block_tables"] = block_tables
|
||||
if task.disaggregate_info["role"] == "prefill":
|
||||
@@ -288,13 +297,13 @@ class ResourceManager:
|
||||
elif task.disaggregate_info["role"] == "decode":
|
||||
self.req_dict[task.request_id] = allocated_position
|
||||
|
||||
processed_tasks.append(task)
|
||||
self.stop_flags[allocated_position] = False
|
||||
processed_tasks.append(task) # add current task
|
||||
self.stop_flags[allocated_position] = False # mark the slot as occupied
|
||||
task.inference_start_time = time.time()
|
||||
task.inference_time_cost = -1.0
|
||||
task.tokens_all_num = 0
|
||||
self.tasks_list[allocated_position] = task
|
||||
llm_logger.info(
|
||||
self.logger.info(
|
||||
f"Allocate request: {task.request_id}, "
|
||||
f"allocated_position:{allocated_position}, "
|
||||
f"length of prompt token: {task.prompt_token_ids_len}"
|
||||
@@ -303,15 +312,22 @@ class ResourceManager:
|
||||
processing_task_index += 1
|
||||
|
||||
# batch size when the statistical engine is inferring
|
||||
# determine batch size by index of the first slot that is not occupied
|
||||
for i in range(self.max_num_seqs - 1, -1, -1):
|
||||
if not self.stop_flags[i]:
|
||||
self.real_bsz = i + 1
|
||||
break
|
||||
|
||||
llm_logger.info(
|
||||
# record batch size here
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
|
||||
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
|
||||
self.logger.info(
|
||||
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
|
||||
)
|
||||
llm_logger.info(f"{self.info()}")
|
||||
self.logger.info(f"{self.info()}")
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
|
||||
return processed_tasks
|
||||
@@ -321,8 +337,8 @@ class ResourceManager:
|
||||
Delete cached data from the task's prompt token ids based on the cached length.
|
||||
"""
|
||||
if cached_len == len(task.prompt_token_ids):
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
|
||||
task.seq_lens_decoder = cached_len - 1
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
|
||||
task.seq_lens_decoder = cached_len - self.cfg.block_size
|
||||
else:
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
||||
task.seq_lens_decoder = cached_len
|
||||
@@ -339,11 +355,16 @@ class ResourceManager:
|
||||
task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size
|
||||
task.cache_info = (cache_block_num, no_cache_block_num)
|
||||
|
||||
# Report the number of cached tokens to Prometheus metrics
|
||||
main_process_metrics.prefix_cache_token_num.inc(task.num_cached_tokens)
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(task.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(task.cpu_cache_token_num)
|
||||
|
||||
cached_len = len(common_block_ids) * self.cfg.block_size
|
||||
task.block_tables = common_block_ids + unique_block_ids
|
||||
task.need_block_tables = unique_block_ids
|
||||
llm_logger.debug(f"common: {common_block_ids} ")
|
||||
llm_logger.debug(f"unique: {unique_block_ids} ")
|
||||
self.logger.debug(f"common: {common_block_ids} ")
|
||||
self.logger.debug(f"unique: {unique_block_ids} ")
|
||||
return cached_len
|
||||
|
||||
def info(self):
|
||||
|
@@ -27,6 +27,7 @@ import paddle
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestStatus, RequestType
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
@@ -75,6 +76,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.running: list[Request] = []
|
||||
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.lock = threading.Lock()
|
||||
main_process_metrics.max_batch_size.set(max_num_seqs)
|
||||
|
||||
def allocated_slots(self, request: Request):
|
||||
return len(request.block_tables) * self.config.cache_config.block_size
|
||||
@@ -98,6 +100,9 @@ class ResourceManagerV1(ResourceManager):
|
||||
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
|
||||
|
||||
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
|
||||
"""
|
||||
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
|
||||
"""
|
||||
can_schedule = True
|
||||
while True:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
|
||||
@@ -201,6 +206,9 @@ class ResourceManagerV1(ResourceManager):
|
||||
return False
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
Try to pull a batch of requests from the waiting queue and schedule them.
|
||||
"""
|
||||
with self.lock:
|
||||
scheduled_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
@@ -262,7 +270,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
# Prepare prefill task
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
else:
|
||||
else: # Not enough blocks to allocate, trigger preemption
|
||||
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
|
||||
if not can_schedule:
|
||||
break
|
||||
@@ -328,6 +336,10 @@ class ResourceManagerV1(ResourceManager):
|
||||
else:
|
||||
llm_logger.error("Unknown request status type")
|
||||
if scheduled_reqs:
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
|
||||
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
|
||||
return scheduled_reqs
|
||||
|
||||
@@ -369,6 +381,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.block_tables = common_block_ids
|
||||
request.skip_allocate = False
|
||||
|
||||
# Report the number of cached tokens to Prometheus metrics
|
||||
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
|
||||
|
||||
if matched_token_num == request.prompt_token_ids_len:
|
||||
request.num_computed_tokens = matched_token_num - 1
|
||||
request.skip_allocate = True
|
||||
|
@@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.config import ModelConfig
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -90,7 +90,7 @@ class EngineClient:
|
||||
"""
|
||||
Create a ZMQ client.
|
||||
"""
|
||||
self.zmq_client = ZmqClient(model, mode)
|
||||
self.zmq_client = ZmqIpcClient(model, mode)
|
||||
self.zmq_client.connect()
|
||||
|
||||
def format_and_add_data(self, prompts: dict):
|
||||
|
@@ -177,6 +177,8 @@ class OpenAIServingChat:
|
||||
for res in response:
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
if res["finished"]:
|
||||
api_server_logger.info(f"chat completion finished: {request_id}")
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res,
|
||||
|
@@ -80,8 +80,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
|
||||
# enable kv cache block scheduler v1 (no need for kv_cache_ratio)
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
|
||||
# enable internal module to access LLMEngine.
|
||||
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
|
||||
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
|
||||
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
# Batched token timeout in EP
|
||||
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||
# Whether to use PLUGINS.
|
||||
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
|
||||
# Whether to enable cache task in decode node
|
||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -17,6 +17,7 @@
|
||||
from .engine_cache_queue import EngineCacheQueue
|
||||
from .engine_worker_queue import EngineWorkerQueue
|
||||
from .ipc_signal import IPCSignal
|
||||
from .zmq_client import ZmqClient
|
||||
from .zmq_client import ZmqIpcClient
|
||||
from .zmq_server import ZmqIpcServer, ZmqTcpServer
|
||||
|
||||
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"]
|
||||
__all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"]
|
||||
|
@@ -85,12 +85,15 @@ class EngineWorkerQueue:
|
||||
]
|
||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.client_read_info_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.lock_info_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
|
||||
|
||||
self.finish_request_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
@@ -112,11 +115,26 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_lock",
|
||||
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_read_finish_flag",
|
||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks",
|
||||
callable=lambda idx: self.connect_rdma_tasks_list[idx],
|
||||
proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks_responses",
|
||||
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
|
||||
proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connected_client_counter",
|
||||
callable=lambda idx: self.connected_client_counter_init[idx],
|
||||
@@ -180,6 +198,9 @@ class EngineWorkerQueue:
|
||||
QueueManager.register("get_disaggregate_requests")
|
||||
QueueManager.register("get_available_prefill_instances")
|
||||
QueueManager.register("get_finish_request_barrier")
|
||||
QueueManager.register("get_connect_rdma_tasks")
|
||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||
QueueManager.register("get_connect_task_lock")
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
|
||||
@@ -200,6 +221,13 @@ class EngineWorkerQueue:
|
||||
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
||||
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
# p/d互联
|
||||
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||
|
||||
assert self.num_client == len(self.client_read_flag)
|
||||
|
||||
if is_server:
|
||||
@@ -280,6 +308,45 @@ class EngineWorkerQueue:
|
||||
total_num: int = len(self.tasks)
|
||||
self.lock.release()
|
||||
return total_num
|
||||
|
||||
def put_connect_rdma_task(self, connect_rdma_task):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_response_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_response_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
|
||||
def get_prefill_instances(self):
|
||||
"""
|
||||
|
@@ -14,200 +14,78 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ZmqClient:
|
||||
class ZmqClientBase(ABC):
|
||||
"""
|
||||
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(mode)
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.router = None
|
||||
self.poller = None
|
||||
self.running = True
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
||||
def start_server(self):
|
||||
"""
|
||||
Start the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
||||
def create_router(self):
|
||||
"""
|
||||
Create a ROUTER socket and bind it to the specified router path.
|
||||
"""
|
||||
self.router = self.context.socket(zmq.ROUTER)
|
||||
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.router.bind(f"ipc://{self.router_path}")
|
||||
pass
|
||||
|
||||
def send_json(self, data):
|
||||
"""
|
||||
Send a JSON-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_json(data)
|
||||
|
||||
def recv_json(self):
|
||||
"""
|
||||
Receive a JSON-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_json()
|
||||
|
||||
def send_pyobj(self, data):
|
||||
"""
|
||||
Send a Pickle-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_pyobj(data)
|
||||
|
||||
def recv_pyobj(self):
|
||||
"""
|
||||
Receive a Pickle-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
"""
|
||||
if self.router is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
class ZmqIpcClient(ZmqClientBase):
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(self.mode)
|
||||
|
||||
while self.running:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
try:
|
||||
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.context = zmq.Context()
|
||||
return self.context.socket(self.mode)
|
||||
|
||||
try:
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_pyobj(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if hasattr(self, "socket") and not self.socket.closed:
|
||||
self.socket.close()
|
||||
|
||||
if self.router is not None and not self.router.closed:
|
||||
self.router.close()
|
||||
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
self._clear_ipc(self.file_name)
|
||||
self._clear_ipc(self.router_path)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
def connect(self):
|
||||
self._ensure_socket()
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
303
fastdeploy/inter_communicator/zmq_server.py
Normal file
303
fastdeploy/inter_communicator/zmq_server.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ZmqServerBase(ABC):
|
||||
"""
|
||||
ZmqServerBase
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cached_results = defaultdict(list)
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_pyobj(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def recv_result_handle(self):
|
||||
while True:
|
||||
try:
|
||||
with self.response_token_lock:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
with self.mutex:
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
except Exception as e:
|
||||
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||
continue
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
new_data = []
|
||||
has_result_handle = False
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
self.cached_results[req_id].append(data)
|
||||
else:
|
||||
has_result_handle = True
|
||||
if req_id in self.cached_results:
|
||||
for history_data in self.cached_results[req_id]:
|
||||
new_data.extend(history_data)
|
||||
llm_logger.info(
|
||||
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
|
||||
)
|
||||
del self.cached_results[req_id]
|
||||
if has_result_handle:
|
||||
try:
|
||||
new_data.extend(data)
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(new_data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in new_data])
|
||||
with self.response_token_lock:
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(
|
||||
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it")
|
||||
if req_id in self.cached_results:
|
||||
del self.cached_results[req_id]
|
||||
else:
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
self.req_dict.pop(req_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class ZmqIpcServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.cached_results = defaultdict(list)
|
||||
if mode == zmq.PULL:
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
elif mode == zmq.ROUTER:
|
||||
self.file_name = f"/dev/shm/router_{name}.ipc"
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
self.mutex = threading.Lock()
|
||||
self.response_token_lock = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
return self.socket
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
self._clear_ipc(self.file_name)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
||||
|
||||
|
||||
class ZmqTcpServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"""
|
||||
|
||||
def __init__(self, port, mode):
|
||||
self.mode = mode
|
||||
self.port = port
|
||||
self.cached_results = defaultdict(list)
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
self.mutex = threading.Lock()
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"tcp://*:{self.port}")
|
||||
return self.socket
|
||||
|
||||
def recv_control_cmd(self):
|
||||
"""
|
||||
Recieve control command from client
|
||||
"""
|
||||
self._ensure_socket()
|
||||
try:
|
||||
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
task = msgpack.unpackb(task_data)
|
||||
task_id_str = task["task_id"]
|
||||
except zmq.Again:
|
||||
return None
|
||||
with self.mutex:
|
||||
self.req_dict[task_id_str] = client
|
||||
return task
|
||||
|
||||
def response_for_control_cmd(self, task_id, result):
|
||||
"""
|
||||
Send command result back to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created.")
|
||||
try:
|
||||
result = msgpack.packb(result)
|
||||
self.socket.send_multipart([self.req_dict[task_id], b"", result])
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
with self.mutex:
|
||||
self.req_dict.pop(task_id, None)
|
||||
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
@@ -154,6 +154,22 @@ class MetricsManager:
|
||||
spec_decode_num_emitted_tokens_total: "Counter"
|
||||
spec_decode_draft_single_head_acceptance_rate: "list[Gauge]"
|
||||
|
||||
# for YIYAN Adapter
|
||||
prefix_cache_token_num: "Gauge"
|
||||
prefix_gpu_cache_token_num: "Gauge"
|
||||
prefix_cpu_cache_token_num: "Gauge"
|
||||
prefix_ssd_cache_token_num: "Gauge"
|
||||
batch_size: "Gauge"
|
||||
max_batch_size: "Gauge"
|
||||
available_gpu_block_num: "Gauge"
|
||||
free_gpu_block_num: "Gauge"
|
||||
max_gpu_block_num: "Gauge"
|
||||
available_gpu_resource: "Gauge"
|
||||
requests_number: "Counter"
|
||||
send_cache_failed_num: "Counter"
|
||||
first_token_latency: "Gauge"
|
||||
infer_latency: "Gauge"
|
||||
|
||||
# 定义所有指标配置
|
||||
METRICS = {
|
||||
"num_requests_running": {
|
||||
@@ -258,6 +274,91 @@ class MetricsManager:
|
||||
"description": "Total number of successfully processed requests",
|
||||
"kwargs": {},
|
||||
},
|
||||
# for YIYAN Adapter
|
||||
"prefix_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_cache_token_num",
|
||||
"description": "Total number of cached tokens",
|
||||
"kwargs": {},
|
||||
},
|
||||
"prefix_gpu_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_gpu_cache_token_num",
|
||||
"description": "Total number of cached tokens on GPU",
|
||||
"kwargs": {},
|
||||
},
|
||||
"prefix_cpu_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_cpu_cache_token_num",
|
||||
"description": "Total number of cached tokens on CPU",
|
||||
"kwargs": {},
|
||||
},
|
||||
"prefix_ssd_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_ssd_cache_token_num",
|
||||
"description": "Total number of cached tokens on SSD",
|
||||
"kwargs": {},
|
||||
},
|
||||
"batch_size": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:batch_size",
|
||||
"description": "Real batch size during inference",
|
||||
"kwargs": {},
|
||||
},
|
||||
"max_batch_size": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:max_batch_size",
|
||||
"description": "Maximum batch size determined when service started",
|
||||
"kwargs": {},
|
||||
},
|
||||
"available_gpu_block_num": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:available_gpu_block_num",
|
||||
"description": "Number of available gpu blocks in cache, including prefix caching blocks that are not officially released",
|
||||
"kwargs": {},
|
||||
},
|
||||
"free_gpu_block_num": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:free_gpu_block_num",
|
||||
"description": "Number of free blocks in cache",
|
||||
"kwargs": {},
|
||||
},
|
||||
"max_gpu_block_num": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:max_gpu_block_num",
|
||||
"description": "Number of total blocks determined when service started",
|
||||
"kwargs": {},
|
||||
},
|
||||
"available_gpu_resource": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:available_gpu_resource",
|
||||
"description": "Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num",
|
||||
"kwargs": {},
|
||||
},
|
||||
"requests_number": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:requests_number",
|
||||
"description": "Total number of requests received",
|
||||
"kwargs": {},
|
||||
},
|
||||
"send_cache_failed_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:send_cache_failed_num",
|
||||
"description": "Total number of failures of sending cache",
|
||||
"kwargs": {},
|
||||
},
|
||||
"first_token_latency": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:first_token_latency",
|
||||
"description": "Latest time to first token in seconds",
|
||||
"kwargs": {},
|
||||
},
|
||||
"infer_latency": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:infer_latency",
|
||||
"description": "Latest time to generate one token in seconds",
|
||||
"kwargs": {},
|
||||
},
|
||||
}
|
||||
SPECULATIVE_METRICS = {}
|
||||
|
||||
|
@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
@@ -65,6 +65,7 @@ else:
|
||||
update_inputs,
|
||||
step_reschedule,
|
||||
update_inputs_v1,
|
||||
speculate_step_reschedule,
|
||||
)
|
||||
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||
@@ -355,12 +356,11 @@ def step_cuda(
|
||||
"""
|
||||
|
||||
if speculative_config.method is not None:
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
if DISABLE_RECOVER:
|
||||
speculate_step_reschedule(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
@@ -386,64 +386,67 @@ def step_cuda(
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
if enable_prefix_caching:
|
||||
step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
elif DISABLE_RECOVER:
|
||||
if DISABLE_RECOVER:
|
||||
step_reschedule(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
@@ -471,32 +474,61 @@ def step_cuda(
|
||||
enc_dec_block_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
if enable_prefix_caching:
|
||||
step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
|
||||
|
||||
def rebuild_padding(
|
||||
|
@@ -195,7 +195,14 @@ class TokenProcessor:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||
else:
|
||||
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
@@ -258,13 +265,13 @@ class TokenProcessor:
|
||||
llm_logger.info(f"finished_task_id: {finished_task_id}")
|
||||
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
|
||||
if task_id in self.prefill_result_status:
|
||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
if self.prefill_result_status[task_id] != "finished":
|
||||
result.error_code = 400
|
||||
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
|
||||
result.error_msg = f"{task_id} failed to {self.prefill_result_status[task_id]}"
|
||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||
del self.resource_manager.req_dict[task_id]
|
||||
break
|
||||
else:
|
||||
@@ -276,6 +283,15 @@ class TokenProcessor:
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(
|
||||
self.resource_manager.total_block_number() - task_used_block_num
|
||||
)
|
||||
main_process_metrics.batch_size.set(
|
||||
self.resource_manager.max_num_seqs - self.resource_manager.available_batch()
|
||||
)
|
||||
|
||||
if task_id in self.tokens_counter:
|
||||
del self.tokens_counter[task_id]
|
||||
|
||||
@@ -412,7 +428,11 @@ class TokenProcessor:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
if (
|
||||
not is_prefill
|
||||
or self.cfg.scheduler_config.name == "splitwise"
|
||||
or self.cfg.scheduler_config.name == "dp"
|
||||
):
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
@@ -427,6 +447,7 @@ class TokenProcessor:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
self._record_speculative_decoding_mertics(accept_num)
|
||||
|
||||
else:
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = tokens[2 : batch + 2]
|
||||
@@ -441,16 +462,22 @@ class TokenProcessor:
|
||||
|
||||
task_id = task.request_id
|
||||
if self.cfg.speculative_config.method:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
||||
continue
|
||||
if accept_num[i] == -3:
|
||||
recovery_stop = True
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||
else:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||
continue
|
||||
else:
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
@@ -474,6 +501,7 @@ class TokenProcessor:
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
model_execute_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||
request_start_time=task.arrival_time,
|
||||
@@ -485,6 +513,7 @@ class TokenProcessor:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
model_execute_time=time.time() - task.inference_start_time,
|
||||
)
|
||||
self.number_of_output_tokens += len(token_ids)
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
@@ -502,7 +531,7 @@ class TokenProcessor:
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
@@ -512,7 +541,8 @@ class TokenProcessor:
|
||||
for token_id in token_ids:
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
@@ -531,7 +561,11 @@ class TokenProcessor:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
if (
|
||||
not is_prefill
|
||||
or self.cfg.scheduler_config.name == "splitwise"
|
||||
or self.cfg.scheduler_config.name == "dp"
|
||||
):
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
@@ -549,6 +583,7 @@ class TokenProcessor:
|
||||
def _record_first_token_metrics(self, task, current_time):
|
||||
"""Record metrics for first token"""
|
||||
task.first_token_time = current_time
|
||||
main_process_metrics.first_token_latency.set(current_time - task.inference_start_time)
|
||||
main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time)
|
||||
|
||||
@@ -560,6 +595,7 @@ class TokenProcessor:
|
||||
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
main_process_metrics.request_success_total.inc()
|
||||
main_process_metrics.infer_latency.set(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id])
|
||||
|
||||
@@ -571,7 +607,7 @@ class TokenProcessor:
|
||||
self.cfg.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
real_accept_num = [x for x in accept_num if x != 0]
|
||||
real_accept_num = [x for x in accept_num if x > 0]
|
||||
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
num_emitted_tokens = sum(real_accept_num)
|
||||
|
@@ -18,6 +18,7 @@ import redis
|
||||
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
from .dp_scheduler import DPScheduler
|
||||
from .global_scheduler import GlobalScheduler
|
||||
from .local_scheduler import LocalScheduler
|
||||
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
|
||||
@@ -89,6 +90,57 @@ class LocalSchedulerConfig:
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
|
||||
class DPLocalSchedulerConfig(LocalSchedulerConfig):
|
||||
"""
|
||||
Configuration class for DPLocalScheduler.
|
||||
|
||||
Attributes:
|
||||
max_size: Maximum number of concurrent requests (-1 for unlimited)
|
||||
ttl: Time-to-live in seconds for request expiration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = -1,
|
||||
ttl: int = 900,
|
||||
max_model_len: int = 8192,
|
||||
enable_chunked_prefill: bool = False,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
splitwise_role: str = "prefill",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize LocalScheduler configuration.
|
||||
|
||||
Args:
|
||||
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
|
||||
ttl: Time-to-live in seconds for request expiration (default 900s)
|
||||
max_model_len: Maximum model context length in tokens
|
||||
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||
max_num_partial_prefills: Max partial prefill operations allowed
|
||||
max_long_partial_prefills: Max long-running partial prefill ops
|
||||
long_prefill_token_threshold: Token count threshold for long prefill
|
||||
**kwargs: Additional unused arguments (for forward compatibility)
|
||||
|
||||
Note:
|
||||
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
|
||||
- See LocalScheduler class for implementation details
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
self.splitwise_role = splitwise_role
|
||||
|
||||
|
||||
class GlobalSchedulerConfig:
|
||||
"""
|
||||
Configuration class for GlobalScheduler (Redis-based).
|
||||
@@ -229,6 +281,9 @@ class SchedulerConfig:
|
||||
if name == "splitwise":
|
||||
self.config = SplitWiseSchedulerConfig(**kwargs)
|
||||
|
||||
if name == "dp":
|
||||
self.config = DPLocalSchedulerConfig(**kwargs)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Validate the configuration.
|
||||
@@ -236,7 +291,7 @@ class SchedulerConfig:
|
||||
Raises:
|
||||
Exception: If invalid scheduler type is specified
|
||||
"""
|
||||
if self.name not in ["local", "global", "splitwise"]:
|
||||
if self.name not in ["local", "global", "splitwise", "dp"]:
|
||||
raise Exception(f"Unknown scheduler type {self.name}")
|
||||
|
||||
self.config.check()
|
||||
@@ -274,6 +329,17 @@ class SchedulerConfig:
|
||||
if self.name == "splitwise":
|
||||
return SplitWiseScheduler(self.config)
|
||||
|
||||
if self.name == "dp":
|
||||
return DPScheduler(
|
||||
max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
enable_chunked_prefill=self.config.enable_chunked_prefill,
|
||||
max_num_partial_prefills=self.config.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.config.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
|
||||
splitwise_role=self.config.splitwise_role,
|
||||
)
|
||||
|
||||
return LocalScheduler(
|
||||
max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
|
258
fastdeploy/scheduler/dp_scheduler.py
Normal file
258
fastdeploy/scheduler/dp_scheduler.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing import Queue
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledResponse
|
||||
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
|
||||
class DPLocalScheduler(LocalScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
splitwise_role: str = "prefill",
|
||||
):
|
||||
super().__init__(
|
||||
max_size,
|
||||
ttl,
|
||||
enable_chunked_prefill,
|
||||
max_num_partial_prefills,
|
||||
max_long_partial_prefills,
|
||||
long_prefill_token_threshold,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.scheduler_logger = logging
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
Add processing results back to the scheduler.
|
||||
Args:
|
||||
results: List of RequestOutput objects containing results
|
||||
"""
|
||||
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
|
||||
|
||||
finished_responses = [response.request_id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||
|
||||
with self.mutex:
|
||||
for response in responses:
|
||||
if response.request_id not in self.responses:
|
||||
self.responses[response.request_id] = [response]
|
||||
continue
|
||||
self.responses[response.request_id].append(response)
|
||||
self.responses_not_empty.notify_all()
|
||||
|
||||
def _recycle(self, request_id: Optional[str] = None):
|
||||
"""
|
||||
Clean up expired or completed requests to free memory.
|
||||
Args:
|
||||
request_id: Optional specific request ID to remove.
|
||||
If None, removes all expired requests.
|
||||
"""
|
||||
if request_id is not None:
|
||||
self.requests.pop(request_id, None)
|
||||
self.responses.pop(request_id, None)
|
||||
if self.splitwise_role == "decode":
|
||||
return
|
||||
self.ids.pop(self.ids.index(request_id))
|
||||
self.ids_read_cursor -= 1
|
||||
return
|
||||
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
|
||||
if len(self.requests) <= self.max_size:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
for request_id in self.ids:
|
||||
request = self.requests[request_id]
|
||||
if now - request.schedule_time < self.ttl:
|
||||
break
|
||||
expired_ids.append(request.request_id)
|
||||
|
||||
for i, expired_id in enumerate(expired_ids):
|
||||
self.requests.pop(expired_id, None)
|
||||
self.responses.pop(expired_id, None)
|
||||
self.ids.pop(i)
|
||||
|
||||
if len(expired_ids) > 0:
|
||||
if len(expired_ids) - 1 >= self.ids_read_cursor:
|
||||
self.ids_read_cursor = 0
|
||||
else:
|
||||
self.ids_read_cursor -= len(expired_ids)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
"""
|
||||
Retrieve requests from the scheduler based on available resources.
|
||||
|
||||
Args:
|
||||
available_blocks: Number of available processing blocks
|
||||
block_size: Size of each processing block
|
||||
reserved_output_blocks: Blocks reserved for output
|
||||
max_num_batched_tokens: Maximum tokens that can be batched
|
||||
batch: Preferred batch size
|
||||
|
||||
Returns:
|
||||
List of Request objects ready for processing
|
||||
"""
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||
)
|
||||
return []
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
start_batch_time = time.time()
|
||||
requests: List[Request] = []
|
||||
|
||||
with self.requests_not_empty:
|
||||
while True:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
)
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
if len(requests) >= batch:
|
||||
break
|
||||
if (
|
||||
(current_prefill_tokens > max_num_batched_tokens)
|
||||
or (len(requests) >= batch)
|
||||
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||
):
|
||||
break
|
||||
if batch_ids:
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||
)
|
||||
|
||||
if len(requests) > 0:
|
||||
self.scheduler_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
class DPScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
splitwise_role: str = "prefill",
|
||||
):
|
||||
self._scheduler = DPLocalScheduler(
|
||||
max_size,
|
||||
ttl,
|
||||
enable_chunked_prefill,
|
||||
max_num_partial_prefills,
|
||||
max_long_partial_prefills,
|
||||
long_prefill_token_threshold,
|
||||
splitwise_role,
|
||||
)
|
||||
|
||||
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
|
||||
self.dp_rank = dp_rank
|
||||
self.request_queues = request_queues
|
||||
self.result_queue = result_queue
|
||||
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
|
||||
self._scheduler.scheduler_logger = self.scheduler_logger
|
||||
threading.Thread(target=self._put_requests_to_local).start()
|
||||
threading.Thread(target=self._get_response_from_local).start()
|
||||
|
||||
def put_requests(self, requests: List[Dict]):
|
||||
results = []
|
||||
for request in requests:
|
||||
if not hasattr(request, "dp_rank"):
|
||||
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
|
||||
self.request_queues[request.dp_rank].put(request)
|
||||
results.append((request.request_id, None))
|
||||
return results
|
||||
|
||||
def _put_requests_to_local(self):
|
||||
while True:
|
||||
request = self.request_queues[self.dp_rank].get()
|
||||
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
|
||||
self._scheduler.put_requests([request])
|
||||
|
||||
def _get_response_from_local(self):
|
||||
while True:
|
||||
results = self._scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
continue
|
||||
self.result_queue.put(results)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
return self._scheduler.get_requests(
|
||||
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
|
||||
)
|
||||
|
||||
def get_unhandled_request_num(self):
|
||||
return len(self._scheduler.requests)
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
self._scheduler.put_results(results)
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
return self.result_queue.get()
|
@@ -208,6 +208,9 @@ class LocalScheduler:
|
||||
"""
|
||||
return (token_num + block_size - 1) // block_size
|
||||
|
||||
def get_unhandled_request_num(self):
|
||||
return len(self.requests)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
|
@@ -37,6 +37,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
eagle_get_self_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
@@ -75,6 +76,7 @@ class MTPProposer(Proposer):
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.model_config.model = self.speculative_config.model
|
||||
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
||||
self.model_config.is_quantized = False
|
||||
if self.speculative_config.quantization != "":
|
||||
self.model_config.quantization = self.speculative_config.quantization
|
||||
self.model_config.start_layer_index = self.num_main_model_layers
|
||||
@@ -141,17 +143,16 @@ class MTPProposer(Proposer):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
|
||||
cache_kvs_list = []
|
||||
for i in range(
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -160,7 +161,10 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["caches"] = cache_kvs_list
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
for i in range(
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
shape=kv_cache_shape,
|
||||
fill_value=0,
|
||||
@@ -171,6 +175,15 @@ class MTPProposer(Proposer):
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"key_caches_{i}"],
|
||||
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"value_caches_{i}"],
|
||||
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
@@ -235,7 +248,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.main_model_num_gpu_blocks = num_gpu_blocks
|
||||
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
if self.parallel_config.splitwise_role == "mixed":
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
|
117
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
117
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# **Note**: Just for internal use
|
||||
import zmq
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqTcpServer
|
||||
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
|
||||
|
||||
|
||||
class InternalAdapter:
|
||||
def __init__(self, cfg, engine, dp_rank):
|
||||
self.cfg = cfg
|
||||
self.engine = engine
|
||||
self.dp_rank = dp_rank
|
||||
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
|
||||
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
|
||||
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
|
||||
self.recv_external_instruct_thread = threading.Thread(
|
||||
target=self._recv_external_module_control_instruct, daemon=True
|
||||
)
|
||||
self.recv_external_instruct_thread.start()
|
||||
self.response_external_instruct_thread = threading.Thread(
|
||||
target=self._response_external_module_control_instruct, daemon=True
|
||||
)
|
||||
self.response_external_instruct_thread.start()
|
||||
|
||||
def _get_current_server_info(self):
|
||||
"""
|
||||
Get resources information
|
||||
"""
|
||||
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
||||
|
||||
available_block_num = self.engine.resource_manager.available_block_num()
|
||||
server_info = {
|
||||
"splitwise_role": self.cfg.splitwise_role,
|
||||
"block_size": int(self.cfg.cache_config.block_size),
|
||||
"block_num": int(available_block_num),
|
||||
"max_block_num": int(self.cfg.cache_config.total_block_num),
|
||||
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
|
||||
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
|
||||
"max_batch_size": int(available_batch_size),
|
||||
"max_input_token_num": self.cfg.max_num_batched_tokens,
|
||||
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
|
||||
"available_batch": int(self.engine.resource_manager.available_batch()),
|
||||
}
|
||||
return server_info
|
||||
|
||||
def _recv_external_module_control_instruct(self):
|
||||
"""
|
||||
Receive a multipart message from the control cmd socket.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
with self.response_lock:
|
||||
task = self.recv_control_cmd_server.recv_control_cmd()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"Recieve control task: {task}")
|
||||
task_id_str = task["task_id"]
|
||||
if task["cmd"] == "get_payload":
|
||||
payload_info = self._get_current_server_info()
|
||||
result = {"task_id": task_id_str, "result": payload_info}
|
||||
logger.debug(f"Response for task: {task_id_str}")
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
|
||||
elif task["cmd"] == "get_metrics":
|
||||
metrics_text = get_filtered_metrics(
|
||||
[],
|
||||
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
|
||||
)
|
||||
result = {"task_id": task_id_str, "result": metrics_text}
|
||||
logger.debug(f"Response for task: {task_id_str}")
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
elif task["cmd"] == "connect_rdma":
|
||||
self.engine.engine_worker_queue.put_connect_rdma_task(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def _response_external_module_control_instruct(self):
|
||||
while True:
|
||||
try:
|
||||
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
|
||||
if result_data:
|
||||
task_id_str = result_data["task_id"]
|
||||
result = {"task_id": task_id_str, "result": result_data}
|
||||
logger.info(f"Response for task: {task_id_str}")
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")
|
@@ -14,27 +14,26 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
|
||||
|
||||
class SplitwiseConnector:
|
||||
"""
|
||||
SplitwiseConnector class for managing and scheduling Splitwise tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, scheduler, worker_queue, resource_manager):
|
||||
def __init__(self, cfg, scheduler, worker_queue, resource_manager, splitwise_queue):
|
||||
"""
|
||||
Initialize the SplitwiseConnector instance.
|
||||
|
||||
@@ -45,12 +44,20 @@ class SplitwiseConnector:
|
||||
resource_manager (object): Resource manager object.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.logger = get_logger(
|
||||
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
self.scheduler = scheduler
|
||||
self.engine_worker_queue = worker_queue
|
||||
self.resource_manager = resource_manager
|
||||
self.connect_innode_instances = {}
|
||||
self.temp_cache_info = dict()
|
||||
self.current_request_ids = dict()
|
||||
self.splitwise_queue = splitwise_queue
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
self.zmq_ctx = zmq.Context()
|
||||
@@ -69,7 +76,7 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
@@ -88,16 +95,16 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
logger.debug(f"receive {socks}")
|
||||
self.logger.debug(f"receive {socks}")
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
logger.debug(f"frames: {frames}")
|
||||
self.logger.debug(f"frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
time.sleep(0.001)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Receiver error: {e}")
|
||||
self.logger.error(f"Receiver error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_push_socket(self, addr):
|
||||
@@ -109,7 +116,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
try:
|
||||
logger.info(f"Establishing new connection to {addr}")
|
||||
self.logger.info(f"Establishing new connection to {addr}")
|
||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||
|
||||
# 设置连接参数
|
||||
@@ -128,7 +135,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"Connection to {addr} failed: {e}")
|
||||
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||
|
||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||
|
||||
@@ -137,7 +144,7 @@ class SplitwiseConnector:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
message = self._serialize_message(msg_type, payload)
|
||||
|
||||
try:
|
||||
@@ -145,18 +152,19 @@ class SplitwiseConnector:
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
logger.warning(f"Connection to {addr} not established")
|
||||
self.logger.warning(f"Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
logger.warning(f"Send queue full for {addr}")
|
||||
self.logger.warning(f"Send queue full for {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Send to {addr} failed: {e}")
|
||||
main_process_metrics.send_cache_failed_num.inc()
|
||||
self.logger.error(f"Send to {addr} failed: {e}")
|
||||
self._close_connection(addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message preparation failed: {e}")
|
||||
self.logger.error(f"Message preparation failed: {e}")
|
||||
|
||||
def _close_connection(self, addr):
|
||||
"""
|
||||
@@ -261,7 +269,7 @@ class SplitwiseConnector:
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
@@ -289,7 +297,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
logger.info(f"send splitwise tasks to port {port} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||
current_port = port
|
||||
return current_port
|
||||
|
||||
@@ -299,7 +307,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
logger.info("send first token to port decode")
|
||||
self.logger.info("send first token to port decode")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
@@ -307,7 +315,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
logger.info(f"send first token to port {node} decode")
|
||||
self.logger.info(f"send first token to port {node} decode")
|
||||
self._send_message(node, "decode", tasks_list)
|
||||
|
||||
def create_connection(self, port):
|
||||
@@ -323,6 +331,22 @@ class SplitwiseConnector:
|
||||
client_id=0,
|
||||
)
|
||||
|
||||
def check_decode_allocated(self, task):
|
||||
if task.disaggregate_info is None:
|
||||
return True, ""
|
||||
if self.enable_decode_cache_task:
|
||||
return True, ""
|
||||
if task.disaggregate_info["role"] != "prefill":
|
||||
return True, ""
|
||||
while self.current_request_ids[task.request_id] == "init":
|
||||
time.sleep(0.001)
|
||||
msg = self.current_request_ids[task.request_id]
|
||||
del self.current_request_ids[task.request_id]
|
||||
if msg == "finished":
|
||||
return True, ""
|
||||
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_infos(self, tasks, current_id):
|
||||
"""
|
||||
Send cache information to specific port.
|
||||
@@ -339,15 +363,21 @@ class SplitwiseConnector:
|
||||
for i in range(len(tasks)):
|
||||
if tasks[i].disaggregate_info is None:
|
||||
continue
|
||||
logger.info(f"{tasks[i].disaggregate_info}")
|
||||
self.logger.info(f"{tasks[i].disaggregate_info}")
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"transfer_protocol": "ipc",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"transfer_protocol": "ipc",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
|
||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
|
||||
@@ -356,14 +386,20 @@ class SplitwiseConnector:
|
||||
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if addr not in temp_cache_info:
|
||||
temp_cache_info[addr] = []
|
||||
|
||||
@@ -390,7 +426,7 @@ class SplitwiseConnector:
|
||||
else:
|
||||
if len(temp_cache_info):
|
||||
for k, v in temp_cache_info.items():
|
||||
logger.info(f"{k} {v}")
|
||||
self.logger.info(f"{k} {v}")
|
||||
if ":" in str(k):
|
||||
self._send_message(k, "cache_sync", v)
|
||||
else:
|
||||
@@ -406,13 +442,19 @@ class SplitwiseConnector:
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
self.logger.info(f"send message {msg_type} {req_ids}")
|
||||
|
||||
json_data = msgpack.packb({"type": msg_type, "payload": payload})
|
||||
|
||||
return json_data
|
||||
|
||||
def _deserialize_message(self, data: bytes):
|
||||
|
||||
# JSON反序列化
|
||||
message = json.loads(data.decode("utf-8"))
|
||||
message = msgpack.unpackb(data)
|
||||
req_ids = [task["request_id"] for task in message["payload"]]
|
||||
self.logger.info(f"recv message type {message['type']} for {req_ids}")
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
@@ -421,7 +463,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
logger.info(f"{msg_type}")
|
||||
self.logger.info(f"{msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
self._handle_prefill(payload)
|
||||
@@ -429,11 +471,16 @@ class SplitwiseConnector:
|
||||
self._handle_decode(payload)
|
||||
elif msg_type == "cache_sync":
|
||||
for task in payload:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
self.logger.info(f"cache_sync task: {task}")
|
||||
current_status = task.get("error_msg", "finished")
|
||||
self.current_request_ids[task["request_id"]] = current_status
|
||||
if self.enable_decode_cache_task:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
if current_status == "finished":
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message processing failed: {e}")
|
||||
self.logger.error(f"Message processing failed: {e}")
|
||||
|
||||
def _handle_prefill(self, tasks):
|
||||
"""
|
||||
@@ -441,7 +488,9 @@ class SplitwiseConnector:
|
||||
"""
|
||||
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
|
||||
req_ids = [task["request_id"] for task in tasks]
|
||||
self.splitwise_queue.append(("decode", tasks_data))
|
||||
self.logger.info(f"{req_ids} received prefill data")
|
||||
|
||||
def _handle_decode(self, payload):
|
||||
"""
|
||||
@@ -456,8 +505,13 @@ class SplitwiseConnector:
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||
),
|
||||
finished=True,
|
||||
error_code=task["error_code"],
|
||||
error_msg=task["error_msg"],
|
||||
)
|
||||
)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
self.splitwise_queue.append(("decode", tasks))
|
||||
self.logger.info(f"{req_ids} received decode data")
|
||||
|
@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
recover_decode_task,
|
||||
set_data_ipc,
|
||||
set_value_by_flags_and_idx,
|
||||
share_external_data,
|
||||
)
|
||||
@@ -904,7 +905,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
if not profile and self.parallel_config.splitwise_role != "mixed":
|
||||
cache_kvs_list = []
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -930,6 +931,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
set_data_ipc(
|
||||
cache_kvs[f"key_caches_{i}"],
|
||||
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
set_data_ipc(
|
||||
cache_kvs[f"value_caches_{i}"],
|
||||
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
self.share_inputs["caches"] = list(cache_kvs.values())
|
||||
for value in cache_kvs.values():
|
||||
del value
|
||||
@@ -1138,6 +1148,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if task.chunk_idx > len(task.prefill_chunk_info):
|
||||
continue
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
if len(self.restore_chunked_prefill_request) > 0:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
idx = task.idx
|
||||
@@ -1182,7 +1194,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||
|
||||
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
||||
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
|
||||
self.proposer.update_task_chunk_prefill(task)
|
||||
task.chunk_idx += 1
|
||||
@@ -1256,17 +1268,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
We plan to replace it with 'ModelForwardBatch'.
|
||||
intermediate_tensors:
|
||||
"""
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
if not self.not_need_stop():
|
||||
self._execute_empty_input()
|
||||
return None
|
||||
start_time = time.time()
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# 2. Padding inputs for cuda graph
|
||||
self.padding_cudagraph_inputs()
|
||||
@@ -1397,6 +1409,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
end_time = time.time()
|
||||
logger.debug(f"execute one step cost time: {end_time-start_time:.3f} s")
|
||||
return None
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
@@ -1507,12 +1521,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
|
||||
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
|
||||
num_layers = (
|
||||
num_hidden_layers = (
|
||||
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
|
||||
if self.speculative_method in ["mtp"]
|
||||
else self.model_config.num_hidden_layers
|
||||
)
|
||||
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
|
||||
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_hidden_layers # k + v
|
||||
return required_memory
|
||||
|
||||
def not_need_stop(self) -> bool:
|
||||
|
@@ -150,7 +150,7 @@ class PaddleDisWorkerProc:
|
||||
# Initialize task queue
|
||||
task_address = (
|
||||
self.parallel_config.pod_ip,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
self.parallel_config.engine_worker_queue_port + self.parallel_config.expert_parallel_rank,
|
||||
)
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
self.task_queue = TaskQueue(
|
||||
@@ -252,9 +252,11 @@ class PaddleDisWorkerProc:
|
||||
for req_dict, bsz in tasks:
|
||||
num_running_requests = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
req_ids = [req.request_id for req in req_dicts]
|
||||
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
|
||||
f"num_insert_requests: {len(req_dicts)}"
|
||||
f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}"
|
||||
)
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts)
|
||||
@@ -408,7 +410,7 @@ class PaddleDisWorkerProc:
|
||||
|
||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||
# wait engine launch cache_manager
|
||||
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
||||
if self.parallel_config.splitwise_role != "mixed":
|
||||
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.launched_cache_manager_signal = IPCSignal(
|
||||
name="launched_cache_manager_signal",
|
||||
|
Reference in New Issue
Block a user