[Feature] Support pd ep deployment with yiyan adapter (#4029)

* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* Support pd ep deployment with yiyan adapter

* Support pd ep deployment with yiyan adapter

* refactor cache messager

* support scheduler v1 in PD

* suppport pd v1 + chunk prefill

* suppport pd v1 + chunk prefill

* add eplb

* support eplb

* support eplb

* support eplb

* support v1

* fix

* fix

* fix bug

* remove eplb support

* support prefix cache in P

* fix bug

* fix bug

* support one stop in V1

* fix bug

* fix ci

* fix ci

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-22 16:41:38 +08:00
committed by GitHub
parent 9845f0d010
commit 918ccdb123
22 changed files with 1838 additions and 343 deletions

View File

@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
const int max_bsz, const int max_bsz,
const int input_ids_stride, const int input_ids_stride,
const int block_num_per_seq, const int block_num_per_seq,
const int block_size) { const int block_size,
bool prefill_one_step_stop) {
int thread_idx = threadIdx.x; int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
seq_lens_encoder[thread_idx] = 0; seq_lens_encoder[thread_idx] = 0;
} else { } else {
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) { if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
// decoding if (prefill_one_step_stop) {
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; // prefill done, stop
seq_lens_this_time[thread_idx] = 1; stop_flags[thread_idx] = true;
seq_lens_encoder[thread_idx] = 0; seq_lens_this_time[thread_idx] = 0;
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; seq_lens_decoder[thread_idx] = 0;
input_ids_now[0] = next_tokens[thread_idx]; seq_lens_encoder[thread_idx] = 0;
stop_flag_now_int = 1;
} else{
// decoding
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
seq_lens_this_time[thread_idx] = 1;
seq_lens_encoder[thread_idx] = 0;
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
input_ids_now[0] = next_tokens[thread_idx];
// to judge whether block is not enough // to judge whether block is not enough
int *block_table_now = block_tables + thread_idx * block_num_per_seq; int *block_table_now = block_tables + thread_idx * block_num_per_seq;
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) { if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
// should be scheduled by server // should be scheduled by server
is_block_step[thread_idx] = true; is_block_step[thread_idx] = true;
seq_lens_this_time[thread_idx]= 0; seq_lens_this_time[thread_idx]= 0;
stop_flags[thread_idx] = true; stop_flags[thread_idx] = true;
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
seq_lens_decoder[thread_idx] = 0; seq_lens_decoder[thread_idx] = 0;
stop_flag_now_int = 1; stop_flag_now_int = 1;
}
} }
} else } else
{ {
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
#else #else
auto cu_stream = input_ids.stream(); auto cu_stream = input_ids.stream();
#endif #endif
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
const int max_bsz = stop_flags.shape()[0]; const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0]; const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1]; const int input_ids_stride = input_ids.shape()[1];
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
max_bsz, max_bsz,
input_ids_stride, input_ids_stride,
block_num_per_seq, block_num_per_seq,
block_size); block_size,
prefill_one_step_stop);
auto not_need_stop_cpu = auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false); not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>()); bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());

View File

