[Feature] Support pd ep deployment with yiyan adapter (#4029)

* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* Support pd ep deployment with yiyan adapter

* Support pd ep deployment with yiyan adapter

* refactor cache messager

* support scheduler v1 in PD

* suppport pd v1 + chunk prefill

* suppport pd v1 + chunk prefill

* add eplb

* support eplb

* support eplb

* support eplb

* support v1

* fix

* fix

* fix bug

* remove eplb support

* support prefix cache in P

* fix bug

* fix bug

* support one stop in V1

* fix bug

* fix ci

* fix ci

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-22 16:41:38 +08:00
committed by GitHub
parent 9845f0d010
commit 918ccdb123
22 changed files with 1838 additions and 343 deletions

View File

@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
const int block_size,
bool prefill_one_step_stop) {
int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
@@ -54,6 +55,14 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
seq_lens_encoder[thread_idx] = 0;
} else {
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
if (prefill_one_step_stop) {
// prefill done, stop
stop_flags[thread_idx] = true;
seq_lens_this_time[thread_idx] = 0;
seq_lens_decoder[thread_idx] = 0;
seq_lens_encoder[thread_idx] = 0;
stop_flag_now_int = 1;
} else{
// decoding
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
seq_lens_this_time[thread_idx] = 1;
@@ -72,6 +81,7 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
seq_lens_decoder[thread_idx] = 0;
stop_flag_now_int = 1;
}
}
} else
{
stop_flags[thread_idx] = true;
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
#else
auto cu_stream = input_ids.stream();
#endif
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
block_size,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());

View File

@@ -14,7 +14,10 @@
# limitations under the License.
"""
import argparse
import json
import math
import queue
import threading
import time
import traceback
@@ -23,16 +26,72 @@ import numpy as np
import paddle
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import (
EngineWorkerQueue,
IPCSignal,
shared_memory_exists,
)
from fastdeploy.utils import get_logger
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
"""
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache Messager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port",
)
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
default="{}",
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args()
return args
class CacheMessager:
"""
CacheMessager is used to send the cache data between the engine worker and the cache server.
@@ -69,11 +128,6 @@ class CacheMessager:
Returns:
None
"""
assert splitwise_role in [
"prefill",
"decode",
], "splitwise_role must be prefill or decode"
self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
@@ -147,15 +201,16 @@ class CacheMessager:
self.gpu_id = gpu_id
self.cache_info = dict()
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
self.rank_id = self.rank + local_data_parallel_id * self.nranks
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start()
if self.splitwise_role != "mixed":
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True
connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}")
def _prefill_layerwise_send_cache_thread(self):
def prefill_layerwise_send_cache_thread(self):
"""
layerwise_send_cache_thread:
send cache to other instance
@@ -163,23 +218,23 @@ class CacheMessager:
try:
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}"
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=not shared_memory_exists(prefilled_step_name),
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=not shared_memory_exists(prefilled_layer_name),
)
logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}")
logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")
step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1
@@ -187,6 +242,9 @@ class CacheMessager:
self.last_step_idx = -1
self.last_layer_idx = -1 # int32
max_step_idx = 100003
engine_recycled_count = 0
while True:
cache_info = self.engine_worker_queue.get_cache_info()
@@ -202,11 +260,9 @@ class CacheMessager:
-len(current_info["dest_block_ids"]) :
]
current_info["src_block_ids"] = current_src_blocks
current_info["current_layer_ids"] = 0
current_info["status"] = "init"
logger.info(f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
else:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
@@ -223,7 +279,18 @@ class CacheMessager:
if not self.cache_info:
time.sleep(0.001)
continue
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
if self.last_step_idx > prefilled_step_idx:
engine_recycled_count += 1
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
prefilled_step_idx = (
prefilled_step_idx + max_step_idx * engine_recycled_count
) # remap prefilled_step_idx for comparison
logger.debug(
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
)
for req_id, item in list(self.cache_info.items()):
if "status" not in item:
continue
@@ -294,12 +361,493 @@ class CacheMessager:
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id]
self.last_step_idx = prefilled_step_idx
self.last_layer_idx = prefilled_layer_idx
except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
class CacheMessagerV1:
"""
CacheMessager is used to send the cache data between the engine worker and the cache server.
"""
def __init__(
self,
splitwise_role,
transfer_protocol,
pod_ip,
engine_worker_queue_port,
local_data_parallel_id,
gpu_cache_kvs,
rank,
nranks,
num_layers,
gpu_id=0,
block_size=64,
rdma_port=None,
):
"""
Initialize the CacheMessager object.
Args:
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
transfer_protocol (str): support ipc and rdma
engine_worker_queue_port (int): engine_worker_queue port
gpu_cache_kvs (dict): GPU kv cache
rank (int): current rank
nranks (int): global rank number
num_layers (int): model layer number
gpu_id (int, optional): GPU ID
rdma_port (int, optional): RDMA port
Returns:
None
"""
self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
self.nranks = nranks
address = (pod_ip, engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.nranks,
client_id=self.rank,
local_data_parallel_id=local_data_parallel_id,
)
self.block_size = block_size
transfer_protocol = transfer_protocol.split(",")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers
cache_k_ptr_list = []
cache_v_ptr_list = []
cache_k = []
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
# 2. initialize the block_bytes
cache_shape = key_cache.shape
max_block_num = cache_shape[0]
block_bytes = math.prod(cache_shape[1:])
if key_cache.dtype == paddle.bfloat16:
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
)
self.block_bytes = block_bytes
# 3. initialize the messager
for protocol in transfer_protocol:
if protocol == "ipc":
self.messager[protocol] = IPCCommManager(
self.rank,
gpu_id,
cache_k,
cache_v,
)
local_device_id = int(str(cache_k[0].place)[-2])
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
elif protocol == "rdma":
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
self.messager[protocol] = RDMACommManager(
splitwise_role,
rank,
gpu_id,
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
rdma_port,
)
self.gpu_id = gpu_id
self.cache_info = dict()
self.rank_id = self.rank + local_data_parallel_id * self.nranks
self.engine_cache_task_thread_lock = threading.Lock()
self.engine_cache_tasks = [dict() for _ in range(512)]
self.idx_cache_task_dict = {}
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
if splitwise_role == "prefill":
consume_signals_thread = threading.Thread(target=self.consume_signals)
consume_signals_thread.daemon = True
consume_signals_thread.start()
add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
add_cache_task_thread.daemon = True
add_cache_task_thread.start()
if self.splitwise_role != "mixed":
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True
connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}")
def _add_cache_task_thread(self):
while True:
try:
cache_info = self.engine_worker_queue.get_cache_info()
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
finished_add_cache_task_req_ids = []
if cache_info:
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]]
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
finished_add_cache_task_req_ids.append(info["request_id"])
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
current_info["dest_block_ids"]
)
padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
"dest_block_ids"
]
current_info["dest_block_ids"] = padding_decode_block_ids
current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
self.cache_info[info["request_id"]] = info
if self.rank == 0 and finished_add_cache_task_req_ids:
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
else:
time.sleep(0.001)
except Exception as e:
logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.")
def prefill_layerwise_send_cache_thread(self):
"""
layerwise_send_cache_thread:
send cache to other instance
"""
while True:
try:
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.finish_request_barrier.wait()
block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index in engine_indexes:
assert engine_index in self.idx_cache_task_dict
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
if (
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
): # all chunks have been prefilled
block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
else:
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
block_start_end_list.append((block_id_start, block_id_end))
current_prefilled_token_num_list.append(prefilled_token_num)
while True: # from layer0 to last layer
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
start_layer_idx = sended_layer_idx + 1
with self.engine_cache_task_thread_lock: # to check end_layer_idx
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
logger.info(
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
)
assert (
current_prefilled_token_num_list[0]
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
else:
end_layer_idx = prefilled_layer_idx
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
time.sleep(0.01)
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
engine_index = engine_indexes[i]
task = self.idx_cache_task_dict[engine_index]
req_id = task["request_id"]
if (
block_id_start >= block_id_end
): # no blocks need to transfer for this request in this chunk
task["sended_layer_id"] += 1
assert task["sended_layer_id"] == layer_idx
if task["sended_layer_id"] == self.num_layers - 1:
task["sended_layer_id"] = -1
continue
else:
current_transfer_protocol = task["transfer_protocol"]
if task["transfer_protocol"] == "rdma":
target_ip = task["ip"]
target_id = int(task["rdma_ports"][self.rank])
if task["status"] == "error":
continue
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
logger.error(f"connect to {target_ip}:{target_id} failed")
task["status"] = "connection error"
continue
elif task["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(task["device_ids"][self.rank])
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
logger.info(
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
)
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
if return_code != 0:
task["status"] = "write cache error"
logger.error(
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
)
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
)
task["sended_layer_id"] += 1
assert task["sended_layer_id"] == layer_idx
if task["sended_layer_id"] == self.num_layers - 1:
self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
block_id_end - block_id_start
)
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
if task["status"] != "error":
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
)
else:
task["sended_layer_id"] = -1
if end_layer_idx == self.num_layers - 1:
with self.engine_cache_task_thread_lock:
for engine_idx in engine_indexes:
task = self.idx_cache_task_dict[engine_idx]
if task["status"] == "finished" or ("error" in task["status"]):
target_id = int(task["rdma_ports"][self.rank])
if task["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
if self.rank == 0:
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req(
[(task["request_id"], task["status"])]
)
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
break
except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
time.sleep(0.01)
def consume_signals(self):
paddle.device.set_device("cpu")
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
while True:
try:
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
if not self.cache_info:
time.sleep(0.01)
continue
tasks_count = kv_signal_data[0]
if tasks_count == -1:
time.sleep(0.001)
continue
layer_id = kv_signal_data[1].numpy().tolist()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
batch_engine_ids = []
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
chuck_token_offset + current_seq_len
)
batch_engine_ids.append(engine_idx)
if layer_id == 0:
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
except Exception as e:
logger.error(f"Consume signals get exception: {e}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
def main():
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
gpu_cache_kvs = {}
gpu_cache_k_tensors = []
gpu_cache_v_tensors = []
for i in range(args.num_layers + num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
cache_messager = CacheMessagerV1(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=gpu_cache_kvs,
rank=rank,
nranks=args.mp_num,
num_layers=args.num_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
)
else:
cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=gpu_cache_kvs,
rank=rank,
nranks=args.mp_num,
num_layers=args.num_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
cache_ready_signal.value[rank] = 1
if args.splitwise_role == "mixed":
while True:
time.sleep(1)
cache_messager.prefill_layerwise_send_cache_thread()
if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -29,7 +29,7 @@ from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger
@@ -139,40 +139,27 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
cache_shape = [
args.num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
]
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape[0] = num_gpu_blocks
key_name = f"key_caches_{i}_rank{rank}.device{device}"
value_name = f"value_caches_{i}_rank{rank}.device{device}"
key_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_name, cache_shape)
value_cache = share_external_data(value_cache, value_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[value_name] = value_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -201,28 +188,6 @@ class CacheTransferManager:
)
self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager
self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal",
@@ -443,5 +408,7 @@ def main():
if __name__ == "__main__":
args = parse_args()
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
paddle.set_device(f"gpu:{args.device_id}")
main()

View File

@@ -150,6 +150,19 @@ class PrefixCacheManager:
filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = []
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
@@ -213,7 +226,76 @@ class PrefixCacheManager:
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
return cache_manager_processes
all_cache_processes = cache_messager_processes + cache_manager_processes
return all_cache_processes
def launch_cache_messager(
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
):
"""
launch_cache_messager function used to initialize the cache messager.
"""
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_messager.py"
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --pod_ip {pod_ip}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --block_size {cache_config.block_size}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_messager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache messager successful")
else:
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
cache_messager_processes = None
return cache_messager_processes
def update_cache_config(self, cache_config):
"""