@@ -14,7 +14,10 @@
# limitations under the License. # limitations under the License.
""" """
import argparse
import json
import math import math
import queue
import threading import threading
import time import time
import traceback import traceback
@@ -23,16 +26,72 @@ import numpy as np
import paddle import paddle
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import ( from fastdeploy.inter_communicator import (
EngineWorkerQueue, EngineWorkerQueue,
IPCSignal, IPCSignal,
shared_memory_exists, shared_memory_exists,
) )
from fastdeploy.utils import get_logger from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log") logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
"""
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache Messager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port",
)
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
default="{}",
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args()
return args
class CacheMessager: class CacheMessager:
""" """
CacheMessager is used to send the cache data between the engine worker and the cache server. CacheMessager is used to send the cache data between the engine worker and the cache server.
@@ -69,11 +128,6 @@ class CacheMessager:
Returns: Returns:
None None
""" """
assert splitwise_role in [
"prefill",
"decode",
], "splitwise_role must be prefill or decode"
self.splitwise_role = splitwise_role self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank self.rank = rank
@@ -147,15 +201,16 @@ class CacheMessager:
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.cache_info = dict() self.cache_info = dict()
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks self.rank_id = self.rank + local_data_parallel_id * self.nranks
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) if self.splitwise_role != "mixed":
layerwise_send_cache_thread.daemon = True connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
layerwise_send_cache_thread.start() connect_rdma_thread.daemon = True
connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}") logger.info(f"cache messager init finished, use {transfer_protocol}")
def _prefill_layerwise_send_cache_thread(self): def prefill_layerwise_send_cache_thread(self):
""" """
layerwise_send_cache_thread: layerwise_send_cache_thread:
send cache to other instance send cache to other instance
@@ -163,23 +218,23 @@ class CacheMessager:
try: try:
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32) prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}" prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}" prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
step_shm_value = IPCSignal( step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", name=f"splitwise_complete_prefilled_step_{self.rank_id}",
array=prefilled_step_idx_data, array=prefilled_step_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=not shared_memory_exists(prefilled_step_name), create=not shared_memory_exists(prefilled_step_name),
) )
layer_shm_value = IPCSignal( layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
array=prefilled_layer_idx_data, array=prefilled_layer_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=not shared_memory_exists(prefilled_layer_name), create=not shared_memory_exists(prefilled_layer_name),
) )
logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}") logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")
step_shm_value.value[0] = -1 step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1 layer_shm_value.value[0] = -1
@@ -187,6 +242,9 @@ class CacheMessager:
self.last_step_idx = -1 self.last_step_idx = -1
self.last_layer_idx = -1 # int32 self.last_layer_idx = -1 # int32
max_step_idx = 100003
engine_recycled_count = 0
while True: while True:
cache_info = self.engine_worker_queue.get_cache_info() cache_info = self.engine_worker_queue.get_cache_info()
@@ -202,11 +260,9 @@ class CacheMessager:
-len(current_info["dest_block_ids"]) : -len(current_info["dest_block_ids"]) :
] ]
current_info["src_block_ids"] = current_src_blocks current_info["src_block_ids"] = current_src_blocks
current_info["current_layer_ids"] = 0
current_info["status"] = "init" current_info["status"] = "init"
logger.info(f"start cache_infos: {current_info}") logger.info(f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
else: else:
self.cache_info[info["request_id"]] = info self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0] prefilled_layer_idx = layer_shm_value.value[0]
@@ -223,7 +279,18 @@ class CacheMessager:
if not self.cache_info: if not self.cache_info:
time.sleep(0.001) time.sleep(0.001)
continue continue
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
if self.last_step_idx > prefilled_step_idx:
engine_recycled_count += 1
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
prefilled_step_idx = (
prefilled_step_idx + max_step_idx * engine_recycled_count
) # remap prefilled_step_idx for comparison
logger.debug(
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
)
for req_id, item in list(self.cache_info.items()): for req_id, item in list(self.cache_info.items()):
if "status" not in item: if "status" not in item:
continue continue
@@ -294,12 +361,493 @@ class CacheMessager:
logger.info(f"finish write cache {item['request_id']}") logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait() self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0: if self.rank == 0:
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}") logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id] del self.cache_info[req_id]
self.last_layer_idx = prefilled_layer_idx
self.last_step_idx = prefilled_step_idx
self.last_layer_idx = prefilled_layer_idx
except Exception as e: except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}") logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
class CacheMessagerV1:
"""
CacheMessager is used to send the cache data between the engine worker and the cache server.
"""
def __init__(
self,
splitwise_role,
transfer_protocol,
pod_ip,
engine_worker_queue_port,
local_data_parallel_id,
gpu_cache_kvs,
rank,
nranks,
num_layers,
gpu_id=0,
block_size=64,
rdma_port=None,
):
"""
Initialize the CacheMessager object.
Args:
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
transfer_protocol (str): support ipc and rdma
engine_worker_queue_port (int): engine_worker_queue port
gpu_cache_kvs (dict): GPU kv cache
rank (int): current rank
nranks (int): global rank number
num_layers (int): model layer number
gpu_id (int, optional): GPU ID
rdma_port (int, optional): RDMA port
Returns:
None
"""
self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
self.nranks = nranks
address = (pod_ip, engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.nranks,
client_id=self.rank,
local_data_parallel_id=local_data_parallel_id,
)
self.block_size = block_size
transfer_protocol = transfer_protocol.split(",")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers
cache_k_ptr_list = []
cache_v_ptr_list = []
cache_k = []
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
# 2. initialize the block_bytes
cache_shape = key_cache.shape
max_block_num = cache_shape[0]
block_bytes = math.prod(cache_shape[1:])
if key_cache.dtype == paddle.bfloat16:
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
)
self.block_bytes = block_bytes
# 3. initialize the messager
for protocol in transfer_protocol:
if protocol == "ipc":
self.messager[protocol] = IPCCommManager(
self.rank,
gpu_id,
cache_k,
cache_v,
)
local_device_id = int(str(cache_k[0].place)[-2])
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
elif protocol == "rdma":
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
self.messager[protocol] = RDMACommManager(
splitwise_role,
rank,
gpu_id,
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
rdma_port,
)
self.gpu_id = gpu_id
self.cache_info = dict()
self.rank_id = self.rank + local_data_parallel_id * self.nranks
self.engine_cache_task_thread_lock = threading.Lock()
self.engine_cache_tasks = [dict() for _ in range(512)]
self.idx_cache_task_dict = {}
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
if splitwise_role == "prefill":
consume_signals_thread = threading.Thread(target=self.consume_signals)
consume_signals_thread.daemon = True
consume_signals_thread.start()
add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
add_cache_task_thread.daemon = True
add_cache_task_thread.start()
if self.splitwise_role != "mixed":
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True
connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}")
def _add_cache_task_thread(self):
while True:
try:
cache_info = self.engine_worker_queue.get_cache_info()
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
finished_add_cache_task_req_ids = []
if cache_info:
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]]
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
finished_add_cache_task_req_ids.append(info["request_id"])
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
current_info["dest_block_ids"]
)
padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
"dest_block_ids"
]
current_info["dest_block_ids"] = padding_decode_block_ids
current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
self.cache_info[info["request_id"]] = info
if self.rank == 0 and finished_add_cache_task_req_ids:
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
else:
time.sleep(0.001)
except Exception as e:
logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.")
def prefill_layerwise_send_cache_thread(self):
"""
layerwise_send_cache_thread:
send cache to other instance
"""
while True:
try:
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.finish_request_barrier.wait()
block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index in engine_indexes:
assert engine_index in self.idx_cache_task_dict
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
if (
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
): # all chunks have been prefilled
block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
else:
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
block_start_end_list.append((block_id_start, block_id_end))
current_prefilled_token_num_list.append(prefilled_token_num)
while True: # from layer0 to last layer
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
start_layer_idx = sended_layer_idx + 1
with self.engine_cache_task_thread_lock: # to check end_layer_idx
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
logger.info(
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
)
assert (
current_prefilled_token_num_list[0]
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
else:
end_layer_idx = prefilled_layer_idx
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
time.sleep(0.01)
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
engine_index = engine_indexes[i]
task = self.idx_cache_task_dict[engine_index]
req_id = task["request_id"]
if (
block_id_start >= block_id_end
): # no blocks need to transfer for this request in this chunk
task["sended_layer_id"] += 1
assert task["sended_layer_id"] == layer_idx
if task["sended_layer_id"] == self.num_layers - 1:
task["sended_layer_id"] = -1
continue
else:
current_transfer_protocol = task["transfer_protocol"]
if task["transfer_protocol"] == "rdma":
target_ip = task["ip"]
target_id = int(task["rdma_ports"][self.rank])
if task["status"] == "error":
continue
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
logger.error(f"connect to {target_ip}:{target_id} failed")
task["status"] = "connection error"
continue
elif task["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(task["device_ids"][self.rank])
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
logger.info(
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
)
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
if return_code != 0:
task["status"] = "write cache error"
logger.error(
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
)
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
)
task["sended_layer_id"] += 1
assert task["sended_layer_id"] == layer_idx
if task["sended_layer_id"] == self.num_layers - 1:
self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
block_id_end - block_id_start
)
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
if task["status"] != "error":
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
)
else:
task["sended_layer_id"] = -1
if end_layer_idx == self.num_layers - 1:
with self.engine_cache_task_thread_lock:
for engine_idx in engine_indexes:
task = self.idx_cache_task_dict[engine_idx]
if task["status"] == "finished" or ("error" in task["status"]):
target_id = int(task["rdma_ports"][self.rank])
if task["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
if self.rank == 0:
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req(
[(task["request_id"], task["status"])]
)
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
break
except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
time.sleep(0.01)
def consume_signals(self):
paddle.device.set_device("cpu")
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
while True:
try:
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
if not self.cache_info:
time.sleep(0.01)
continue
tasks_count = kv_signal_data[0]
if tasks_count == -1:
time.sleep(0.001)
continue
layer_id = kv_signal_data[1].numpy().tolist()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
batch_engine_ids = []
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
chuck_token_offset + current_seq_len
)
batch_engine_ids.append(engine_idx)
if layer_id == 0:
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
except Exception as e:
logger.error(f"Consume signals get exception: {e}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
def main():
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
gpu_cache_kvs = {}
gpu_cache_k_tensors = []
gpu_cache_v_tensors = []
for i in range(args.num_layers + num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
cache_messager = CacheMessagerV1(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=gpu_cache_kvs,
rank=rank,
nranks=args.mp_num,
num_layers=args.num_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
)
else:
cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=gpu_cache_kvs,
rank=rank,
nranks=args.mp_num,
num_layers=args.num_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
cache_ready_signal.value[rank] = 1
if args.splitwise_role == "mixed":
while True:
time.sleep(1)
cache_messager.prefill_layerwise_send_cache_thread()
if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -29,7 +29,7 @@ from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc, cuda_host_alloc,
set_data_ipc, share_external_data,
swap_cache_all_layers, swap_cache_all_layers,
) )
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -139,40 +139,27 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype cache_type = args.cache_dtype
cache_shape = [
args.num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
]
for i in range(args.num_layers + self.num_extra_layers): for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape[0] = num_gpu_blocks
key_name = f"key_caches_{i}_rank{rank}.device{device}"
value_name = f"value_caches_{i}_rank{rank}.device{device}"
key_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_name, cache_shape)
value_cache = share_external_data(value_cache, value_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[value_name] = value_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()]) cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}") logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -201,28 +188,6 @@ class CacheTransferManager:
) )
self.cache_ready_signal.value[self.rank] = 1 self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager
self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal( self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal", name="cache_task_broadcast_signal",
@@ -443,5 +408,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log") rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
paddle.set_device(f"gpu:{args.device_id}")
main() main()

View File

@@ -150,6 +150,19 @@ class PrefixCacheManager:
filename = "cache_transfer_manager.py" filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename) py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = []
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if ( if (
hasattr(cache_config.model_cfg, "num_key_value_heads") hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads") and hasattr(cache_config.model_cfg, "num_key_value_heads")
@@ -213,7 +226,76 @@ class PrefixCacheManager:
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.") logger.info("Enable hierarchical cache.")
self._enable_cpu_cache() self._enable_cpu_cache()
return cache_manager_processes all_cache_processes = cache_messager_processes + cache_manager_processes
return all_cache_processes
def launch_cache_messager(
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
):
"""
launch_cache_messager function used to initialize the cache messager.
"""
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_messager.py"
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --pod_ip {pod_ip}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --block_size {cache_config.block_size}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_messager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache messager successful")
else:
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
cache_messager_processes = None
return cache_messager_processes
def update_cache_config(self, cache_config): def update_cache_config(self, cache_config):
""" """

View File

@@ -61,18 +61,12 @@ class RDMACommManager:
Connect to remote gpu and write cache. Connect to remote gpu and write cache.
""" """
assert self.splitwise_role == "prefill", "only prefill can call this method" assert self.splitwise_role == "prefill", "only prefill can call this method"
addr = f"{ip}:{port!s}"
if addr in self.connected_rdma:
return True
ret = self.messager.is_connected(ip, str(port)) ret = self.messager.is_connected(ip, str(port))
if ret: if ret:
self.connected_rdma.add(addr)
return True return True
ret = self.messager.connect(ip, str(port)) ret = self.messager.connect(ip, str(port))
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}") logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
if ret == 0:
self.connected_rdma.add(addr)
return ret == 0 return ret == 0
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx): def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):

View File

@@ -1481,7 +1481,7 @@ class FDConfig:
self.model_config.model_format = "torch" self.model_config.model_format = "torch"
# TODO # TODO
self.max_prefill_batch = 3 self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
if current_platform.is_xpu(): if current_platform.is_xpu():
self.max_prefill_batch = 1 self.max_prefill_batch = 1
if self.model_config is not None and self.model_config.enable_mm: if self.model_config is not None and self.model_config.enable_mm:

View File

@@ -422,7 +422,7 @@ class EngineArgs:
raise NotImplementedError("Only CUDA platform supports logprob.") raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None: if self.speculative_config is not None:
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.splitwise_role != "mixed": if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda(): if not current_platform.is_cuda():
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

View File

@@ -46,7 +46,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, envs, llm_logger from fastdeploy.utils import EngineError, envs, get_logger, llm_logger
try: try:
TokenProcessor = load_token_processor_plugins() TokenProcessor = load_token_processor_plugins()
@@ -69,6 +69,13 @@ class EngineService:
""" """
self.cfg = cfg self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel:
self.llm_logger = get_logger(
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.llm_logger = llm_logger
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
@@ -79,10 +86,6 @@ class EngineService:
cfg.scheduler_config.splitwise_role, cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id, cfg.parallel_config.local_data_parallel_id,
) )
if cfg.scheduler_config.splitwise_role != "mixed":
raise NotImplementedError(
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
)
else: else:
self.resource_manager = ResourceManager( self.resource_manager = ResourceManager(
cfg.scheduler_config.max_num_seqs, cfg.scheduler_config.max_num_seqs,
@@ -135,12 +138,14 @@ class EngineService:
self.insert_task_to_worker_thread.start() self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run() self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.split_mode_get_tasks()
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理 def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = int( current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
) )
llm_logger.info(f"current_suffix: {current_suffix}") self.llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32) exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal( self.exist_task_signal = IPCSignal(
name="exist_task_signal", name="exist_task_signal",
@@ -201,7 +206,7 @@ class EngineService:
) )
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"): if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
llm_logger.info(f"Starting engine worker queue server service at {address}") self.llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue( self.engine_worker_queue_server = EngineWorkerQueue(
address=address, address=address,
is_server=True, is_server=True,
@@ -225,7 +230,7 @@ class EngineService:
client_id=-1, client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
) )
llm_logger.info( self.llm_logger.info(
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}" f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
) )
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
@@ -254,7 +259,17 @@ class EngineService:
cur_task_idx = self.resource_manager.req_dict[task.request_id] cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx] cur_task = self.resource_manager.tasks_list[cur_task_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
continue
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
cur_task.num_cached_tokens = task.num_cached_tokens
if ( if (
self.cfg.speculative_config.method in ["mtp"] self.cfg.speculative_config.method in ["mtp"]
and self.cfg.scheduler_config.splitwise_role == "decode" and self.cfg.scheduler_config.splitwise_role == "decode"
@@ -267,13 +282,14 @@ class EngineService:
if task.request_id in self.token_processor.tokens_counter: if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id] del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task]) self.scheduler.put_results([task])
llm_logger.warning( self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
) )
continue continue
self.token_processor.tokens_counter[task.request_id] = 1 self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task) current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True return True
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
@@ -281,13 +297,34 @@ class EngineService:
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks: for item in tasks:
item.schedule_start_time = time.time() item.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!") self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch] tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks] req_ids = [t.request_id for t in tasks]
@@ -296,7 +333,7 @@ class EngineService:
if not tasks: if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}." error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg) self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500) raise EngineError(error_msg, error_code=500)
return False return False
@@ -314,7 +351,7 @@ class EngineService:
self.split_connector.send_cache_infos(tasks, current_id) self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode: if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks: for task in tasks:
task.inference_start_time = time.time() task.inference_start_time = time.time()
if not is_prefill: if not is_prefill:
@@ -473,7 +510,7 @@ class EngineService:
Insert task to engine thread, monitor scheduler request queue. Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine if the engine has resource, insert task to engine
""" """
current_id = -1 current_id = 0
while getattr(self, "running", True): while getattr(self, "running", True):
try: try:
if self.resource_manager.available_batch() == 0: if self.resource_manager.available_batch() == 0:
@@ -514,18 +551,21 @@ class EngineService:
time.sleep(0.001) time.sleep(0.001)
continue continue
current_id = (current_id + 1) % 100003
if self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks") self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id) self.split_connector.send_splitwise_tasks(tasks, current_id)
self.insert_tasks(tasks, current_id) insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e: except Exception as e:
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}." err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg) self.llm_logger.error(err_msg)
def _scheduler_task_to_worker_v1(self): def _scheduler_task_to_worker_v1(self):
""" """
@@ -535,40 +575,100 @@ class EngineService:
is_fetching = False is_fetching = False
def _fetch_request(): def _fetch_request():
nonlocal is_fetching try:
is_fetching = True nonlocal is_fetching
num_prefill_batch = min( is_fetching = True
int(self.resource_manager.available_batch()), num_prefill_batch = min(
self.cfg.max_prefill_batch, int(self.resource_manager.available_batch()),
) self.cfg.max_prefill_batch,
if self.cfg.model_config.enable_mm: )
available_blocks = self.resource_manager.available_block_num() if self.cfg.model_config.enable_mm:
else: available_blocks = self.resource_manager.available_block_num()
available_blocks = self.cfg.cache_config.max_block_num_per_seq else:
available_blocks = self.cfg.cache_config.max_block_num_per_seq
tasks = self.scheduler.get_requests( tasks = self.scheduler.get_requests(
available_blocks=available_blocks, available_blocks=available_blocks,
block_size=self.cfg.cache_config.block_size, block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_model_len, max_num_batched_tokens=self.cfg.max_model_len,
batch=num_prefill_batch, batch=num_prefill_batch,
) )
# Fetch requests and add them to the scheduling queue if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks: for task in tasks:
self.resource_manager.add_request(task) # assure can allocate block ids in P
is_fetching = False while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.prerelease_resource(task)
if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager
if tasks:
self.split_connector.send_cache_infos(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids:
for req_id in req_ids:
assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id)
else:
time.sleep(0.001)
# Fetch requests and add them to the scheduling queue
if tasks:
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
else:
for task in tasks:
self.resource_manager.add_request(task)
is_fetching = False
except Exception as e:
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
is_fetching = False
while self.running: while self.running:
try: try:
if self.engine_worker_queue.num_tasks() > 0: if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001) time.sleep(0.001)
continue continue
if ( if self.cfg.scheduler_config.splitwise_role != "mixed":
len(self.resource_manager.waiting) == 0 if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
and (not is_fetching) not is_fetching
and self.exist_prefill_task_signal.value[0] == 0 ):
): get_request_pool.submit(_fetch_request)
get_request_pool.submit(_fetch_request)
else:
if (
len(self.resource_manager.waiting) == 0
and (not is_fetching)
and self.exist_prefill_task_signal.value[0] == 0
):
get_request_pool.submit(_fetch_request)
# 2. Schedule requests # 2. Schedule requests
tasks = self.resource_manager.schedule() tasks = self.resource_manager.schedule()
# 3. Send to engine # 3. Send to engine
@@ -579,8 +679,8 @@ class EngineService:
time.sleep(0.005) time.sleep(0.005)
except Exception as e: except Exception as e:
err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
llm_logger.error(err_msg) self.llm_logger.error(err_msg)
def start_zmq_service(self, api_server_pid=None): def start_zmq_service(self, api_server_pid=None):
if api_server_pid is None: if api_server_pid is None:
@@ -608,6 +708,9 @@ class EngineService:
def _insert_zmq_task_to_scheduler(self): def _insert_zmq_task_to_scheduler(self):
added_requests: Dict[str, int] = dict() added_requests: Dict[str, int] = dict()
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.scheduler_config.splitwise_role == "decode":
return
while self.running: while self.running:
try: try:
block = True if len(added_requests) == 0 else False block = True if len(added_requests) == 0 else False
@@ -616,7 +719,7 @@ class EngineService:
else: else:
err, data = self.recv_request_server.receive_pyobj_once(block) err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None: if err is not None:
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
break break
request, insert_task = None, [] request, insert_task = None, []
@@ -627,16 +730,16 @@ class EngineService:
request = Request.from_dict(data) request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc() main_process_metrics.requests_number.inc()
llm_logger.debug(f"Receive request: {request}") self.llm_logger.debug(f"Receive request: {request}")
except Exception as e: except Exception as e:
llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
err_msg = str(e) err_msg = str(e)
results.append((data["request_id"], err_msg)) results.append((data["request_id"], err_msg))
if self.guided_decoding_checker is not None and err_msg is None: if self.guided_decoding_checker is not None and err_msg is None:
request, err_msg = self.guided_decoding_checker.schema_format(request) request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None: if err_msg is not None:
llm_logger.error(f"Receive request error: {err_msg}") self.llm_logger.error(f"Receive request error: {err_msg}")
results.append((request.request_id, err_msg)) results.append((request.request_id, err_msg))
if err_msg is None: if err_msg is None:
@@ -670,7 +773,7 @@ class EngineService:
# Send result by zmq directly # Send result by zmq directly
self.send_response_server.send_response(request_id, [error_result]) self.send_response_server.send_response(request_id, [error_result])
except Exception as e: except Exception as e:
llm_logger.error( self.llm_logger.error(
f"Error happened while receiving new request from zmq, details={e}, " f"Error happened while receiving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}" f"traceback={traceback.format_exc()}"
) )
@@ -689,7 +792,7 @@ class EngineService:
self.send_response_server.send_response(request_id, contents) self.send_response_server.send_response(request_id, contents)
except Exception as e: except Exception as e:
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}") self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def split_mode_get_tasks(self): def split_mode_get_tasks(self):
""" """
@@ -702,13 +805,22 @@ class EngineService:
processed_indices = [] processed_indices = []
for idx, task in enumerate(self.waiting_requests): for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_tasks([task]) if self.resource_manager.preallocate_resource_in_d(task):
llm_logger.info(f"Resource available, processing task {task.request_id}") self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx) self.split_connector.send_cache_infos([task], -1)
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
else: else:
llm_logger.debug(f"Still waiting for resources {task.request_id}") if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
break self.insert_tasks([task])
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True): for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx) self.waiting_requests.pop(idx)
@@ -730,32 +842,79 @@ class EngineService:
tasks = [tasks] tasks = [tasks]
for task in tasks: for task in tasks:
task.finished = False task.finished = False
self.insert_tasks(tasks, allocated=True) if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for task in tasks:
if self.cfg.innode_prefill_ports is not None: if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.scheduler.put_results(tasks) if (
not task.outputs.token_ids
): # first token is eos in Prefill, just recycle resource and continue
cur_task = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_task.idx] = True
self.resource_manager.tasks_list[cur_task.idx] = None
self.resource_manager._free_blocks(cur_task)
if cur_task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(
f"{task.request_id} need not decode after first token"
)
del self.resource_manager.requests[task.request_id]
del self.resource_manager.req_dict[task.request_id]
continue
if task.error_code != 200:
cur_task = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_task.idx] = True
self.resource_manager.tasks_list[cur_task.idx] = None
self.resource_manager._free_blocks(cur_task)
if cur_task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.resource_manager.insert_task_for_decoding(task)
else:
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else: else:
if len(self.waiting_requests): if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}") self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks) self.waiting_requests.extend(tasks)
else: else:
new_waiting = [] new_waiting = []
for task in tasks: for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): can_allocate_resource = False
self.insert_tasks([task]) if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.split_connector.send_cache_infos([task], -1)
can_allocate_resource = True
else: else:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
self.insert_tasks([task])
can_allocate_resource = True
if can_allocate_resource is False:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task) new_waiting.append(task)
if new_waiting: if new_waiting:
self.waiting_requests.extend(new_waiting) if not self.enable_decode_cache_task:
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
else: else:
time.sleep(0.001) time.sleep(0.001)
except Exception as e: except Exception as e:
llm_logger.error(f"Error in main loop: {e}") self.llm_logger.error(f"Error in main loop: {e}")
time.sleep(0.1) time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start() threading.Thread(target=receiver_loop, daemon=True).start()

View File

@@ -120,11 +120,10 @@ class LLMEngine:
self.data_processor = self.input_processor.create_processor() self.data_processor = self.input_processor.create_processor()
self.engine.data_processor = self.data_processor self.engine.data_processor = self.data_processor
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
self.engine.start() self.engine.start()
if api_server_pid is not None:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
if self.do_profile == 0 and ( if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed" self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
@@ -159,11 +158,14 @@ class LLMEngine:
if self.do_profile: if self.do_profile:
self._stop_profile() self._stop_profile()
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1 self.launched_cache_manager_signal.value[0] = 1
if api_server_pid is not None:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
# Worker launched # Worker launched
self.check_worker_initialize_status_func_thread.join() self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]: if not result_container["worker_is_alive"]:
@@ -427,7 +429,10 @@ class LLMEngine:
) )
if self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
variables["FLAGS_use_pd_disaggregation"] = 1 if envs.ENABLE_V1_KVCACHE_SCHEDULER:
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
else:
variables["FLAGS_use_pd_disaggregation"] = 1
# TODO dynamic load environment variable # TODO dynamic load environment variable
if self.cfg.scheduler_config.splitwise_role == "prefill": if self.cfg.scheduler_config.splitwise_role == "prefill":
variables["FLAGS_fmt_write_cache_completed_signal"] = 1 variables["FLAGS_fmt_write_cache_completed_signal"] = 1
@@ -498,6 +503,7 @@ class LLMEngine:
f" --load_choices {self.cfg.load_config.load_choices}" f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --ips {ips}" f" --ips {ips}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
f" --runner {self.cfg.model_config.runner}" f" --runner {self.cfg.model_config.runner}"
f" --convert {self.cfg.model_config.convert}" f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
@@ -625,13 +631,11 @@ class LLMEngine:
if self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
# 单机逻辑 # 单机逻辑
self.engine.engine_worker_queue.available_prefill_instances.put(1) self.engine.engine_worker_queue.available_prefill_instances.put(1)
self.engine.split_mode_get_tasks() self.splitwise_receive_thread = threading.Thread(
if self.cfg.scheduler_config.name == "splitwise": target=self.engine.split_connector.start_receiver, args=()
self.splitwise_receive_thread = threading.Thread( )
target=self.engine.split_connector.start_receiver, args=() self.splitwise_receive_thread.daemon = True
) self.splitwise_receive_thread.start()
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.init_cache_info() self.cfg.init_cache_info()
@@ -640,6 +644,14 @@ class LLMEngine:
disaggregate = self.cfg.disaggregate_info disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise": if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, disaggregate) self.engine.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.engine.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
if not envs.FD_ENABLE_MULTI_API_SERVER: if not envs.FD_ENABLE_MULTI_API_SERVER:
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
@@ -669,6 +681,9 @@ class LLMEngine:
args=( args=(
self.cfg, self.cfg,
i, i,
None,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
), ),
) )
) )

View File

@@ -27,6 +27,7 @@ import numpy as np
from fastdeploy.engine.common_engine import EngineService from fastdeploy.engine.common_engine import EngineService
from fastdeploy.inter_communicator import IPCSignal from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.utils import console_logger, envs, llm_logger from fastdeploy.utils import console_logger, envs, llm_logger
@@ -69,8 +70,12 @@ class ExpertService:
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self._finalizer = weakref.finalize(self, self._exit_sub_services) self._finalizer = weakref.finalize(self, self._exit_sub_services)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.internal_adapter = InternalAdapter(cfg=self.cfg, engine=self.engine, dp_rank=local_data_parallel_id)
def start(self, ipc_signal_suffix, local_data_parallel_id): def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
""" """
Initializes the engine and starts its sub-services. Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread If `api_server_pid` is defined, will launch a thread
@@ -80,6 +85,11 @@ class ExpertService:
start_time = time.time() start_time = time.time()
self.engine.start() self.engine.start()
if self.cfg.scheduler_config.name == "dp":
self.cfg.init_cache_info()
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
if ipc_signal_suffix is not None: if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix) self.engine.start_zmq_service(ipc_signal_suffix)
@@ -88,8 +98,8 @@ class ExpertService:
llm_logger.info(f"start expert service {local_data_parallel_id}") llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix) ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
self.engine.split_mode_get_tasks() self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache)
if self.cfg.scheduler_config.name == "splitwise": if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info() self.cfg.init_cache_info()
@@ -144,14 +154,18 @@ class ExpertService:
self.zmq_server.close() self.zmq_server.close()
def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None): def start_data_parallel_service(
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
""" """
Start expert service Start expert service
""" """
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False) expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
try: try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id) expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
def deamon_thread(): def deamon_thread():
while True: while True:
@@ -159,5 +173,6 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
t_deamon = threading.Thread(target=deamon_thread, daemon=True) t_deamon = threading.Thread(target=deamon_thread, daemon=True)
t_deamon.start() t_deamon.start()
t_deamon.join()
except Exception as e: except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}") llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")

View File

@@ -73,6 +73,7 @@ class Request:
guided_json_object: Optional[bool] = None, guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True, enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(), trace_carrier: dict = dict(),
dp_rank: Optional[int] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
image_start: int = 0, image_start: int = 0,
video_start: int = 0, video_start: int = 0,
@@ -145,6 +146,8 @@ class Request:
# extend block tables # extend block tables
self.use_extend_tables = False self.use_extend_tables = False
self.extend_block_tables = [] self.extend_block_tables = []
# dp
self.dp_rank = dp_rank
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
@@ -187,6 +190,7 @@ class Request:
image_end=d.get("image_end", 0), image_end=d.get("image_end", 0),
video_end=d.get("video_end", 0), video_end=d.get("video_end", 0),
audio_end=d.get("audio_end", 0), audio_end=d.get("audio_end", 0),
dp_rank=d.get("dp_rank", None),
) )
@property @property

View File

@@ -328,8 +328,8 @@ class ResourceManager:
Delete cached data from the task's prompt token ids based on the cached length. Delete cached data from the task's prompt token ids based on the cached length.
""" """
if cached_len == len(task.prompt_token_ids): if cached_len == len(task.prompt_token_ids):
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :] task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
task.seq_lens_decoder = cached_len - 1 task.seq_lens_decoder = cached_len - self.cfg.block_size
else: else:
task.prompt_token_ids = task.prompt_token_ids[cached_len:] task.prompt_token_ids = task.prompt_token_ids[cached_len:]
task.seq_lens_decoder = cached_len task.seq_lens_decoder = cached_len

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import copy
import threading import threading
import time import time
import traceback import traceback
@@ -26,7 +27,7 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
from fastdeploy.engine.request import Request, RequestStatus, RequestType from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger from fastdeploy.utils import llm_logger
@@ -297,6 +298,11 @@ class ResourceManagerV1(ResourceManager):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if (
self.config.scheduler_config.splitwise_role == "prefill"
): # do not need to schedule for decoding
req_index += 1
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1 request.num_computed_tokens = request.num_total_tokens - 1
if ( if (
@@ -400,11 +406,12 @@ class ResourceManagerV1(ResourceManager):
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1) main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1) main_process_metrics.num_requests_running.inc(1)
allocated_position = self.get_available_position() if self.config.scheduler_config.splitwise_role == "mixed":
request.idx = allocated_position allocated_position = self.get_available_position()
self.tasks_list[allocated_position] = request request.idx = allocated_position
self.stop_flags[allocated_position] = False self.tasks_list[allocated_position] = request
self.req_dict[request.request_id] = allocated_position self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position
else: else:
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request) self._free_blocks(request)
@@ -569,6 +576,127 @@ class ResourceManagerV1(ResourceManager):
self.waiting.append(request) self.waiting.append(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
def prerelease_resource(self, request: Request):
"""
Release resource in P or D before finished due to unexpected error.
"""
with self.lock:
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[request.request_id]
del self.req_dict[request.request_id]
self._free_blocks(request)
def add_request_in_p(self, requests: list[Request]):
with self.lock:
for request in requests:
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def preallocate_resource_in_p(self, request: Request):
"""
In P/D aggregated deployment, preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method"
with self.lock:
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.config.cache_config.enable_prefix_caching:
# Enable prefix caching
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
if not self.cache_manager.can_allocate_gpu_blocks(
need_prealloc_prefill_blocks
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
return False
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
return False
# consider for mtp, plus enc_dec_block_num
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
else:
self._free_blocks(request)
return False
else:
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
request.num_computed_tokens = 0
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def preallocate_resource_in_d(self, request: Request):
"""
In P/D aggregated deployment, D should preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
with self.lock:
if len(self.waiting) > 0:
return False
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def insert_task_for_decoding(self, request_output_in_p: RequestOutput):
"""
In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P.
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
with self.lock:
request = self.requests[request_output_in_p.request_id]
request.output_token_ids.append(request_output_in_p.outputs.token_ids[0])
request.num_cached_tokens = request_output_in_p.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output_in_p.outputs.draft_token_ids)
# update request.need_prefill_tokens
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
self.cache_manager.release_block_ids(request) self.cache_manager.release_block_ids(request)
@@ -620,5 +748,7 @@ class ResourceManagerV1(ResourceManager):
self.tasks_list[request.idx] = None self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True self.stop_flags[request.idx] = True
del self.requests[req_id] del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
except Exception as e: except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")