View File

@@ -61,18 +61,12 @@ class RDMACommManager:
Connect to remote gpu and write cache.
"""
assert self.splitwise_role == "prefill", "only prefill can call this method"
addr = f"{ip}:{port!s}"
if addr in self.connected_rdma:
return True
ret = self.messager.is_connected(ip, str(port))
if ret:
self.connected_rdma.add(addr)
return True
ret = self.messager.connect(ip, str(port))
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
if ret == 0:
self.connected_rdma.add(addr)
return ret == 0
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):

View File

@@ -1481,7 +1481,7 @@ class FDConfig:
self.model_config.model_format = "torch"
# TODO
self.max_prefill_batch = 3
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
if current_platform.is_xpu():
self.max_prefill_batch = 1
if self.model_config is not None and self.model_config.enable_mm:

View File

@@ -422,7 +422,7 @@ class EngineArgs:
raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None:
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.splitwise_role != "mixed":
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda():
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

View File

@@ -46,7 +46,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, envs, llm_logger
from fastdeploy.utils import EngineError, envs, get_logger, llm_logger
try:
TokenProcessor = load_token_processor_plugins()
@@ -69,6 +69,13 @@ class EngineService:
"""
self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel:
self.llm_logger = get_logger(
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.llm_logger = llm_logger
self.scheduler = cfg.scheduler_config.scheduler()
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
@@ -79,10 +86,6 @@ class EngineService:
cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
if cfg.scheduler_config.splitwise_role != "mixed":
raise NotImplementedError(
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
)
else:
self.resource_manager = ResourceManager(
cfg.scheduler_config.max_num_seqs,
@@ -135,12 +138,14 @@ class EngineService:
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.split_mode_get_tasks()
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
llm_logger.info(f"current_suffix: {current_suffix}")
self.llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
@@ -201,7 +206,7 @@ class EngineService:
)
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
llm_logger.info(f"Starting engine worker queue server service at {address}")
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
@@ -225,7 +230,7 @@ class EngineService:
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
llm_logger.info(
self.llm_logger.info(
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
)
self.engine_worker_queue = EngineWorkerQueue(
@@ -254,7 +259,17 @@ class EngineService:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
continue
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
cur_task.num_cached_tokens = task.num_cached_tokens
if (
self.cfg.speculative_config.method in ["mtp"]
and self.cfg.scheduler_config.splitwise_role == "decode"
@@ -267,12 +282,13 @@ class EngineService:
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
@@ -281,13 +297,34 @@ class EngineService:
if not isinstance(tasks, list):
tasks = [tasks]
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks:
item.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
@@ -296,7 +333,7 @@ class EngineService:
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg)
self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
@@ -314,7 +351,7 @@ class EngineService:
self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
task.inference_start_time = time.time()
if not is_prefill:
@@ -473,7 +510,7 @@ class EngineService:
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
current_id = 0
while getattr(self, "running", True):
try:
if self.resource_manager.available_batch() == 0:
@@ -514,18 +551,21 @@ class EngineService:
time.sleep(0.001)
continue
current_id = (current_id + 1) % 100003
if self.cfg.scheduler_config.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
self.insert_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
self.llm_logger.error(err_msg)
def _scheduler_task_to_worker_v1(self):
"""
@@ -535,6 +575,7 @@ class EngineService:
is_fetching = False
def _fetch_request():
try:
nonlocal is_fetching
is_fetching = True
num_prefill_batch = min(
@@ -553,16 +594,75 @@ class EngineService:
max_num_batched_tokens=self.cfg.max_model_len,
batch=num_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.prerelease_resource(task)
if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager
if tasks:
self.split_connector.send_cache_infos(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids:
for req_id in req_ids:
assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id)
else:
time.sleep(0.001)
# Fetch requests and add them to the scheduling queue
if tasks:
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
else:
for task in tasks:
self.resource_manager.add_request(task)
is_fetching = False
except Exception as e:
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
is_fetching = False
while self.running:
try:
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
not is_fetching
):
get_request_pool.submit(_fetch_request)
else:
if (
len(self.resource_manager.waiting) == 0
and (not is_fetching)
@@ -579,8 +679,8 @@ class EngineService:
time.sleep(0.005)
except Exception as e:
err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
llm_logger.error(err_msg)
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
self.llm_logger.error(err_msg)
def start_zmq_service(self, api_server_pid=None):
if api_server_pid is None:
@@ -608,6 +708,9 @@ class EngineService:
def _insert_zmq_task_to_scheduler(self):
added_requests: Dict[str, int] = dict()
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.scheduler_config.splitwise_role == "decode":
return
while self.running:
try:
block = True if len(added_requests) == 0 else False
@@ -616,7 +719,7 @@ class EngineService:
else:
err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
break
request, insert_task = None, []
@@ -627,16 +730,16 @@ class EngineService:
request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc()
llm_logger.debug(f"Receive request: {request}")
self.llm_logger.debug(f"Receive request: {request}")
except Exception as e:
llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
err_msg = str(e)
results.append((data["request_id"], err_msg))
if self.guided_decoding_checker is not None and err_msg is None:
request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None:
llm_logger.error(f"Receive request error: {err_msg}")
self.llm_logger.error(f"Receive request error: {err_msg}")
results.append((request.request_id, err_msg))
if err_msg is None:
@@ -670,7 +773,7 @@ class EngineService:
# Send result by zmq directly
self.send_response_server.send_response(request_id, [error_result])
except Exception as e:
llm_logger.error(
self.llm_logger.error(
f"Error happened while receiving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}"
)
@@ -689,7 +792,7 @@ class EngineService:
self.send_response_server.send_response(request_id, contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}")
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def split_mode_get_tasks(self):
"""
@@ -702,12 +805,21 @@ class EngineService:
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
llm_logger.info(f"Resource available, processing task {task.request_id}")
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_infos([task], -1)
processed_indices.append(idx)
else:
llm_logger.debug(f"Still waiting for resources {task.request_id}")
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
else:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
@@ -730,32 +842,79 @@ class EngineService:
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for task in tasks:
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if (
not task.outputs.token_ids
): # first token is eos in Prefill, just recycle resource and continue
cur_task = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_task.idx] = True
self.resource_manager.tasks_list[cur_task.idx] = None
self.resource_manager._free_blocks(cur_task)
if cur_task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(
f"{task.request_id} need not decode after first token"
)
del self.resource_manager.requests[task.request_id]
del self.resource_manager.req_dict[task.request_id]
continue
if task.error_code != 200:
cur_task = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_task.idx] = True
self.resource_manager.tasks_list[cur_task.idx] = None
self.resource_manager._free_blocks(cur_task)
if cur_task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.resource_manager.insert_task_for_decoding(task)
else:
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
can_allocate_resource = False
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.split_connector.send_cache_infos([task], -1)
can_allocate_resource = True
else:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
self.insert_tasks([task])
can_allocate_resource = True
if can_allocate_resource is False:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
self.llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
else:
time.sleep(0.001)
except Exception as e:
llm_logger.error(f"Error in main loop: {e}")
self.llm_logger.error(f"Error in main loop: {e}")
time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start()

View File

@@ -120,11 +120,10 @@ class LLMEngine:
self.data_processor = self.input_processor.create_processor()
self.engine.data_processor = self.data_processor
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
self.engine.start()
if api_server_pid is not None:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
@@ -159,11 +158,14 @@ class LLMEngine:
if self.do_profile:
self._stop_profile()
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
if api_server_pid is not None:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
# Worker launched
self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]:
@@ -427,6 +429,9 @@ class LLMEngine:
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
else:
variables["FLAGS_use_pd_disaggregation"] = 1
# TODO dynamic load environment variable
if self.cfg.scheduler_config.splitwise_role == "prefill":
@@ -498,6 +503,7 @@ class LLMEngine:
f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
f" --runner {self.cfg.model_config.runner}"
f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
@@ -625,8 +631,6 @@ class LLMEngine:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# 单机逻辑
self.engine.engine_worker_queue.available_prefill_instances.put(1)
self.engine.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
self.splitwise_receive_thread = threading.Thread(
target=self.engine.split_connector.start_receiver, args=()
)
@@ -640,6 +644,14 @@ class LLMEngine:
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.engine.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
if not envs.FD_ENABLE_MULTI_API_SERVER:
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
@@ -669,6 +681,9 @@ class LLMEngine:
args=(
self.cfg,
i,
None,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
),
)
)

View File

@@ -27,6 +27,7 @@ import numpy as np
from fastdeploy.engine.common_engine import EngineService
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.utils import console_logger, envs, llm_logger
@@ -69,8 +70,12 @@ class ExpertService:
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self._finalizer = weakref.finalize(self, self._exit_sub_services)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.internal_adapter = InternalAdapter(cfg=self.cfg, engine=self.engine, dp_rank=local_data_parallel_id)
def start(self, ipc_signal_suffix, local_data_parallel_id):
def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
@@ -80,6 +85,11 @@ class ExpertService:
start_time = time.time()
self.engine.start()
if self.cfg.scheduler_config.name == "dp":
self.cfg.init_cache_info()
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix)
@@ -88,8 +98,8 @@ class ExpertService:
llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
self.engine.split_mode_get_tasks()
ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache)
if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info()
@@ -144,14 +154,18 @@ class ExpertService:
self.zmq_server.close()
def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None):
def start_data_parallel_service(
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
def deamon_thread():
while True:
@@ -159,5 +173,6 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
t_deamon.start()
t_deamon.join()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")

View File

@@ -73,6 +73,7 @@ class Request:
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(),
dp_rank: Optional[int] = None,
chat_template: Optional[str] = None,
image_start: int = 0,
video_start: int = 0,
@@ -145,6 +146,8 @@ class Request:
# extend block tables
self.use_extend_tables = False
self.extend_block_tables = []
# dp
self.dp_rank = dp_rank
@classmethod
def from_dict(cls, d: dict):
@@ -187,6 +190,7 @@ class Request:
image_end=d.get("image_end", 0),
video_end=d.get("video_end", 0),
audio_end=d.get("audio_end", 0),
dp_rank=d.get("dp_rank", None),
)
@property