View File

@@ -109,6 +109,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to enable cache task in decode node
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
} }
@@ -120,6 +126,14 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def get_unique_name(self, name):
"""
Get unique name for config
"""
shm_uuid = os.getenv("SHM_UUID", "")
return name + f"_{shm_uuid}"
def __setattr__(name: str, value: Any): def __setattr__(name: str, value: Any):
assert name in environment_variables assert name in environment_variables
environment_variables[name] = lambda: value environment_variables[name] = lambda: value

View File

@@ -84,18 +84,28 @@ class EngineWorkerQueue:
Value("i", 0) for _ in range(self.local_data_parallel_size) Value("i", 0) for _ in range(self.local_data_parallel_size)
] ]
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)] self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
self.client_read_info_flag_init: List[List[int]] = [ self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size) [1] * self.num_client for _ in range(self.local_data_parallel_size)
] ]
self.lock_info_init: List[threading.Lock] = [ self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size) threading.Lock() for _ in range(self.local_data_parallel_size)
] ]
self.connect_task_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.finish_request_barrier = [ self.finish_request_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
] ]
self.finish_add_cache_task_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types # Register shared objects with proxy types
QueueManager.register( QueueManager.register(
"get_tasks", "get_tasks",
@@ -117,6 +127,19 @@ class EngineWorkerQueue:
callable=lambda idx: self.read_finish_flag_init[idx], callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy, proxytype=ValueProxy,
) )
QueueManager.register(
"get_connect_task_lock",
callable=lambda idx: self.connect_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
)
QueueManager.register(
"get_connect_rdma_tasks_responses",
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
proxytype=ListProxy,
)
QueueManager.register( QueueManager.register(
"get_connected_client_counter", "get_connected_client_counter",
callable=lambda idx: self.connected_client_counter_init[idx], callable=lambda idx: self.connected_client_counter_init[idx],
@@ -128,6 +151,11 @@ class EngineWorkerQueue:
callable=lambda idx: self.finished_req_queue[idx], callable=lambda idx: self.finished_req_queue[idx],
) )
QueueManager.register(
"get_finish_add_cache_task_queue",
callable=lambda idx: self.finished_add_cache_task_queue[idx],
)
QueueManager.register( QueueManager.register(
"get_cache_infos", "get_cache_infos",
callable=lambda idx: self.cache_infos_init[idx], callable=lambda idx: self.cache_infos_init[idx],
@@ -161,6 +189,10 @@ class EngineWorkerQueue:
"get_finish_request_barrier", "get_finish_request_barrier",
callable=lambda idx: self.finish_request_barrier[idx], callable=lambda idx: self.finish_request_barrier[idx],
) )
QueueManager.register(
"get_finish_add_cache_task_barrier",
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
)
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey) self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
self.manager.start() self.manager.start()
else: else:
@@ -174,12 +206,17 @@ class EngineWorkerQueue:
QueueManager.register("get_read_finish_flag") QueueManager.register("get_read_finish_flag")
QueueManager.register("get_connected_client_counter") QueueManager.register("get_connected_client_counter")
QueueManager.register("get_finish_request_queue") QueueManager.register("get_finish_request_queue")
QueueManager.register("get_finish_add_cache_task_queue")
QueueManager.register("get_cache_infos") QueueManager.register("get_cache_infos")
QueueManager.register("get_client_read_info_flag") QueueManager.register("get_client_read_info_flag")
QueueManager.register("get_lock_info") QueueManager.register("get_lock_info")
QueueManager.register("get_disaggregate_requests") QueueManager.register("get_disaggregate_requests")
QueueManager.register("get_available_prefill_instances") QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier") QueueManager.register("get_finish_request_barrier")
QueueManager.register("get_finish_add_cache_task_barrier")
QueueManager.register("get_connect_rdma_tasks")
QueueManager.register("get_connect_rdma_tasks_responses")
QueueManager.register("get_connect_task_lock")
self.manager = QueueManager(address=self.address, authkey=self.authkey) self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry() self._connect_with_retry()
@@ -199,7 +236,20 @@ class EngineWorkerQueue:
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id) self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
self.available_prefill_instances = self.manager.get_available_prefill_instances() self.available_prefill_instances = self.manager.get_available_prefill_instances()
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
self.local_data_parallel_id
)
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
self.local_data_parallel_id
)
# p/d互联
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
self.local_data_parallel_id
)
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
assert self.num_client == len(self.client_read_flag) assert self.num_client == len(self.client_read_flag)
if is_server: if is_server:
@@ -281,6 +331,44 @@ class EngineWorkerQueue:
self.lock.release() self.lock.release()
return total_num return total_num
def put_connect_rdma_task(self, connect_rdma_task):
self.connect_task_lock.acquire()
self.connect_rdma_task_queue.append(connect_rdma_task)
self.connect_task_lock.release()
def get_connect_rdma_task(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def put_connect_rdma_task_response(self, connect_rdma_task_response):
self.connect_task_lock.acquire()
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
self.connect_task_lock.release()
def get_connect_rdma_task_response(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_response_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_response_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def get_prefill_instances(self): def get_prefill_instances(self):
""" """
check if the prefill queue is empty check if the prefill queue is empty
@@ -365,6 +453,29 @@ class EngineWorkerQueue:
llm_logger.debug(f"get finished req: {ans}") llm_logger.debug(f"get finished req: {ans}")
return ans return ans
def put_finished_add_cache_task_req(self, req_ids) -> None:
"""
Put finished request ID into the queue.
Args:
req_ids: Request ID to be added to the queue
"""
self.finished_add_cache_task_queue.put(req_ids)
def get_finished_add_cache_task_req(self) -> str:
"""
Get finished request ID from the queue.
Returns:
str: Finished request ID
"""
ans = []
if self.finished_add_cache_task_queue.empty():
return ans
ans = self.finished_add_cache_task_queue.get()
llm_logger.debug(f"get finished req: {ans}")
return ans
def disaggregate_queue_empty(self): def disaggregate_queue_empty(self):
""" """
Check if the disaggregated task queue is empty. Check if the disaggregated task queue is empty.

View File

@@ -211,9 +211,8 @@ class DeepEPEngine:
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size self.num_local_experts = num_experts // ep_size
self.async_finish = async_finish self.async_finish = async_finish
from paddle.base.core import Config
self.ep_config = Config(24, 6, 256) self.ep_config = None
# Store phase and role for buffer management # Store phase and role for buffer management
self._splitwise_role = splitwise_role self._splitwise_role = splitwise_role

View File

@@ -76,6 +76,7 @@ else:
update_inputs, update_inputs,
step_reschedule, step_reschedule,
update_inputs_v1, update_inputs_v1,
speculate_step_reschedule,
) )
@@ -413,12 +414,11 @@ def step_cuda(
""" """
if speculative_config.method is not None: if speculative_config.method is not None:
if enable_prefix_caching: if DISABLE_RECOVER:
speculate_step_system_cache( speculate_step_reschedule(
share_inputs["stop_flags"], share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"], share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"], share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"], share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"], share_inputs["seq_lens_decoder"],
share_inputs["block_tables"], share_inputs["block_tables"],
@@ -444,64 +444,67 @@ def step_cuda(
speculative_config.num_speculative_tokens, speculative_config.num_speculative_tokens,
) )
else: else:
speculate_step_paddle( if enable_prefix_caching:
share_inputs["stop_flags"], speculate_step_system_cache(
share_inputs["seq_lens_this_time"], share_inputs["stop_flags"],
share_inputs["step_seq_lens_encoder"], share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"], share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_decoder"], share_inputs["step_seq_lens_decoder"],
share_inputs["block_tables"], share_inputs["seq_lens_encoder"],
share_inputs["encoder_block_lens"], share_inputs["seq_lens_decoder"],
share_inputs["is_block_step"], share_inputs["block_tables"],
share_inputs["step_block_list"], share_inputs["encoder_block_lens"],
share_inputs["step_lens"], share_inputs["is_block_step"],
share_inputs["recover_block_list"], share_inputs["step_block_list"],
share_inputs["recover_lens"], share_inputs["step_lens"],
share_inputs["need_block_list"], share_inputs["recover_block_list"],
share_inputs["need_block_len"], share_inputs["recover_lens"],
share_inputs["used_list_len"], share_inputs["need_block_list"],
share_inputs["free_list"], share_inputs["need_block_len"],
share_inputs["free_list_len"], share_inputs["used_list_len"],
share_inputs["input_ids"], share_inputs["free_list"],
share_inputs["pre_ids"], share_inputs["free_list_len"],
share_inputs["step_idx"], share_inputs["input_ids"],
share_inputs["next_tokens"], share_inputs["pre_ids"],
share_inputs["first_token_ids"], share_inputs["step_idx"],
share_inputs["accept_num"], share_inputs["next_tokens"],
block_size, share_inputs["first_token_ids"],
enc_dec_block_num, share_inputs["accept_num"],
speculative_config.num_speculative_tokens, block_size,
) enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else: else:
if enable_prefix_caching: if DISABLE_RECOVER:
step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
elif DISABLE_RECOVER:
step_reschedule( step_reschedule(
share_inputs["stop_flags"], share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"], share_inputs["seq_lens_this_time"],
@@ -529,32 +532,61 @@ def step_cuda(
enc_dec_block_num, enc_dec_block_num,
) )
else: else:
step_paddle( if enable_prefix_caching:
share_inputs["stop_flags"], step_system_cache(
share_inputs["seq_lens_this_time"], share_inputs["stop_flags"],
share_inputs["step_seq_lens_encoder"], share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"], share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_decoder"], share_inputs["step_seq_lens_decoder"],
share_inputs["block_tables"], share_inputs["seq_lens_encoder"],
share_inputs["encoder_block_lens"], share_inputs["seq_lens_decoder"],
share_inputs["is_block_step"], share_inputs["block_tables"],
share_inputs["step_block_list"], share_inputs["encoder_block_lens"],
share_inputs["step_lens"], share_inputs["is_block_step"],
share_inputs["recover_block_list"], share_inputs["step_block_list"],
share_inputs["recover_lens"], share_inputs["step_lens"],
share_inputs["need_block_list"], share_inputs["recover_block_list"],
share_inputs["need_block_len"], share_inputs["recover_lens"],
share_inputs["used_list_len"], share_inputs["need_block_list"],
share_inputs["free_list"], share_inputs["need_block_len"],
share_inputs["free_list_len"], share_inputs["used_list_len"],
share_inputs["input_ids"], share_inputs["free_list"],
share_inputs["pre_ids"], share_inputs["free_list_len"],
share_inputs["step_idx"], share_inputs["input_ids"],
share_inputs["next_tokens"], share_inputs["pre_ids"],
share_inputs["first_token_ids"], share_inputs["step_idx"],
block_size, share_inputs["next_tokens"],
enc_dec_block_num, share_inputs["first_token_ids"],
) block_size,
enc_dec_block_num,
)
else:
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
def rebuild_padding( def rebuild_padding(

View File

@@ -58,7 +58,6 @@ class TokenProcessor:
self.split_connector = split_connector self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1: if envs.FD_USE_GET_SAVE_OUTPUT_V1:
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}") llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
self.zmq_server = ZmqIpcServer( self.zmq_server = ZmqIpcServer(
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
@@ -298,10 +297,15 @@ class TokenProcessor:
try: try:
is_blocking = True is_blocking = True
if self.speculative_decoding: if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2: if self.output_tokens[0] == -2:
continue continue
else: else:
if self.use_logprobs: if self.use_logprobs:
get_output_topk( get_output_topk(
@@ -370,14 +374,18 @@ class TokenProcessor:
llm_logger.info(f"finished_task_id: {finished_task_id}") llm_logger.info(f"finished_task_id: {finished_task_id}")
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
if task_id in self.prefill_result_status: if task_id in self.prefill_result_status:
self.split_connector.send_first_token(task.disaggregate_info, [result]) if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager.stop_flags[index] = True self.resource_manager.finish_requests_async(task_id)
self.resource_manager.tasks_list[index] = None else:
self.resource_manager._recycle_block_tables(task) self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
if task_id in self.resource_manager.req_dict:
del self.resource_manager.req_dict[task_id]
if self.prefill_result_status[task_id] != "finished": if self.prefill_result_status[task_id] != "finished":
result.error_code = 400 result.error_code = 400
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
del self.resource_manager.req_dict[task_id] self.split_connector.send_first_token(task.disaggregate_info, [result])
break break
else: else:
time.sleep(0.002) time.sleep(0.002)
@@ -388,6 +396,8 @@ class TokenProcessor:
self.resource_manager.stop_flags[index] = True self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task) self.resource_manager._recycle_block_tables(task)
if task_id in self.resource_manager.req_dict:
del self.resource_manager.req_dict[task_id]
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list]) task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
main_process_metrics.available_gpu_block_num.set( main_process_metrics.available_gpu_block_num.set(
@@ -461,16 +471,22 @@ class TokenProcessor:
task_id = task.request_id task_id = task.request_id
if self.cfg.speculative_config.method: if self.cfg.speculative_config.method:
token_ids = tokens[ if accept_num[i] == -3:
2 recovery_stop = True
+ SPECULATE_MAX_BSZ if recovery_stop:
+ i * MAX_DRAFT_TOKENS : 2 llm_logger.info(f"recovery stop signal found at task {task_id}")
+ SPECULATE_MAX_BSZ token_ids = [RECOVERY_STOP_SIGNAL]
+ i * MAX_DRAFT_TOKENS else:
+ accept_num[i] token_ids = tokens[
].tolist() 2
if len(token_ids) == 0 or token_ids[-1] <= 0: + SPECULATE_MAX_BSZ
continue + i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue
else: else:
token_id = int(tokens[i, 0]) token_id = int(tokens[i, 0])
token_ids = [token_id] token_ids = [token_id]
@@ -527,7 +543,7 @@ class TokenProcessor:
if self.tokens_counter[task_id] == 0: if self.tokens_counter[task_id] == 0:
if task.messages is not None: if task.messages is not None:
result.prompt = task.messages result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
@@ -537,7 +553,8 @@ class TokenProcessor:
for token_id in token_ids: for token_id in token_ids:
self.tokens_counter[task_id] += 1 self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL: if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id) if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id) task.output_token_ids.append(token_id)
if self.use_logprobs: if self.use_logprobs:
result.outputs.logprob = float(scores[i, 0]) result.outputs.logprob = float(scores[i, 0])
@@ -567,7 +584,11 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time) self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill)
break break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise": if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
batch_result.append(result) batch_result.append(result)
self.postprocess(batch_result) self.postprocess(batch_result)
@@ -609,7 +630,7 @@ class TokenProcessor:
self.cfg.speculative_config.num_speculative_tokens, self.cfg.speculative_config.num_speculative_tokens,
) )
real_accept_num = [x for x in accept_num if x != 0] real_accept_num = [x for x in accept_num if x > 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num]) num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num) num_emitted_tokens = sum(real_accept_num)