View File

@@ -328,8 +328,8 @@ class ResourceManager:
Delete cached data from the task's prompt token ids based on the cached length.
"""
if cached_len == len(task.prompt_token_ids):
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
task.seq_lens_decoder = cached_len - 1
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
task.seq_lens_decoder = cached_len - self.cfg.block_size
else:
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
task.seq_lens_decoder = cached_len

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import copy
import threading
import time
import traceback
@@ -26,7 +27,7 @@ from typing import Union
import numpy as np
import paddle
from fastdeploy.engine.request import Request, RequestStatus, RequestType
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
@@ -297,6 +298,11 @@ class ResourceManagerV1(ResourceManager):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if (
self.config.scheduler_config.splitwise_role == "prefill"
): # do not need to schedule for decoding
req_index += 1
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if (
@@ -400,6 +406,7 @@ class ResourceManagerV1(ResourceManager):
request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1)
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[allocated_position] = request
@@ -569,6 +576,127 @@ class ResourceManagerV1(ResourceManager):
self.waiting.append(request)
self.requests[request.request_id] = request
def prerelease_resource(self, request: Request):
"""
Release resource in P or D before finished due to unexpected error.
"""
with self.lock:
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[request.request_id]
del self.req_dict[request.request_id]
self._free_blocks(request)
def add_request_in_p(self, requests: list[Request]):
with self.lock:
for request in requests:
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def preallocate_resource_in_p(self, request: Request):
"""
In P/D aggregated deployment, preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method"
with self.lock:
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.config.cache_config.enable_prefix_caching:
# Enable prefix caching
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
if not self.cache_manager.can_allocate_gpu_blocks(
need_prealloc_prefill_blocks
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
return False
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
return False
# consider for mtp, plus enc_dec_block_num
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
else:
self._free_blocks(request)
return False
else:
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
request.num_computed_tokens = 0
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def preallocate_resource_in_d(self, request: Request):
"""
In P/D aggregated deployment, D should preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
with self.lock:
if len(self.waiting) > 0:
return False
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def insert_task_for_decoding(self, request_output_in_p: RequestOutput):
"""
In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P.
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
with self.lock:
request = self.requests[request_output_in_p.request_id]
request.output_token_ids.append(request_output_in_p.outputs.token_ids[0])
request.num_cached_tokens = request_output_in_p.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output_in_p.outputs.draft_token_ids)
# update request.need_prefill_tokens
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.release_block_ids(request)
@@ -620,5 +748,7 @@ class ResourceManagerV1(ResourceManager):
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")

View File

@@ -109,6 +109,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to enable cache task in decode node
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
}
@@ -120,6 +126,14 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def get_unique_name(self, name):
"""
Get unique name for config
"""
shm_uuid = os.getenv("SHM_UUID", "")
return name + f"_{shm_uuid}"
def __setattr__(name: str, value: Any):
assert name in environment_variables
environment_variables[name] = lambda: value

View File

@@ -84,18 +84,28 @@ class EngineWorkerQueue:
Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.connect_task_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.finish_request_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.finish_add_cache_task_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types
QueueManager.register(
"get_tasks",
@@ -117,6 +127,19 @@ class EngineWorkerQueue:
callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy,
)
QueueManager.register(
"get_connect_task_lock",
callable=lambda idx: self.connect_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
)
QueueManager.register(
"get_connect_rdma_tasks_responses",
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_connected_client_counter",
callable=lambda idx: self.connected_client_counter_init[idx],
@@ -128,6 +151,11 @@ class EngineWorkerQueue:
callable=lambda idx: self.finished_req_queue[idx],
)
QueueManager.register(
"get_finish_add_cache_task_queue",
callable=lambda idx: self.finished_add_cache_task_queue[idx],
)
QueueManager.register(
"get_cache_infos",
callable=lambda idx: self.cache_infos_init[idx],
@@ -161,6 +189,10 @@ class EngineWorkerQueue:
"get_finish_request_barrier",
callable=lambda idx: self.finish_request_barrier[idx],
)
QueueManager.register(
"get_finish_add_cache_task_barrier",
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
)
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
self.manager.start()
else:
@@ -174,12 +206,17 @@ class EngineWorkerQueue:
QueueManager.register("get_read_finish_flag")
QueueManager.register("get_connected_client_counter")
QueueManager.register("get_finish_request_queue")
QueueManager.register("get_finish_add_cache_task_queue")
QueueManager.register("get_cache_infos")
QueueManager.register("get_client_read_info_flag")
QueueManager.register("get_lock_info")
QueueManager.register("get_disaggregate_requests")
QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier")
QueueManager.register("get_finish_add_cache_task_barrier")
QueueManager.register("get_connect_rdma_tasks")
QueueManager.register("get_connect_rdma_tasks_responses")
QueueManager.register("get_connect_task_lock")
self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry()
@@ -199,7 +236,20 @@ class EngineWorkerQueue:
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
self.available_prefill_instances = self.manager.get_available_prefill_instances()
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
self.local_data_parallel_id
)
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
self.local_data_parallel_id
)
# p/d互联
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
self.local_data_parallel_id
)
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
assert self.num_client == len(self.client_read_flag)
if is_server:
@@ -281,6 +331,44 @@ class EngineWorkerQueue:
self.lock.release()
return total_num
def put_connect_rdma_task(self, connect_rdma_task):
self.connect_task_lock.acquire()
self.connect_rdma_task_queue.append(connect_rdma_task)
self.connect_task_lock.release()
def get_connect_rdma_task(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def put_connect_rdma_task_response(self, connect_rdma_task_response):
self.connect_task_lock.acquire()
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
self.connect_task_lock.release()
def get_connect_rdma_task_response(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_response_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_response_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def get_prefill_instances(self):
"""
check if the prefill queue is empty
@@ -365,6 +453,29 @@ class EngineWorkerQueue:
llm_logger.debug(f"get finished req: {ans}")
return ans
def put_finished_add_cache_task_req(self, req_ids) -> None:
"""
Put finished request ID into the queue.
Args:
req_ids: Request ID to be added to the queue
"""
self.finished_add_cache_task_queue.put(req_ids)
def get_finished_add_cache_task_req(self) -> str:
"""
Get finished request ID from the queue.
Returns:
str: Finished request ID
"""
ans = []
if self.finished_add_cache_task_queue.empty():
return ans
ans = self.finished_add_cache_task_queue.get()
llm_logger.debug(f"get finished req: {ans}")
return ans
def disaggregate_queue_empty(self):
"""
Check if the disaggregated task queue is empty.

View File

@@ -211,9 +211,8 @@ class DeepEPEngine:
self.num_experts = num_experts
self.num_local_experts = num_experts // ep_size
self.async_finish = async_finish
from paddle.base.core import Config
self.ep_config = Config(24, 6, 256)
self.ep_config = None
# Store phase and role for buffer management
self._splitwise_role = splitwise_role

View File

@@ -76,6 +76,7 @@ else:
update_inputs,
step_reschedule,
update_inputs_v1,
speculate_step_reschedule,
)
@@ -413,6 +414,36 @@ def step_cuda(
"""
if speculative_config.method is not None:
if DISABLE_RECOVER:
speculate_step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
if enable_prefix_caching:
speculate_step_system_cache(
share_inputs["stop_flags"],
@@ -473,12 +504,11 @@ def step_cuda(
speculative_config.num_speculative_tokens,
)
else:
if enable_prefix_caching:
step_system_cache(
if DISABLE_RECOVER:
step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
@@ -501,11 +531,13 @@ def step_cuda(
block_size,
enc_dec_block_num,
)
elif DISABLE_RECOVER:
step_reschedule(
else:
if enable_prefix_caching:
step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],

View File

@@ -58,7 +58,6 @@ class TokenProcessor:
self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
self.zmq_server = ZmqIpcServer(
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
@@ -298,10 +297,15 @@ class TokenProcessor:
try:
is_blocking = True
if self.speculative_decoding:
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
else:
if self.use_logprobs:
get_output_topk(
@@ -370,14 +374,18 @@ class TokenProcessor:
llm_logger.info(f"finished_task_id: {finished_task_id}")
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
if task_id in self.prefill_result_status:
self.split_connector.send_first_token(task.disaggregate_info, [result])
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager.finish_requests_async(task_id)
else:
self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
if task_id in self.resource_manager.req_dict:
del self.resource_manager.req_dict[task_id]
if self.prefill_result_status[task_id] != "finished":
result.error_code = 400
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
del self.resource_manager.req_dict[task_id]
self.split_connector.send_first_token(task.disaggregate_info, [result])
break
else:
time.sleep(0.002)
@@ -388,6 +396,8 @@ class TokenProcessor:
self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
if task_id in self.resource_manager.req_dict:
del self.resource_manager.req_dict[task_id]
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
main_process_metrics.available_gpu_block_num.set(
@@ -461,6 +471,12 @@ class TokenProcessor:
task_id = task.request_id
if self.cfg.speculative_config.method:
if accept_num[i] == -3:
recovery_stop = True
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
else:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
@@ -469,7 +485,7 @@ class TokenProcessor:
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if len(token_ids) == 0 or token_ids[-1] <= 0:
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue
else:
token_id = int(tokens[i, 0])
@@ -537,6 +553,7 @@ class TokenProcessor:
for token_id in token_ids:
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if self.use_logprobs:
@@ -567,7 +584,11 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill)
break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
batch_result.append(result)
self.postprocess(batch_result)
@@ -609,7 +630,7 @@ class TokenProcessor:
self.cfg.speculative_config.num_speculative_tokens,
)
real_accept_num = [x for x in accept_num if x != 0]
real_accept_num = [x for x in accept_num if x > 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num)

View File

@@ -18,6 +18,7 @@ import redis
from fastdeploy.utils import llm_logger
from .dp_scheduler import DPScheduler
from .global_scheduler import GlobalScheduler
from .local_scheduler import LocalScheduler
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
@@ -89,6 +90,54 @@ class LocalSchedulerConfig:
llm_logger.info("=============================================================")
class DPLocalSchedulerConfig(LocalSchedulerConfig):
"""
Configuration class for DPLocalScheduler.
Attributes:
max_size: Maximum number of concurrent requests (-1 for unlimited)
ttl: Time-to-live in seconds for request expiration
"""
def __init__(
self,
max_size: int = -1,
ttl: int = 900,
max_model_len: int = 8192,
enable_chunked_prefill: bool = False,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
splitwise_role: str = "prefill",
**kwargs,
):
"""
Initialize LocalScheduler configuration.
Args:
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
ttl: Time-to-live in seconds for request expiration (default 900s)
max_model_len: Maximum model context length in tokens
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Max partial prefill operations allowed
max_long_partial_prefills: Max long-running partial prefill ops
long_prefill_token_threshold: Token count threshold for long prefill
**kwargs: Additional unused arguments (for forward compatibility)
Note:
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
- See LocalScheduler class for implementation details
"""
self.max_size = max_size
self.ttl = ttl
self.max_model_len = max_model_len
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.splitwise_role = splitwise_role
class GlobalSchedulerConfig:
"""
Configuration class for GlobalScheduler (Redis-based).
@@ -235,6 +284,9 @@ class SchedulerConfig:
if self.name == "splitwise":
self.config = SplitWiseSchedulerConfig(**args)
if self.name == "dp":
self.config = DPLocalSchedulerConfig(**args)
def check(self):
"""
Validate the configuration.
@@ -242,7 +294,7 @@ class SchedulerConfig:
Raises:
Exception: If invalid scheduler type is specified
"""
if self.name not in ["local", "global", "splitwise"]:
if self.name not in ["local", "global", "splitwise", "dp"]:
raise Exception(f"Unknown scheduler type {self.name}")
self.config.check()
@@ -280,6 +332,17 @@ class SchedulerConfig:
if self.name == "splitwise":
return SplitWiseScheduler(self.config)
if self.name == "dp":
return DPScheduler(
max_size=self.config.max_size,
ttl=self.config.ttl,
enable_chunked_prefill=self.config.enable_chunked_prefill,
max_num_partial_prefills=self.config.max_num_partial_prefills,
max_long_partial_prefills=self.config.max_long_partial_prefills,
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
splitwise_role=self.config.splitwise_role,
)
return LocalScheduler(
max_size=self.config.max_size,
ttl=self.config.ttl,

View File

@@ -0,0 +1,272 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import logging
import threading
import time
from multiprocessing import Queue
from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import envs, get_logger
class DPLocalScheduler(LocalScheduler):
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
super().__init__(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
)
self.splitwise_role = splitwise_role
self.scheduler_logger = logging
def put_results(self, results: List[RequestOutput]):
"""
Add processing results back to the scheduler.
Args:
results: List of RequestOutput objects containing results
"""
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
finished_responses = [response.request_id for response in responses if response.finished]
if len(finished_responses) > 0:
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
for response in responses:
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
continue
self.responses[response.request_id].append(response)
self.responses_not_empty.notify_all()
def _recycle(self, request_id: Optional[str] = None):
"""
Clean up expired or completed requests to free memory.
Args:
request_id: Optional specific request ID to remove.
If None, removes all expired requests.
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
if self.splitwise_role == "decode":
return
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if now - request.schedule_time < self.ttl:
break
expired_ids.append(request.request_id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
"""
Retrieve requests from the scheduler based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
self.scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
required_total_blocks = 0
current_prefill_tokens = 0
start_batch_time = time.time()
requests: List[Request] = []
with self.requests_not_empty:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(
request.prompt_tokens_ids_len, block_size
)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
else:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
requests.append(request.raw)
self.ids_read_cursor += 1
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:
self.scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
self.scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
return requests
class DPScheduler:
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
self._scheduler = DPLocalScheduler(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
splitwise_role,
)
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queue = result_queue
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
threading.Thread(target=self._get_response_from_local).start()
def put_requests(self, requests: List[Dict]):
results = []
for request in requests:
if not hasattr(request, "dp_rank"):
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
self.request_queues[request.dp_rank].put(request)
results.append((request.request_id, None))
return results
def _put_requests_to_local(self):
while True:
request = self.request_queues[self.dp_rank].get()
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
self._scheduler.put_requests([request])
def _get_response_from_local(self):
while True:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queue.put(results)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
return self._scheduler.get_requests(
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
)
def get_unhandled_request_num(self):
return len(self._scheduler.requests)
def put_results(self, results: List[RequestOutput]):
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queue.get()

View File

@@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
class SplitwiseConnector:
"""
@@ -46,12 +44,19 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object.
"""
self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager
self.connect_innode_instances = {}
self.temp_cache_info = dict()
self.current_request_ids = dict()
self.idx = self.cfg.parallel_config.local_data_parallel_id
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context()
@@ -70,7 +75,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
@@ -90,17 +95,17 @@ class SplitwiseConnector:
if not socks:
continue
else:
logger.debug(f"receive {socks}")
self.logger.debug(f"receive {socks}")
frames = self.router_socket.recv_multipart()
logger.debug(f"frames: {frames}")
self.logger.debug(f"frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
time.sleep(0.001)
else:
time.sleep(5)
except Exception as e:
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
time.sleep(1)
def _get_push_socket(self, addr):
@@ -112,7 +117,7 @@ class SplitwiseConnector:
return sock
try:
logger.info(f"Establishing new connection to {addr}")
self.logger.info(f"Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数
@@ -131,7 +136,7 @@ class SplitwiseConnector:
return sock
except zmq.ZMQError as e:
logger.error(f"Connection to {addr} failed: {e}")
self.logger.error(f"Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e
@@ -140,7 +145,7 @@ class SplitwiseConnector:
return
try:
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload)
try:
@@ -148,19 +153,19 @@ class SplitwiseConnector:
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
logger.warning(f"Connection to {addr} not established")
self.logger.warning(f"Connection to {addr} not established")
except zmq.Again:
logger.warning(f"Send queue full for {addr}")
self.logger.warning(f"Send queue full for {addr}")
except Exception as e:
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
main_process_metrics.send_cache_failed_num.inc()
self._close_connection(addr)
except Exception as e:
logger.error(f"Message preparation failed: {e}")
self.logger.error(f"Message preparation failed: {e}")
def _close_connection(self, addr):
"""
@@ -265,7 +270,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
logger.info(f"send splitwise tasks to port {addr} decode")
self.logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -295,7 +300,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode")
self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port
return current_port
@@ -305,7 +310,7 @@ class SplitwiseConnector:
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
logger.info("send first token to port decode")
self.logger.info("send first token to port decode")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
@@ -313,7 +318,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode")
self.logger.info(f"send first token to port {node} decode")
self._send_message(node, "decode", tasks_list)
def create_connection(self, port):
@@ -329,6 +334,26 @@ class SplitwiseConnector:
client_id=0,
)
def check_decode_allocated(self, task):
start_time = time.time()
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
if time.time() - start_time > 30:
del self.current_request_ids[task.request_id]
return False, "timeout"
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id):
"""
Send cache information to specific port.
@@ -345,7 +370,7 @@ class SplitwiseConnector:
for i in range(len(tasks)):
if tasks[i].disaggregate_info is None:
continue
logger.info(f"{tasks[i].disaggregate_info}")
self.logger.info(f"{tasks[i].disaggregate_info}")
if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
cache_info = {
@@ -380,6 +405,14 @@ class SplitwiseConnector:
addr = "prefill"
if current_id == -1:
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
"current_id": tasks[i].idx,
"need_prefill_tokens": tasks[i].need_prefill_tokens,
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
@@ -396,7 +429,7 @@ class SplitwiseConnector:
else:
if len(temp_cache_info):
for k, v in temp_cache_info.items():
logger.info(f"{k} {v}")
self.logger.info(f"{k} {v}")
if ":" in str(k):
self._send_message(k, "cache_sync", v)
else:
@@ -427,7 +460,7 @@ class SplitwiseConnector:
"""
try:
msg_type, payload = self._deserialize_message(message)
logger.info(f"{msg_type}")
self.logger.info(f"{msg_type}")
if msg_type == "prefill":
self._handle_prefill(payload)
@@ -435,11 +468,16 @@ class SplitwiseConnector:
self._handle_decode(payload)
elif msg_type == "cache_sync":
for task in payload:
self.logger.info(f"cache_sync task: {task}")
current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
del self.current_request_ids[task["request_id"]]
if current_status == "finished":
self.engine_worker_queue.put_cache_info(payload)
except Exception as e:
logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
def _handle_prefill(self, tasks):
"""
@@ -462,8 +500,12 @@ class SplitwiseConnector:
index=task["outputs"]["index"],
send_idx=0,
token_ids=task["outputs"]["token_ids"],
draft_token_ids=task["outputs"]["draft_token_ids"],
),
finished=True,
num_cached_tokens=task["num_cached_tokens"],
error_code=task["error_code"],
error_msg=task["error_msg"],
)
)
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))

View File

@@ -16,6 +16,7 @@
import argparse
import json
import os
import time
from typing import Tuple
@@ -259,6 +260,7 @@ class PaddleDisWorkerProc:
"""Main event loop for Paddle Distributed Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
"""
# Currently, only support single node
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
req_ids = []
@@ -643,6 +645,12 @@ def parse_args():
help="Flag to specify dtype of lm_head as FP32",
)
parser.add_argument(
"--cache-transfer-protocol",
type=str,
default="ipc",
help="support protocol list, comma separated, default is ipc",
)
parser.add_argument(
"--runner",
type=str,
@@ -762,8 +770,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if args.splitwise_role != "mixed":
logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda():
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
@@ -772,6 +779,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill":
os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"
fd_config = FDConfig(
model_config=model_config,
parallel_config=parallel_config,