View File

@@ -18,6 +18,7 @@ import redis
from fastdeploy.utils import llm_logger from fastdeploy.utils import llm_logger
from .dp_scheduler import DPScheduler
from .global_scheduler import GlobalScheduler from .global_scheduler import GlobalScheduler
from .local_scheduler import LocalScheduler from .local_scheduler import LocalScheduler
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
@@ -89,6 +90,54 @@ class LocalSchedulerConfig:
llm_logger.info("=============================================================") llm_logger.info("=============================================================")
class DPLocalSchedulerConfig(LocalSchedulerConfig):
"""
Configuration class for DPLocalScheduler.
Attributes:
max_size: Maximum number of concurrent requests (-1 for unlimited)
ttl: Time-to-live in seconds for request expiration
"""
def __init__(
self,
max_size: int = -1,
ttl: int = 900,
max_model_len: int = 8192,
enable_chunked_prefill: bool = False,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
splitwise_role: str = "prefill",
**kwargs,
):
"""
Initialize LocalScheduler configuration.
Args:
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
ttl: Time-to-live in seconds for request expiration (default 900s)
max_model_len: Maximum model context length in tokens
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Max partial prefill operations allowed
max_long_partial_prefills: Max long-running partial prefill ops
long_prefill_token_threshold: Token count threshold for long prefill
**kwargs: Additional unused arguments (for forward compatibility)
Note:
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
- See LocalScheduler class for implementation details
"""
self.max_size = max_size
self.ttl = ttl
self.max_model_len = max_model_len
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.splitwise_role = splitwise_role
class GlobalSchedulerConfig: class GlobalSchedulerConfig:
""" """
Configuration class for GlobalScheduler (Redis-based). Configuration class for GlobalScheduler (Redis-based).
@@ -235,6 +284,9 @@ class SchedulerConfig:
if self.name == "splitwise": if self.name == "splitwise":
self.config = SplitWiseSchedulerConfig(**args) self.config = SplitWiseSchedulerConfig(**args)
if self.name == "dp":
self.config = DPLocalSchedulerConfig(**args)
def check(self): def check(self):
""" """
Validate the configuration. Validate the configuration.
@@ -242,7 +294,7 @@ class SchedulerConfig:
Raises: Raises:
Exception: If invalid scheduler type is specified Exception: If invalid scheduler type is specified
""" """
if self.name not in ["local", "global", "splitwise"]: if self.name not in ["local", "global", "splitwise", "dp"]:
raise Exception(f"Unknown scheduler type {self.name}") raise Exception(f"Unknown scheduler type {self.name}")
self.config.check() self.config.check()
@@ -280,6 +332,17 @@ class SchedulerConfig:
if self.name == "splitwise": if self.name == "splitwise":
return SplitWiseScheduler(self.config) return SplitWiseScheduler(self.config)
if self.name == "dp":
return DPScheduler(
max_size=self.config.max_size,
ttl=self.config.ttl,
enable_chunked_prefill=self.config.enable_chunked_prefill,
max_num_partial_prefills=self.config.max_num_partial_prefills,
max_long_partial_prefills=self.config.max_long_partial_prefills,
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
splitwise_role=self.config.splitwise_role,
)
return LocalScheduler( return LocalScheduler(
max_size=self.config.max_size, max_size=self.config.max_size,
ttl=self.config.ttl, ttl=self.config.ttl,

View File

@@ -0,0 +1,272 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import logging
import threading
import time
from multiprocessing import Queue
from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import envs, get_logger
class DPLocalScheduler(LocalScheduler):
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
super().__init__(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
)
self.splitwise_role = splitwise_role
self.scheduler_logger = logging
def put_results(self, results: List[RequestOutput]):
"""
Add processing results back to the scheduler.
Args:
results: List of RequestOutput objects containing results
"""
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
finished_responses = [response.request_id for response in responses if response.finished]
if len(finished_responses) > 0:
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
for response in responses:
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
continue
self.responses[response.request_id].append(response)
self.responses_not_empty.notify_all()
def _recycle(self, request_id: Optional[str] = None):
"""
Clean up expired or completed requests to free memory.
Args:
request_id: Optional specific request ID to remove.
If None, removes all expired requests.
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
if self.splitwise_role == "decode":
return
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if now - request.schedule_time < self.ttl:
break
expired_ids.append(request.request_id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
"""
Retrieve requests from the scheduler based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
self.scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
required_total_blocks = 0
current_prefill_tokens = 0
start_batch_time = time.time()
requests: List[Request] = []
with self.requests_not_empty:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(
request.prompt_tokens_ids_len, block_size
)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
else:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
requests.append(request.raw)
self.ids_read_cursor += 1
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:
self.scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
self.scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
return requests
class DPScheduler:
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
self._scheduler = DPLocalScheduler(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
splitwise_role,
)
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queue = result_queue
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
threading.Thread(target=self._get_response_from_local).start()
def put_requests(self, requests: List[Dict]):
results = []
for request in requests:
if not hasattr(request, "dp_rank"):
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
self.request_queues[request.dp_rank].put(request)
results.append((request.request_id, None))
return results
def _put_requests_to_local(self):
while True:
request = self.request_queues[self.dp_rank].get()
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
self._scheduler.put_requests([request])
def _get_response_from_local(self):
while True:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queue.put(results)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
return self._scheduler.get_requests(
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
)
def get_unhandled_request_num(self):
return len(self._scheduler.requests)
def put_results(self, results: List[RequestOutput]):
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queue.get()

View File

@@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
class SplitwiseConnector: class SplitwiseConnector:
""" """
@@ -46,12 +44,19 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object. resource_manager (object): Resource manager object.
""" """
self.cfg = cfg self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.engine_worker_queue = worker_queue self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager self.resource_manager = resource_manager
self.connect_innode_instances = {} self.connect_innode_instances = {}
self.temp_cache_info = dict() self.temp_cache_info = dict()
self.current_request_ids = dict() self.current_request_ids = dict()
self.idx = self.cfg.parallel_config.local_data_parallel_id self.idx = self.cfg.parallel_config.local_data_parallel_id
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None: if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context() self.zmq_ctx = zmq.Context()
@@ -70,7 +75,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000) self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}") self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}") self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
@@ -90,17 +95,17 @@ class SplitwiseConnector:
if not socks: if not socks:
continue continue
else: else:
logger.debug(f"receive {socks}") self.logger.debug(f"receive {socks}")
frames = self.router_socket.recv_multipart() frames = self.router_socket.recv_multipart()
logger.debug(f"frames: {frames}") self.logger.debug(f"frames: {frames}")
message = frames[-1] message = frames[-1]
self.io_executor.submit(self._process_message, message) self.io_executor.submit(self._process_message, message)
time.sleep(0.001) time.sleep(0.001)
else: else:
time.sleep(5) time.sleep(5)
except Exception as e: except Exception as e:
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}") self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
time.sleep(1) time.sleep(1)
def _get_push_socket(self, addr): def _get_push_socket(self, addr):
@@ -112,7 +117,7 @@ class SplitwiseConnector:
return sock return sock
try: try:
logger.info(f"Establishing new connection to {addr}") self.logger.info(f"Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER) sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数 # 设置连接参数
@@ -131,7 +136,7 @@ class SplitwiseConnector:
return sock return sock
except zmq.ZMQError as e: except zmq.ZMQError as e:
logger.error(f"Connection to {addr} failed: {e}") self.logger.error(f"Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e raise ConnectionError(f"Failed to connect to {addr}") from e
@@ -140,7 +145,7 @@ class SplitwiseConnector:
return return
try: try:
logger.info(f"Sent {msg_type} to {addr}") self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload) message = self._serialize_message(msg_type, payload)
try: try:
@@ -148,19 +153,19 @@ class SplitwiseConnector:
sock = self._get_push_socket(addr) sock = self._get_push_socket(addr)
sock.send_multipart([b"", message]) sock.send_multipart([b"", message])
logger.info(f"Sent {msg_type} to {addr}") self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError: except ConnectionError:
logger.warning(f"Connection to {addr} not established") self.logger.warning(f"Connection to {addr} not established")
except zmq.Again: except zmq.Again:
logger.warning(f"Send queue full for {addr}") self.logger.warning(f"Send queue full for {addr}")
except Exception as e: except Exception as e:
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}") self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
main_process_metrics.send_cache_failed_num.inc() main_process_metrics.send_cache_failed_num.inc()
self._close_connection(addr) self._close_connection(addr)
except Exception as e: except Exception as e:
logger.error(f"Message preparation failed: {e}") self.logger.error(f"Message preparation failed: {e}")
def _close_connection(self, addr): def _close_connection(self, addr):
""" """
@@ -265,7 +270,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}" + f"{task.disaggregate_info['cache_info']['rdma']['port']}"
) )
logger.info(f"send splitwise tasks to port {addr} decode") self.logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init" self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"] decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -295,7 +300,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks: for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode") self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port current_port = port
return current_port return current_port
@@ -305,7 +310,7 @@ class SplitwiseConnector:
""" """
if not isinstance(tasks_list, list): if not isinstance(tasks_list, list):
tasks_list = [tasks_list] tasks_list = [tasks_list]
logger.info("send first token to port decode") self.logger.info("send first token to port decode")
if prefill_msg["transfer_protocol"] == "ipc": if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"] port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances: if port not in self.connect_innode_instances:
@@ -313,7 +318,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else: else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode") self.logger.info(f"send first token to port {node} decode")
self._send_message(node, "decode", tasks_list) self._send_message(node, "decode", tasks_list)
def create_connection(self, port): def create_connection(self, port):
@@ -329,6 +334,26 @@ class SplitwiseConnector:
client_id=0, client_id=0,
) )
def check_decode_allocated(self, task):
start_time = time.time()
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
if time.time() - start_time > 30:
del self.current_request_ids[task.request_id]
return False, "timeout"
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id): def send_cache_infos(self, tasks, current_id):
""" """
Send cache information to specific port. Send cache information to specific port.
@@ -345,7 +370,7 @@ class SplitwiseConnector:
for i in range(len(tasks)): for i in range(len(tasks)):
if tasks[i].disaggregate_info is None: if tasks[i].disaggregate_info is None:
continue continue
logger.info(f"{tasks[i].disaggregate_info}") self.logger.info(f"{tasks[i].disaggregate_info}")
if tasks[i].disaggregate_info["role"] == "decode": if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc": if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
cache_info = { cache_info = {
@@ -380,11 +405,19 @@ class SplitwiseConnector:
addr = "prefill" addr = "prefill"
if current_id == -1: if current_id == -1:
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"] current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
cache_info = { if envs.ENABLE_V1_KVCACHE_SCHEDULER:
"request_id": tasks[i].request_id, cache_info = {
"src_block_ids": tasks[i].block_tables, "request_id": tasks[i].request_id,
"current_id": current_id, "src_block_ids": tasks[i].block_tables,
} "current_id": tasks[i].idx,
"need_prefill_tokens": tasks[i].need_prefill_tokens,
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
"current_id": current_id,
}
if addr not in temp_cache_info: if addr not in temp_cache_info:
temp_cache_info[addr] = [] temp_cache_info[addr] = []
@@ -396,7 +429,7 @@ class SplitwiseConnector:
else: else:
if len(temp_cache_info): if len(temp_cache_info):
for k, v in temp_cache_info.items(): for k, v in temp_cache_info.items():
logger.info(f"{k} {v}") self.logger.info(f"{k} {v}")
if ":" in str(k): if ":" in str(k):
self._send_message(k, "cache_sync", v) self._send_message(k, "cache_sync", v)
else: else:
@@ -427,7 +460,7 @@ class SplitwiseConnector:
""" """
try: try:
msg_type, payload = self._deserialize_message(message) msg_type, payload = self._deserialize_message(message)
logger.info(f"{msg_type}") self.logger.info(f"{msg_type}")
if msg_type == "prefill": if msg_type == "prefill":
self._handle_prefill(payload) self._handle_prefill(payload)
@@ -435,11 +468,16 @@ class SplitwiseConnector:
self._handle_decode(payload) self._handle_decode(payload)
elif msg_type == "cache_sync": elif msg_type == "cache_sync":
for task in payload: for task in payload:
del self.current_request_ids[task["request_id"]] self.logger.info(f"cache_sync task: {task}")
self.engine_worker_queue.put_cache_info(payload) current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
del self.current_request_ids[task["request_id"]]
if current_status == "finished":
self.engine_worker_queue.put_cache_info(payload)
except Exception as e: except Exception as e:
logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}") self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
def _handle_prefill(self, tasks): def _handle_prefill(self, tasks):
""" """
@@ -462,8 +500,12 @@ class SplitwiseConnector:
index=task["outputs"]["index"], index=task["outputs"]["index"],
send_idx=0, send_idx=0,
token_ids=task["outputs"]["token_ids"], token_ids=task["outputs"]["token_ids"],
draft_token_ids=task["outputs"]["draft_token_ids"],
), ),
finished=True, finished=True,
num_cached_tokens=task["num_cached_tokens"],
error_code=task["error_code"],
error_msg=task["error_msg"],
) )
) )
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks)) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))

View File

@@ -16,6 +16,7 @@
import argparse import argparse
import json import json
import os
import time import time
from typing import Tuple from typing import Tuple
@@ -259,6 +260,7 @@ class PaddleDisWorkerProc:
"""Main event loop for Paddle Distributed Workers. """Main event loop for Paddle Distributed Workers.
TODO(gongshaotian): support remote calling of functions that control worker. TODO(gongshaotian): support remote calling of functions that control worker.
""" """
# Currently, only support single node # Currently, only support single node
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8) self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
req_ids = [] req_ids = []
@@ -643,6 +645,12 @@ def parse_args():
help="Flag to specify dtype of lm_head as FP32", help="Flag to specify dtype of lm_head as FP32",
) )
parser.add_argument(
"--cache-transfer-protocol",
type=str,
default="ipc",
help="support protocol list, comma separated, default is ipc",
)
parser.add_argument( parser.add_argument(
"--runner", "--runner",
type=str, type=str,
@@ -762,8 +770,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
): ):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.") logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if args.splitwise_role != "mixed": if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda(): if not current_platform.is_cuda():
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.") logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
@@ -772,6 +779,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.") logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill":
os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"
fd_config = FDConfig( fd_config = FDConfig(
model_config=model_config, model_config=model_config,
parallel_config=parallel_config, parallel_config=parallel_config,