mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else{
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
|
@@ -14,7 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@@ -23,16 +26,72 @@ import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
shared_memory_exists,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
从命令行解析参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser("Cache Messager")
|
||||
parser.add_argument(
|
||||
"--splitwise_role",
|
||||
type=str,
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only surport ipc now",
|
||||
)
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
|
||||
parser.add_argument(
|
||||
"--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
parser.add_argument(
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_config",
|
||||
type=json.loads,
|
||||
default="{}",
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class CacheMessager:
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
@@ -69,11 +128,6 @@ class CacheMessager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
assert splitwise_role in [
|
||||
"prefill",
|
||||
"decode",
|
||||
], "splitwise_role must be prefill or decode"
|
||||
self.splitwise_role = splitwise_role
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
@@ -147,15 +201,16 @@ class CacheMessager:
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
|
||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
||||
layerwise_send_cache_thread.daemon = True
|
||||
layerwise_send_cache_thread.start()
|
||||
if self.splitwise_role != "mixed":
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _prefill_layerwise_send_cache_thread(self):
|
||||
def prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
@@ -163,23 +218,23 @@ class CacheMessager:
|
||||
try:
|
||||
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}"
|
||||
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
|
||||
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
||||
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||
array=prefilled_step_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=not shared_memory_exists(prefilled_step_name),
|
||||
)
|
||||
layer_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
||||
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||
array=prefilled_layer_idx_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.gpu_id,
|
||||
create=not shared_memory_exists(prefilled_layer_name),
|
||||
)
|
||||
logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}")
|
||||
logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")
|
||||
|
||||
step_shm_value.value[0] = -1
|
||||
layer_shm_value.value[0] = -1
|
||||
@@ -187,6 +242,9 @@ class CacheMessager:
|
||||
self.last_step_idx = -1
|
||||
self.last_layer_idx = -1 # int32
|
||||
|
||||
max_step_idx = 100003
|
||||
engine_recycled_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
@@ -202,11 +260,9 @@ class CacheMessager:
|
||||
-len(current_info["dest_block_ids"]) :
|
||||
]
|
||||
current_info["src_block_ids"] = current_src_blocks
|
||||
current_info["current_layer_ids"] = 0
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"start cache_infos: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
@@ -223,7 +279,18 @@ class CacheMessager:
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
|
||||
if self.last_step_idx > prefilled_step_idx:
|
||||
engine_recycled_count += 1
|
||||
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
|
||||
prefilled_step_idx = (
|
||||
prefilled_step_idx + max_step_idx * engine_recycled_count
|
||||
) # remap prefilled_step_idx for comparison
|
||||
|
||||
logger.debug(
|
||||
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
|
||||
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
|
||||
)
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
@@ -294,12 +361,493 @@ class CacheMessager:
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
del self.cache_info[req_id]
|
||||
|
||||
self.last_step_idx = prefilled_step_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
|
||||
class CacheMessagerV1:
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
splitwise_role,
|
||||
transfer_protocol,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
local_data_parallel_id,
|
||||
gpu_cache_kvs,
|
||||
rank,
|
||||
nranks,
|
||||
num_layers,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port=None,
|
||||
):
|
||||
"""
|
||||
Initialize the CacheMessager object.
|
||||
|
||||
Args:
|
||||
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
||||
transfer_protocol (str): support ipc and rdma
|
||||
engine_worker_queue_port (int): engine_worker_queue port
|
||||
gpu_cache_kvs (dict): GPU kv cache
|
||||
rank (int): current rank
|
||||
nranks (int): global rank number
|
||||
num_layers (int): model layer number
|
||||
gpu_id (int, optional): GPU ID
|
||||
rdma_port (int, optional): RDMA port
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.splitwise_role = splitwise_role
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
self.nranks = nranks
|
||||
address = (pod_ip, engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
num_client=self.nranks,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=local_data_parallel_id,
|
||||
)
|
||||
self.block_size = block_size
|
||||
transfer_protocol = transfer_protocol.split(",")
|
||||
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
cache_k_ptr_list = []
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
# 2. initialize the block_bytes
|
||||
cache_shape = key_cache.shape
|
||||
max_block_num = cache_shape[0]
|
||||
block_bytes = math.prod(cache_shape[1:])
|
||||
if key_cache.dtype == paddle.bfloat16:
|
||||
block_bytes *= 2
|
||||
logger.info(
|
||||
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||
)
|
||||
self.block_bytes = block_bytes
|
||||
|
||||
# 3. initialize the messager
|
||||
for protocol in transfer_protocol:
|
||||
if protocol == "ipc":
|
||||
self.messager[protocol] = IPCCommManager(
|
||||
self.rank,
|
||||
gpu_id,
|
||||
cache_k,
|
||||
cache_v,
|
||||
)
|
||||
local_device_id = int(str(cache_k[0].place)[-2])
|
||||
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
||||
|
||||
elif protocol == "rdma":
|
||||
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
||||
|
||||
self.messager[protocol] = RDMACommManager(
|
||||
splitwise_role,
|
||||
rank,
|
||||
gpu_id,
|
||||
cache_k_ptr_list,
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||
self.engine_cache_task_thread_lock = threading.Lock()
|
||||
self.engine_cache_tasks = [dict() for _ in range(512)]
|
||||
self.idx_cache_task_dict = {}
|
||||
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
|
||||
if splitwise_role == "prefill":
|
||||
consume_signals_thread = threading.Thread(target=self.consume_signals)
|
||||
consume_signals_thread.daemon = True
|
||||
consume_signals_thread.start()
|
||||
add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
|
||||
add_cache_task_thread.daemon = True
|
||||
add_cache_task_thread.start()
|
||||
|
||||
if self.splitwise_role != "mixed":
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
connect_rdma_thread.start()
|
||||
|
||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||
|
||||
def _add_cache_task_thread(self):
|
||||
while True:
|
||||
try:
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
finished_add_cache_task_req_ids = []
|
||||
if cache_info:
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
current_info = self.cache_info[info["request_id"]]
|
||||
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
|
||||
finished_add_cache_task_req_ids.append(info["request_id"])
|
||||
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
|
||||
current_info["dest_block_ids"]
|
||||
)
|
||||
padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
|
||||
"dest_block_ids"
|
||||
]
|
||||
current_info["dest_block_ids"] = padding_decode_block_ids
|
||||
current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
|
||||
current_info["sended_layer_id"] = -1
|
||||
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"finish add cache task: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
if self.rank == 0 and finished_add_cache_task_req_ids:
|
||||
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.")
|
||||
|
||||
def prefill_layerwise_send_cache_thread(self):
|
||||
"""
|
||||
layerwise_send_cache_thread:
|
||||
send cache to other instance
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
block_start_end_list = []
|
||||
current_prefilled_token_num_list = []
|
||||
for engine_index in engine_indexes:
|
||||
assert engine_index in self.idx_cache_task_dict
|
||||
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
||||
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
||||
if (
|
||||
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
||||
): # all chunks have been prefilled
|
||||
block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
|
||||
else:
|
||||
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
||||
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
||||
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
||||
logger.info(
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
||||
)
|
||||
assert (
|
||||
current_prefilled_token_num_list[0]
|
||||
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
||||
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
||||
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
||||
else:
|
||||
end_layer_idx = prefilled_layer_idx
|
||||
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
|
||||
time.sleep(0.01)
|
||||
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||
engine_index = engine_indexes[i]
|
||||
task = self.idx_cache_task_dict[engine_index]
|
||||
req_id = task["request_id"]
|
||||
if (
|
||||
block_id_start >= block_id_end
|
||||
): # no blocks need to transfer for this request in this chunk
|
||||
task["sended_layer_id"] += 1
|
||||
assert task["sended_layer_id"] == layer_idx
|
||||
if task["sended_layer_id"] == self.num_layers - 1:
|
||||
task["sended_layer_id"] = -1
|
||||
continue
|
||||
else:
|
||||
current_transfer_protocol = task["transfer_protocol"]
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["status"] == "error":
|
||||
continue
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
task["status"] = "connection error"
|
||||
continue
|
||||
elif task["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(task["device_ids"][self.rank])
|
||||
|
||||
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
|
||||
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
|
||||
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
|
||||
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
|
||||
|
||||
logger.info(
|
||||
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
if return_code != 0:
|
||||
task["status"] = "write cache error"
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
|
||||
task["sended_layer_id"] += 1
|
||||
assert task["sended_layer_id"] == layer_idx
|
||||
if task["sended_layer_id"] == self.num_layers - 1:
|
||||
self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
|
||||
block_id_end - block_id_start
|
||||
)
|
||||
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
|
||||
if task["status"] != "error":
|
||||
task["status"] = "finished"
|
||||
logger.info(
|
||||
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||
)
|
||||
else:
|
||||
task["sended_layer_id"] = -1
|
||||
if end_layer_idx == self.num_layers - 1:
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for engine_idx in engine_indexes:
|
||||
task = self.idx_cache_task_dict[engine_idx]
|
||||
if task["status"] == "finished" or ("error" in task["status"]):
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req(
|
||||
[(task["request_id"], task["status"])]
|
||||
)
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
self.engine_cache_tasks[task["current_id"]] = dict()
|
||||
del self.cache_info[task["request_id"]]
|
||||
del self.idx_cache_task_dict[task["current_id"]]
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
|
||||
time.sleep(0.01)
|
||||
|
||||
def consume_signals(self):
|
||||
paddle.device.set_device("cpu")
|
||||
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
|
||||
while True:
|
||||
try:
|
||||
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
|
||||
if not self.cache_info:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
tasks_count = kv_signal_data[0]
|
||||
if tasks_count == -1:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
layer_id = kv_signal_data[1].numpy().tolist()
|
||||
if layer_id == self.num_layers - 1:
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
||||
batch_engine_ids = []
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for bi in range(tasks_count):
|
||||
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
||||
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
|
||||
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
||||
chuck_token_offset + current_seq_len
|
||||
)
|
||||
batch_engine_ids.append(engine_idx)
|
||||
if layer_id == 0:
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Consume signals get exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
cache_type = args.cache_dtype
|
||||
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
||||
gpu_cache_kvs = {}
|
||||
gpu_cache_k_tensors = []
|
||||
gpu_cache_v_tensors = []
|
||||
|
||||
for i in range(args.num_layers + num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
|
||||
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
cache_messager = CacheMessagerV1(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
else:
|
||||
cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
cache_ready_signal.value[rank] = 1
|
||||
if args.splitwise_role == "mixed":
|
||||
while True:
|
||||
time.sleep(1)
|
||||
cache_messager.prefill_layerwise_send_cache_thread()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
main()
|
||||
|
@@ -29,7 +29,7 @@ from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cuda_host_alloc,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -139,40 +139,27 @@ class CacheTransferManager:
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
cache_type = args.cache_dtype
|
||||
cache_shape = [
|
||||
args.num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
]
|
||||
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape[0] = num_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{rank}.device{device}"
|
||||
value_name = f"value_caches_{i}_rank{rank}.device{device}"
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
value_cache = share_external_data(value_cache, value_name, cache_shape)
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_kvs[value_name] = value_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
|
||||
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
@@ -201,28 +188,6 @@ class CacheTransferManager:
|
||||
)
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
if args.enable_splitwise:
|
||||
logger.debug("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
from fastdeploy.cache_manager.cache_messager import CacheMessager
|
||||
|
||||
self.cache_messager = CacheMessager(
|
||||
splitwise_role=args.splitwise_role,
|
||||
transfer_protocol=args.protocol,
|
||||
pod_ip=args.pod_ip,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
gpu_cache_kvs=self.gpu_cache_kvs,
|
||||
rank=self.rank,
|
||||
nranks=args.mp_num,
|
||||
num_layers=args.num_layers + self.num_extra_layers,
|
||||
gpu_id=self.device,
|
||||
rdma_port=args.rdma_port,
|
||||
)
|
||||
logger.info("successfully create cache messager")
|
||||
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
name="cache_task_broadcast_signal",
|
||||
@@ -443,5 +408,7 @@ def main():
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
|
||||
paddle.set_device(f"gpu:{args.device_id}")
|
||||
main()
|
||||
|
@@ -150,6 +150,19 @@ class PrefixCacheManager:
|
||||
filename = "cache_transfer_manager.py"
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
|
||||
cache_messager_processes = []
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
)
|
||||
if cache_messager_processes is None:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
return []
|
||||
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
@@ -213,7 +226,76 @@ class PrefixCacheManager:
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
return cache_manager_processes
|
||||
all_cache_processes = cache_messager_processes + cache_manager_processes
|
||||
return all_cache_processes
|
||||
|
||||
def launch_cache_messager(
|
||||
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
|
||||
):
|
||||
"""
|
||||
launch_cache_messager function used to initialize the cache messager.
|
||||
"""
|
||||
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||
filename = "cache_messager.py"
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and cache_config.model_cfg.num_key_value_heads is not None
|
||||
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||
):
|
||||
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_messager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
logger.info("Waiting for cache ready...")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
exit_code = cache_messager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache messager successful")
|
||||
else:
|
||||
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
|
||||
cache_messager_processes = None
|
||||
return cache_messager_processes
|
||||
|
||||
def update_cache_config(self, cache_config):
|
||||
"""
|
||||
|
@@ -61,18 +61,12 @@ class RDMACommManager:
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
assert self.splitwise_role == "prefill", "only prefill can call this method"
|
||||
addr = f"{ip}:{port!s}"
|
||||
if addr in self.connected_rdma:
|
||||
return True
|
||||
ret = self.messager.is_connected(ip, str(port))
|
||||
if ret:
|
||||
self.connected_rdma.add(addr)
|
||||
return True
|
||||
|
||||
ret = self.messager.connect(ip, str(port))
|
||||
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||
if ret == 0:
|
||||
self.connected_rdma.add(addr)
|
||||
return ret == 0
|
||||
|
||||
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):
|
||||
|
@@ -1481,7 +1481,7 @@ class FDConfig:
|
||||
self.model_config.model_format = "torch"
|
||||
|
||||
# TODO
|
||||
self.max_prefill_batch = 3
|
||||
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
|
||||
if current_platform.is_xpu():
|
||||
self.max_prefill_batch = 1
|
||||
if self.model_config is not None and self.model_config.enable_mm:
|
||||
|
@@ -422,7 +422,7 @@ class EngineArgs:
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if self.speculative_config is not None:
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if self.splitwise_role != "mixed":
|
||||
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if not current_platform.is_cuda():
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
|
@@ -46,7 +46,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, envs, llm_logger
|
||||
from fastdeploy.utils import EngineError, envs, get_logger, llm_logger
|
||||
|
||||
try:
|
||||
TokenProcessor = load_token_processor_plugins()
|
||||
@@ -69,6 +69,13 @@ class EngineService:
|
||||
"""
|
||||
self.cfg = cfg
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel:
|
||||
self.llm_logger = get_logger(
|
||||
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.llm_logger = llm_logger
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
@@ -79,10 +86,6 @@ class EngineService:
|
||||
cfg.scheduler_config.splitwise_role,
|
||||
cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
raise NotImplementedError(
|
||||
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
|
||||
)
|
||||
else:
|
||||
self.resource_manager = ResourceManager(
|
||||
cfg.scheduler_config.max_num_seqs,
|
||||
@@ -135,12 +138,14 @@ class EngineService:
|
||||
self.insert_task_to_worker_thread.start()
|
||||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||||
self.token_processor.run()
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.split_mode_get_tasks()
|
||||
|
||||
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||||
current_suffix = int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
llm_logger.info(f"current_suffix: {current_suffix}")
|
||||
self.llm_logger.info(f"current_suffix: {current_suffix}")
|
||||
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
name="exist_task_signal",
|
||||
@@ -201,7 +206,7 @@ class EngineService:
|
||||
)
|
||||
|
||||
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
@@ -225,7 +230,7 @@ class EngineService:
|
||||
client_id=-1,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
llm_logger.info(
|
||||
self.llm_logger.info(
|
||||
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
|
||||
)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
@@ -254,7 +259,17 @@ class EngineService:
|
||||
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
|
||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||
self.resource_manager._recycle_block_tables(cur_task)
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
|
||||
continue
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
cur_task.num_cached_tokens = task.num_cached_tokens
|
||||
if (
|
||||
self.cfg.speculative_config.method in ["mtp"]
|
||||
and self.cfg.scheduler_config.splitwise_role == "decode"
|
||||
@@ -267,13 +282,14 @@ class EngineService:
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.scheduler.put_results([task])
|
||||
llm_logger.warning(
|
||||
self.llm_logger.warning(
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
current_tasks.append(cur_task)
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
if current_tasks:
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
return True
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
@@ -281,13 +297,34 @@ class EngineService:
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
|
||||
for item in tasks:
|
||||
item.schedule_start_time = time.time()
|
||||
|
||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||
if len(tasks) > available_batch:
|
||||
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
||||
llm_logger.error("The exceeded part will be ignored!")
|
||||
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
||||
self.llm_logger.error("The exceeded part will be ignored!")
|
||||
tasks = tasks[:available_batch]
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
@@ -296,7 +333,7 @@ class EngineService:
|
||||
|
||||
if not tasks:
|
||||
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
||||
llm_logger.error(error_msg)
|
||||
self.llm_logger.error(error_msg)
|
||||
raise EngineError(error_msg, error_code=500)
|
||||
return False
|
||||
|
||||
@@ -314,7 +351,7 @@ class EngineService:
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
for task in tasks:
|
||||
task.inference_start_time = time.time()
|
||||
if not is_prefill:
|
||||
@@ -473,7 +510,7 @@ class EngineService:
|
||||
Insert task to engine thread, monitor scheduler request queue.
|
||||
if the engine has resource, insert task to engine
|
||||
"""
|
||||
current_id = -1
|
||||
current_id = 0
|
||||
while getattr(self, "running", True):
|
||||
try:
|
||||
if self.resource_manager.available_batch() == 0:
|
||||
@@ -514,18 +551,21 @@ class EngineService:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
current_id = (current_id + 1) % 100003
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
llm_logger.info("Inserting splitwise tasks")
|
||||
self.llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
self.insert_tasks(tasks, current_id)
|
||||
insert_successful = self.insert_tasks(tasks, current_id)
|
||||
if insert_successful:
|
||||
current_id = current_id + 1
|
||||
else:
|
||||
continue
|
||||
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
except Exception as e:
|
||||
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
llm_logger.error(err_msg)
|
||||
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
self.llm_logger.error(err_msg)
|
||||
|
||||
def _scheduler_task_to_worker_v1(self):
|
||||
"""
|
||||
@@ -535,40 +575,100 @@ class EngineService:
|
||||
is_fetching = False
|
||||
|
||||
def _fetch_request():
|
||||
nonlocal is_fetching
|
||||
is_fetching = True
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
if self.cfg.model_config.enable_mm:
|
||||
available_blocks = self.resource_manager.available_block_num()
|
||||
else:
|
||||
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
||||
try:
|
||||
nonlocal is_fetching
|
||||
is_fetching = True
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
if self.cfg.model_config.enable_mm:
|
||||
available_blocks = self.resource_manager.available_block_num()
|
||||
else:
|
||||
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
||||
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=available_blocks,
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.max_model_len,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
# Fetch requests and add them to the scheduling queue
|
||||
for task in tasks:
|
||||
self.resource_manager.add_request(task)
|
||||
is_fetching = False
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=available_blocks,
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.max_model_len,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# assure fetch block ids from D
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
# release resource in P
|
||||
self.resource_manager.prerelease_resource(task)
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
# to send cache info to cache messager
|
||||
if tasks:
|
||||
self.split_connector.send_cache_infos(tasks, 0)
|
||||
# ensure cache tasks has sent to cache_messager
|
||||
need_check_req_ids = [task.request_id for task in tasks]
|
||||
while need_check_req_ids:
|
||||
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
|
||||
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
|
||||
if req_ids:
|
||||
for req_id in req_ids:
|
||||
assert req_id in need_check_req_ids
|
||||
need_check_req_ids.remove(req_id)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
# Fetch requests and add them to the scheduling queue
|
||||
if tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
self.resource_manager.add_request_in_p(tasks)
|
||||
else:
|
||||
for task in tasks:
|
||||
self.resource_manager.add_request(task)
|
||||
is_fetching = False
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
|
||||
is_fetching = False
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if self.engine_worker_queue.num_tasks() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if (
|
||||
len(self.resource_manager.waiting) == 0
|
||||
and (not is_fetching)
|
||||
and self.exist_prefill_task_signal.value[0] == 0
|
||||
):
|
||||
get_request_pool.submit(_fetch_request)
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
|
||||
not is_fetching
|
||||
):
|
||||
get_request_pool.submit(_fetch_request)
|
||||
|
||||
else:
|
||||
if (
|
||||
len(self.resource_manager.waiting) == 0
|
||||
and (not is_fetching)
|
||||
and self.exist_prefill_task_signal.value[0] == 0
|
||||
):
|
||||
get_request_pool.submit(_fetch_request)
|
||||
# 2. Schedule requests
|
||||
tasks = self.resource_manager.schedule()
|
||||
# 3. Send to engine
|
||||
@@ -579,8 +679,8 @@ class EngineService:
|
||||
time.sleep(0.005)
|
||||
|
||||
except Exception as e:
|
||||
err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||
llm_logger.error(err_msg)
|
||||
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||
self.llm_logger.error(err_msg)
|
||||
|
||||
def start_zmq_service(self, api_server_pid=None):
|
||||
if api_server_pid is None:
|
||||
@@ -608,6 +708,9 @@ class EngineService:
|
||||
|
||||
def _insert_zmq_task_to_scheduler(self):
|
||||
added_requests: Dict[str, int] = dict()
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
return
|
||||
while self.running:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
@@ -616,7 +719,7 @@ class EngineService:
|
||||
else:
|
||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
|
||||
request, insert_task = None, []
|
||||
@@ -627,16 +730,16 @@ class EngineService:
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
main_process_metrics.requests_number.inc()
|
||||
llm_logger.debug(f"Receive request: {request}")
|
||||
self.llm_logger.debug(f"Receive request: {request}")
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||
err_msg = str(e)
|
||||
results.append((data["request_id"], err_msg))
|
||||
|
||||
if self.guided_decoding_checker is not None and err_msg is None:
|
||||
request, err_msg = self.guided_decoding_checker.schema_format(request)
|
||||
if err_msg is not None:
|
||||
llm_logger.error(f"Receive request error: {err_msg}")
|
||||
self.llm_logger.error(f"Receive request error: {err_msg}")
|
||||
results.append((request.request_id, err_msg))
|
||||
|
||||
if err_msg is None:
|
||||
@@ -670,7 +773,7 @@ class EngineService:
|
||||
# Send result by zmq directly
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
self.llm_logger.error(
|
||||
f"Error happened while receiving new request from zmq, details={e}, "
|
||||
f"traceback={traceback.format_exc()}"
|
||||
)
|
||||
@@ -689,7 +792,7 @@ class EngineService:
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}")
|
||||
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def split_mode_get_tasks(self):
|
||||
"""
|
||||
@@ -702,13 +805,22 @@ class EngineService:
|
||||
|
||||
processed_indices = []
|
||||
for idx, task in enumerate(self.waiting_requests):
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
processed_indices.append(idx)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if self.resource_manager.preallocate_resource_in_d(task):
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
self.split_connector.send_cache_infos([task], -1)
|
||||
processed_indices.append(idx)
|
||||
else:
|
||||
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||
break
|
||||
else:
|
||||
llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||
break
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
processed_indices.append(idx)
|
||||
else:
|
||||
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||
break
|
||||
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
@@ -730,32 +842,79 @@ class EngineService:
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
for task in tasks:
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if (
|
||||
not task.outputs.token_ids
|
||||
): # first token is eos in Prefill, just recycle resource and continue
|
||||
cur_task = self.resource_manager.requests[task.request_id]
|
||||
self.resource_manager.stop_flags[cur_task.idx] = True
|
||||
self.resource_manager.tasks_list[cur_task.idx] = None
|
||||
self.resource_manager._free_blocks(cur_task)
|
||||
if cur_task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.llm_logger.warning(
|
||||
f"{task.request_id} need not decode after first token"
|
||||
)
|
||||
del self.resource_manager.requests[task.request_id]
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
continue
|
||||
if task.error_code != 200:
|
||||
cur_task = self.resource_manager.requests[task.request_id]
|
||||
self.resource_manager.stop_flags[cur_task.idx] = True
|
||||
self.resource_manager.tasks_list[cur_task.idx] = None
|
||||
self.resource_manager._free_blocks(cur_task)
|
||||
if cur_task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.scheduler.put_results([task])
|
||||
self.llm_logger.warning(
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
self.resource_manager.insert_task_for_decoding(task)
|
||||
|
||||
else:
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
can_allocate_resource = False
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if self.resource_manager.preallocate_resource_in_d(task):
|
||||
self.split_connector.send_cache_infos([task], -1)
|
||||
can_allocate_resource = True
|
||||
else:
|
||||
if self.resource_manager.is_resource_sufficient(
|
||||
task.prompt_token_ids_len
|
||||
):
|
||||
self.insert_tasks([task])
|
||||
can_allocate_resource = True
|
||||
if can_allocate_resource is False:
|
||||
if not self.enable_decode_cache_task:
|
||||
task.error_msg = "Not enough resources"
|
||||
new_waiting.append(task)
|
||||
|
||||
if new_waiting:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||
if not self.enable_decode_cache_task:
|
||||
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||
else:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
self.llm_logger.info(
|
||||
f"Added {len(new_waiting)} tasks to waiting queue"
|
||||
)
|
||||
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error in main loop: {e}")
|
||||
self.llm_logger.error(f"Error in main loop: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
threading.Thread(target=receiver_loop, daemon=True).start()
|
||||
|
@@ -120,11 +120,10 @@ class LLMEngine:
|
||||
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
self.engine.data_processor = self.data_processor
|
||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||
self.launch_components()
|
||||
|
||||
self.engine.start()
|
||||
if api_server_pid is not None:
|
||||
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
|
||||
self.engine.start_zmq_service(api_server_pid)
|
||||
|
||||
if self.do_profile == 0 and (
|
||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
|
||||
@@ -159,11 +158,14 @@ class LLMEngine:
|
||||
|
||||
if self.do_profile:
|
||||
self._stop_profile()
|
||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||
self.launch_components()
|
||||
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
if api_server_pid is not None:
|
||||
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
|
||||
self.engine.start_zmq_service(api_server_pid)
|
||||
|
||||
# Worker launched
|
||||
self.check_worker_initialize_status_func_thread.join()
|
||||
if not result_container["worker_is_alive"]:
|
||||
@@ -427,7 +429,10 @@ class LLMEngine:
|
||||
)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
variables["FLAGS_use_pd_disaggregation"] = 1
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
|
||||
else:
|
||||
variables["FLAGS_use_pd_disaggregation"] = 1
|
||||
# TODO dynamic load environment variable
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
||||
@@ -498,6 +503,7 @@ class LLMEngine:
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||
f" --ips {ips}"
|
||||
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
|
||||
f" --runner {self.cfg.model_config.runner}"
|
||||
f" --convert {self.cfg.model_config.convert}"
|
||||
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
||||
@@ -625,13 +631,11 @@ class LLMEngine:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# 单机逻辑
|
||||
self.engine.engine_worker_queue.available_prefill_instances.put(1)
|
||||
self.engine.split_mode_get_tasks()
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.splitwise_receive_thread = threading.Thread(
|
||||
target=self.engine.split_connector.start_receiver, args=()
|
||||
)
|
||||
self.splitwise_receive_thread.daemon = True
|
||||
self.splitwise_receive_thread.start()
|
||||
self.splitwise_receive_thread = threading.Thread(
|
||||
target=self.engine.split_connector.start_receiver, args=()
|
||||
)
|
||||
self.splitwise_receive_thread.daemon = True
|
||||
self.splitwise_receive_thread.start()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
@@ -640,6 +644,14 @@ class LLMEngine:
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.engine.scheduler.start(role, host_ip, disaggregate)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.engine.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
@@ -669,6 +681,9 @@ class LLMEngine:
|
||||
args=(
|
||||
self.cfg,
|
||||
i,
|
||||
None,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queue_for_dp_ipc,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@@ -27,6 +27,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.common_engine import EngineService
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.utils import console_logger, envs, llm_logger
|
||||
|
||||
|
||||
@@ -69,8 +70,12 @@ class ExpertService:
|
||||
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.internal_adapter = InternalAdapter(cfg=self.cfg, engine=self.engine, dp_rank=local_data_parallel_id)
|
||||
|
||||
def start(self, ipc_signal_suffix, local_data_parallel_id):
|
||||
def start(
|
||||
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
If `api_server_pid` is defined, will launch a thread
|
||||
@@ -80,6 +85,11 @@ class ExpertService:
|
||||
|
||||
start_time = time.time()
|
||||
self.engine.start()
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
self.cfg.init_cache_info()
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
|
||||
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
|
||||
|
||||
if ipc_signal_suffix is not None:
|
||||
self.api_server_pid = ipc_signal_suffix
|
||||
self.engine.start_zmq_service(ipc_signal_suffix)
|
||||
@@ -88,8 +98,8 @@ class ExpertService:
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
|
||||
self.engine.split_mode_get_tasks()
|
||||
ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
|
||||
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache)
|
||||
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.cfg.init_cache_info()
|
||||
@@ -144,14 +154,18 @@ class ExpertService:
|
||||
self.zmq_server.close()
|
||||
|
||||
|
||||
def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None):
|
||||
def start_data_parallel_service(
|
||||
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Start expert service
|
||||
"""
|
||||
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
|
||||
|
||||
try:
|
||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
||||
expert_service.start(
|
||||
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
|
||||
def deamon_thread():
|
||||
while True:
|
||||
@@ -159,5 +173,6 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
|
||||
|
||||
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
|
||||
t_deamon.start()
|
||||
t_deamon.join()
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")
|
||||
|
@@ -73,6 +73,7 @@ class Request:
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True,
|
||||
trace_carrier: dict = dict(),
|
||||
dp_rank: Optional[int] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
image_start: int = 0,
|
||||
video_start: int = 0,
|
||||
@@ -145,6 +146,8 @@ class Request:
|
||||
# extend block tables
|
||||
self.use_extend_tables = False
|
||||
self.extend_block_tables = []
|
||||
# dp
|
||||
self.dp_rank = dp_rank
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@@ -187,6 +190,7 @@ class Request:
|
||||
image_end=d.get("image_end", 0),
|
||||
video_end=d.get("video_end", 0),
|
||||
audio_end=d.get("audio_end", 0),
|
||||
dp_rank=d.get("dp_rank", None),
|
||||
)
|
||||
|
||||
@property
|
||||
|
@@ -328,8 +328,8 @@ class ResourceManager:
|
||||
Delete cached data from the task's prompt token ids based on the cached length.
|
||||
"""
|
||||
if cached_len == len(task.prompt_token_ids):
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
|
||||
task.seq_lens_decoder = cached_len - 1
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
|
||||
task.seq_lens_decoder = cached_len - self.cfg.block_size
|
||||
else:
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
||||
task.seq_lens_decoder = cached_len
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@@ -26,7 +27,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestStatus, RequestType
|
||||
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
@@ -297,6 +298,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
||||
if (
|
||||
self.config.scheduler_config.splitwise_role == "prefill"
|
||||
): # do not need to schedule for decoding
|
||||
req_index += 1
|
||||
continue
|
||||
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
||||
request.num_computed_tokens = request.num_total_tokens - 1
|
||||
if (
|
||||
@@ -400,11 +406,12 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.status = RequestStatus.RUNNING
|
||||
main_process_metrics.num_requests_waiting.dec(1)
|
||||
main_process_metrics.num_requests_running.inc(1)
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[allocated_position] = request
|
||||
self.stop_flags[allocated_position] = False
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
if self.config.scheduler_config.splitwise_role == "mixed":
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[allocated_position] = request
|
||||
self.stop_flags[allocated_position] = False
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
else:
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self._free_blocks(request)
|
||||
@@ -569,6 +576,127 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.waiting.append(request)
|
||||
self.requests[request.request_id] = request
|
||||
|
||||
def prerelease_resource(self, request: Request):
|
||||
"""
|
||||
Release resource in P or D before finished due to unexpected error.
|
||||
"""
|
||||
with self.lock:
|
||||
self.tasks_list[request.idx] = None
|
||||
self.stop_flags[request.idx] = True
|
||||
del self.requests[request.request_id]
|
||||
del self.req_dict[request.request_id]
|
||||
self._free_blocks(request)
|
||||
|
||||
def add_request_in_p(self, requests: list[Request]):
|
||||
with self.lock:
|
||||
for request in requests:
|
||||
request.inference_start_time = time.time()
|
||||
request.schedule_start_time = time.time()
|
||||
self.running.append(request)
|
||||
|
||||
def preallocate_resource_in_p(self, request: Request):
|
||||
"""
|
||||
In P/D aggregated deployment, preallocate resource for P.
|
||||
If can allocate, allocate resources and return True
|
||||
If can not, return False
|
||||
"""
|
||||
assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method"
|
||||
with self.lock:
|
||||
if self.available_batch() == 0:
|
||||
return False
|
||||
request.need_prefill_tokens = len(request.prompt_token_ids)
|
||||
need_prealloc_prefill_blocks = (
|
||||
request.need_prefill_tokens + self.config.cache_config.block_size - 1
|
||||
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
# Enable prefix caching
|
||||
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||
need_prealloc_prefill_blocks
|
||||
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||
return False
|
||||
success = self.get_prefix_cached_blocks(request)
|
||||
if not success:
|
||||
self._free_blocks(request)
|
||||
return False
|
||||
# consider for mtp, plus enc_dec_block_num
|
||||
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
|
||||
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[request.idx] = request
|
||||
self.stop_flags[request.idx] = False
|
||||
self.requests[request.request_id] = request
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
return True
|
||||
else:
|
||||
self._free_blocks(request)
|
||||
return False
|
||||
|
||||
else:
|
||||
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
|
||||
request.num_computed_tokens = 0
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[request.idx] = request
|
||||
self.stop_flags[request.idx] = False
|
||||
self.requests[request.request_id] = request
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def preallocate_resource_in_d(self, request: Request):
|
||||
"""
|
||||
In P/D aggregated deployment, D should preallocate resource for P.
|
||||
If can allocate, allocate resources and return True
|
||||
If can not, return False
|
||||
"""
|
||||
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
|
||||
with self.lock:
|
||||
if len(self.waiting) > 0:
|
||||
return False
|
||||
if self.available_batch() == 0:
|
||||
return False
|
||||
request.need_prefill_tokens = len(request.prompt_token_ids)
|
||||
need_prealloc_prefill_blocks = (
|
||||
request.need_prefill_tokens + self.config.cache_config.block_size - 1
|
||||
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
|
||||
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
|
||||
request.num_computed_tokens = request.need_prefill_tokens
|
||||
request.disaggregate_info["block_tables"] = request.block_tables
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[request.idx] = request
|
||||
self.stop_flags[request.idx] = False
|
||||
self.requests[request.request_id] = request
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
return True
|
||||
return False
|
||||
|
||||
def insert_task_for_decoding(self, request_output_in_p: RequestOutput):
|
||||
"""
|
||||
In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P.
|
||||
"""
|
||||
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
|
||||
with self.lock:
|
||||
request = self.requests[request_output_in_p.request_id]
|
||||
request.output_token_ids.append(request_output_in_p.outputs.token_ids[0])
|
||||
request.num_cached_tokens = request_output_in_p.num_cached_tokens
|
||||
if (
|
||||
self.config.speculative_config.method in ["mtp"]
|
||||
and self.config.scheduler_config.splitwise_role == "decode"
|
||||
):
|
||||
request.draft_token_ids = copy.deepcopy(request_output_in_p.outputs.draft_token_ids)
|
||||
# update request.need_prefill_tokens
|
||||
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
|
||||
request.inference_start_time = time.time()
|
||||
request.schedule_start_time = time.time()
|
||||
self.running.append(request)
|
||||
|
||||
def _free_blocks(self, request: Request):
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self.cache_manager.release_block_ids(request)
|
||||
@@ -620,5 +748,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.tasks_list[request.idx] = None
|
||||
self.stop_flags[request.idx] = True
|
||||
del self.requests[req_id]
|
||||
if req_id in self.req_dict:
|
||||
del self.req_dict[req_id]
|
||||
except Exception as e:
|
||||
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
||||
|
@@ -109,6 +109,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
# Whether to enable cache task in decode node
|
||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
|
||||
# Batched token timeout in EP
|
||||
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||
# Max pre-fetch requests number in PD
|
||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
|
||||
}
|
||||
|
||||
@@ -120,6 +126,14 @@ def __getattr__(name: str):
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def get_unique_name(self, name):
|
||||
"""
|
||||
Get unique name for config
|
||||
"""
|
||||
shm_uuid = os.getenv("SHM_UUID", "")
|
||||
return name + f"_{shm_uuid}"
|
||||
|
||||
|
||||
def __setattr__(name: str, value: Any):
|
||||
assert name in environment_variables
|
||||
environment_variables[name] = lambda: value
|
||||
|
@@ -84,18 +84,28 @@ class EngineWorkerQueue:
|
||||
Value("i", 0) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.client_read_info_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.lock_info_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.connect_task_lock_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.finish_request_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.finish_add_cache_task_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
# Register shared objects with proxy types
|
||||
QueueManager.register(
|
||||
"get_tasks",
|
||||
@@ -117,6 +127,19 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_lock",
|
||||
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks_responses",
|
||||
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connected_client_counter",
|
||||
callable=lambda idx: self.connected_client_counter_init[idx],
|
||||
@@ -128,6 +151,11 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.finished_req_queue[idx],
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_queue",
|
||||
callable=lambda idx: self.finished_add_cache_task_queue[idx],
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_cache_infos",
|
||||
callable=lambda idx: self.cache_infos_init[idx],
|
||||
@@ -161,6 +189,10 @@ class EngineWorkerQueue:
|
||||
"get_finish_request_barrier",
|
||||
callable=lambda idx: self.finish_request_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_barrier",
|
||||
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
|
||||
)
|
||||
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self.manager.start()
|
||||
else:
|
||||
@@ -174,12 +206,17 @@ class EngineWorkerQueue:
|
||||
QueueManager.register("get_read_finish_flag")
|
||||
QueueManager.register("get_connected_client_counter")
|
||||
QueueManager.register("get_finish_request_queue")
|
||||
QueueManager.register("get_finish_add_cache_task_queue")
|
||||
QueueManager.register("get_cache_infos")
|
||||
QueueManager.register("get_client_read_info_flag")
|
||||
QueueManager.register("get_lock_info")
|
||||
QueueManager.register("get_disaggregate_requests")
|
||||
QueueManager.register("get_available_prefill_instances")
|
||||
QueueManager.register("get_finish_request_barrier")
|
||||
QueueManager.register("get_finish_add_cache_task_barrier")
|
||||
QueueManager.register("get_connect_rdma_tasks")
|
||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||
QueueManager.register("get_connect_task_lock")
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
|
||||
@@ -199,7 +236,20 @@ class EngineWorkerQueue:
|
||||
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
|
||||
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
||||
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
||||
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
# p/d互联
|
||||
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||
|
||||
assert self.num_client == len(self.client_read_flag)
|
||||
|
||||
if is_server:
|
||||
@@ -281,6 +331,44 @@ class EngineWorkerQueue:
|
||||
self.lock.release()
|
||||
return total_num
|
||||
|
||||
def put_connect_rdma_task(self, connect_rdma_task):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_response_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_response_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
|
||||
def get_prefill_instances(self):
|
||||
"""
|
||||
check if the prefill queue is empty
|
||||
@@ -365,6 +453,29 @@ class EngineWorkerQueue:
|
||||
llm_logger.debug(f"get finished req: {ans}")
|
||||
return ans
|
||||
|
||||
def put_finished_add_cache_task_req(self, req_ids) -> None:
|
||||
"""
|
||||
Put finished request ID into the queue.
|
||||
|
||||
Args:
|
||||
req_ids: Request ID to be added to the queue
|
||||
"""
|
||||
self.finished_add_cache_task_queue.put(req_ids)
|
||||
|
||||
def get_finished_add_cache_task_req(self) -> str:
|
||||
"""
|
||||
Get finished request ID from the queue.
|
||||
|
||||
Returns:
|
||||
str: Finished request ID
|
||||
"""
|
||||
ans = []
|
||||
if self.finished_add_cache_task_queue.empty():
|
||||
return ans
|
||||
ans = self.finished_add_cache_task_queue.get()
|
||||
llm_logger.debug(f"get finished req: {ans}")
|
||||
return ans
|
||||
|
||||
def disaggregate_queue_empty(self):
|
||||
"""
|
||||
Check if the disaggregated task queue is empty.
|
||||
|
@@ -211,9 +211,8 @@ class DeepEPEngine:
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.async_finish = async_finish
|
||||
from paddle.base.core import Config
|
||||
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
self.ep_config = None
|
||||
|
||||
# Store phase and role for buffer management
|
||||
self._splitwise_role = splitwise_role
|
||||
|
@@ -76,6 +76,7 @@ else:
|
||||
update_inputs,
|
||||
step_reschedule,
|
||||
update_inputs_v1,
|
||||
speculate_step_reschedule,
|
||||
)
|
||||
|
||||
|
||||
@@ -413,12 +414,11 @@ def step_cuda(
|
||||
"""
|
||||
|
||||
if speculative_config.method is not None:
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
if DISABLE_RECOVER:
|
||||
speculate_step_reschedule(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
@@ -444,64 +444,67 @@ def step_cuda(
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
if enable_prefix_caching:
|
||||
step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
elif DISABLE_RECOVER:
|
||||
if DISABLE_RECOVER:
|
||||
step_reschedule(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
@@ -529,32 +532,61 @@ def step_cuda(
|
||||
enc_dec_block_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
if enable_prefix_caching:
|
||||
step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
|
||||
|
||||
def rebuild_padding(
|
||||
|
@@ -58,7 +58,6 @@ class TokenProcessor:
|
||||
self.split_connector = split_connector
|
||||
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
|
||||
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
|
||||
self.zmq_server = ZmqIpcServer(
|
||||
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
|
||||
@@ -298,10 +297,15 @@ class TokenProcessor:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
@@ -370,14 +374,18 @@ class TokenProcessor:
|
||||
llm_logger.info(f"finished_task_id: {finished_task_id}")
|
||||
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
|
||||
if task_id in self.prefill_result_status:
|
||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager.finish_requests_async(task_id)
|
||||
else:
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
if task_id in self.resource_manager.req_dict:
|
||||
del self.resource_manager.req_dict[task_id]
|
||||
if self.prefill_result_status[task_id] != "finished":
|
||||
result.error_code = 400
|
||||
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
|
||||
del self.resource_manager.req_dict[task_id]
|
||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||
break
|
||||
else:
|
||||
time.sleep(0.002)
|
||||
@@ -388,6 +396,8 @@ class TokenProcessor:
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
if task_id in self.resource_manager.req_dict:
|
||||
del self.resource_manager.req_dict[task_id]
|
||||
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(
|
||||
@@ -461,16 +471,22 @@ class TokenProcessor:
|
||||
|
||||
task_id = task.request_id
|
||||
if self.cfg.speculative_config.method:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
||||
continue
|
||||
if accept_num[i] == -3:
|
||||
recovery_stop = True
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||
else:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||
continue
|
||||
else:
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
@@ -527,7 +543,7 @@ class TokenProcessor:
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
@@ -537,7 +553,8 @@ class TokenProcessor:
|
||||
for token_id in token_ids:
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
if self.use_logprobs:
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
@@ -567,7 +584,11 @@ class TokenProcessor:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
if (
|
||||
not is_prefill
|
||||
or self.cfg.scheduler_config.name == "splitwise"
|
||||
or self.cfg.scheduler_config.name == "dp"
|
||||
):
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
@@ -609,7 +630,7 @@ class TokenProcessor:
|
||||
self.cfg.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
real_accept_num = [x for x in accept_num if x != 0]
|
||||
real_accept_num = [x for x in accept_num if x > 0]
|
||||
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
num_emitted_tokens = sum(real_accept_num)
|
||||
|
@@ -18,6 +18,7 @@ import redis
|
||||
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
from .dp_scheduler import DPScheduler
|
||||
from .global_scheduler import GlobalScheduler
|
||||
from .local_scheduler import LocalScheduler
|
||||
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
|
||||
@@ -89,6 +90,54 @@ class LocalSchedulerConfig:
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
|
||||
class DPLocalSchedulerConfig(LocalSchedulerConfig):
|
||||
"""
|
||||
Configuration class for DPLocalScheduler.
|
||||
Attributes:
|
||||
max_size: Maximum number of concurrent requests (-1 for unlimited)
|
||||
ttl: Time-to-live in seconds for request expiration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = -1,
|
||||
ttl: int = 900,
|
||||
max_model_len: int = 8192,
|
||||
enable_chunked_prefill: bool = False,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
splitwise_role: str = "prefill",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize LocalScheduler configuration.
|
||||
Args:
|
||||
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
|
||||
ttl: Time-to-live in seconds for request expiration (default 900s)
|
||||
max_model_len: Maximum model context length in tokens
|
||||
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||
max_num_partial_prefills: Max partial prefill operations allowed
|
||||
max_long_partial_prefills: Max long-running partial prefill ops
|
||||
long_prefill_token_threshold: Token count threshold for long prefill
|
||||
**kwargs: Additional unused arguments (for forward compatibility)
|
||||
Note:
|
||||
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
|
||||
- See LocalScheduler class for implementation details
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
self.splitwise_role = splitwise_role
|
||||
|
||||
|
||||
class GlobalSchedulerConfig:
|
||||
"""
|
||||
Configuration class for GlobalScheduler (Redis-based).
|
||||
@@ -235,6 +284,9 @@ class SchedulerConfig:
|
||||
if self.name == "splitwise":
|
||||
self.config = SplitWiseSchedulerConfig(**args)
|
||||
|
||||
if self.name == "dp":
|
||||
self.config = DPLocalSchedulerConfig(**args)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Validate the configuration.
|
||||
@@ -242,7 +294,7 @@ class SchedulerConfig:
|
||||
Raises:
|
||||
Exception: If invalid scheduler type is specified
|
||||
"""
|
||||
if self.name not in ["local", "global", "splitwise"]:
|
||||
if self.name not in ["local", "global", "splitwise", "dp"]:
|
||||
raise Exception(f"Unknown scheduler type {self.name}")
|
||||
|
||||
self.config.check()
|
||||
@@ -280,6 +332,17 @@ class SchedulerConfig:
|
||||
if self.name == "splitwise":
|
||||
return SplitWiseScheduler(self.config)
|
||||
|
||||
if self.name == "dp":
|
||||
return DPScheduler(
|
||||
max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
enable_chunked_prefill=self.config.enable_chunked_prefill,
|
||||
max_num_partial_prefills=self.config.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.config.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
|
||||
splitwise_role=self.config.splitwise_role,
|
||||
)
|
||||
|
||||
return LocalScheduler(
|
||||
max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
|
272
fastdeploy/scheduler/dp_scheduler.py
Normal file
272
fastdeploy/scheduler/dp_scheduler.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing import Queue
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledResponse
|
||||
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
|
||||
class DPLocalScheduler(LocalScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
splitwise_role: str = "prefill",
|
||||
):
|
||||
super().__init__(
|
||||
max_size,
|
||||
ttl,
|
||||
enable_chunked_prefill,
|
||||
max_num_partial_prefills,
|
||||
max_long_partial_prefills,
|
||||
long_prefill_token_threshold,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.scheduler_logger = logging
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
Add processing results back to the scheduler.
|
||||
Args:
|
||||
results: List of RequestOutput objects containing results
|
||||
"""
|
||||
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
|
||||
|
||||
finished_responses = [response.request_id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||
|
||||
with self.mutex:
|
||||
for response in responses:
|
||||
if response.request_id not in self.responses:
|
||||
self.responses[response.request_id] = [response]
|
||||
continue
|
||||
self.responses[response.request_id].append(response)
|
||||
self.responses_not_empty.notify_all()
|
||||
|
||||
def _recycle(self, request_id: Optional[str] = None):
|
||||
"""
|
||||
Clean up expired or completed requests to free memory.
|
||||
Args:
|
||||
request_id: Optional specific request ID to remove.
|
||||
If None, removes all expired requests.
|
||||
"""
|
||||
if request_id is not None:
|
||||
self.requests.pop(request_id, None)
|
||||
self.responses.pop(request_id, None)
|
||||
if self.splitwise_role == "decode":
|
||||
return
|
||||
self.ids.pop(self.ids.index(request_id))
|
||||
self.ids_read_cursor -= 1
|
||||
return
|
||||
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
|
||||
if len(self.requests) <= self.max_size:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
for request_id in self.ids:
|
||||
request = self.requests[request_id]
|
||||
if now - request.schedule_time < self.ttl:
|
||||
break
|
||||
expired_ids.append(request.request_id)
|
||||
|
||||
for i, expired_id in enumerate(expired_ids):
|
||||
self.requests.pop(expired_id, None)
|
||||
self.responses.pop(expired_id, None)
|
||||
self.ids.pop(i)
|
||||
|
||||
if len(expired_ids) > 0:
|
||||
if len(expired_ids) - 1 >= self.ids_read_cursor:
|
||||
self.ids_read_cursor = 0
|
||||
else:
|
||||
self.ids_read_cursor -= len(expired_ids)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
"""
|
||||
Retrieve requests from the scheduler based on available resources.
|
||||
|
||||
Args:
|
||||
available_blocks: Number of available processing blocks
|
||||
block_size: Size of each processing block
|
||||
reserved_output_blocks: Blocks reserved for output
|
||||
max_num_batched_tokens: Maximum tokens that can be batched
|
||||
batch: Preferred batch size
|
||||
|
||||
Returns:
|
||||
List of Request objects ready for processing
|
||||
"""
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||
)
|
||||
return []
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
start_batch_time = time.time()
|
||||
requests: List[Request] = []
|
||||
|
||||
with self.requests_not_empty:
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
while True:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
)
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(
|
||||
request.prompt_tokens_ids_len, block_size
|
||||
)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
if len(requests) >= batch:
|
||||
break
|
||||
if (
|
||||
(current_prefill_tokens > max_num_batched_tokens)
|
||||
or (len(requests) >= batch)
|
||||
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||
):
|
||||
break
|
||||
else:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
)
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
|
||||
if batch_ids:
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||
)
|
||||
|
||||
if len(requests) > 0:
|
||||
self.scheduler_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
class DPScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
splitwise_role: str = "prefill",
|
||||
):
|
||||
self._scheduler = DPLocalScheduler(
|
||||
max_size,
|
||||
ttl,
|
||||
enable_chunked_prefill,
|
||||
max_num_partial_prefills,
|
||||
max_long_partial_prefills,
|
||||
long_prefill_token_threshold,
|
||||
splitwise_role,
|
||||
)
|
||||
|
||||
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
|
||||
self.dp_rank = dp_rank
|
||||
self.request_queues = request_queues
|
||||
self.result_queue = result_queue
|
||||
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
|
||||
self._scheduler.scheduler_logger = self.scheduler_logger
|
||||
threading.Thread(target=self._put_requests_to_local).start()
|
||||
threading.Thread(target=self._get_response_from_local).start()
|
||||
|
||||
def put_requests(self, requests: List[Dict]):
|
||||
results = []
|
||||
for request in requests:
|
||||
if not hasattr(request, "dp_rank"):
|
||||
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
|
||||
self.request_queues[request.dp_rank].put(request)
|
||||
results.append((request.request_id, None))
|
||||
return results
|
||||
|
||||
def _put_requests_to_local(self):
|
||||
while True:
|
||||
request = self.request_queues[self.dp_rank].get()
|
||||
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
|
||||
self._scheduler.put_requests([request])
|
||||
|
||||
def _get_response_from_local(self):
|
||||
while True:
|
||||
results = self._scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
continue
|
||||
self.result_queue.put(results)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
return self._scheduler.get_requests(
|
||||
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
|
||||
)
|
||||
|
||||
def get_unhandled_request_num(self):
|
||||
return len(self._scheduler.requests)
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
self._scheduler.put_results(results)
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
return self.result_queue.get()
|
@@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
|
||||
|
||||
class SplitwiseConnector:
|
||||
"""
|
||||
@@ -46,12 +44,19 @@ class SplitwiseConnector:
|
||||
resource_manager (object): Resource manager object.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.logger = get_logger(
|
||||
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
self.engine_worker_queue = worker_queue
|
||||
self.resource_manager = resource_manager
|
||||
self.connect_innode_instances = {}
|
||||
self.temp_cache_info = dict()
|
||||
self.current_request_ids = dict()
|
||||
self.idx = self.cfg.parallel_config.local_data_parallel_id
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
self.zmq_ctx = zmq.Context()
|
||||
@@ -70,7 +75,7 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
@@ -90,17 +95,17 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
logger.debug(f"receive {socks}")
|
||||
self.logger.debug(f"receive {socks}")
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
logger.debug(f"frames: {frames}")
|
||||
self.logger.debug(f"frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_push_socket(self, addr):
|
||||
@@ -112,7 +117,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
try:
|
||||
logger.info(f"Establishing new connection to {addr}")
|
||||
self.logger.info(f"Establishing new connection to {addr}")
|
||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||
|
||||
# 设置连接参数
|
||||
@@ -131,7 +136,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"Connection to {addr} failed: {e}")
|
||||
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||
|
||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||
|
||||
@@ -140,7 +145,7 @@ class SplitwiseConnector:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
message = self._serialize_message(msg_type, payload)
|
||||
|
||||
try:
|
||||
@@ -148,19 +153,19 @@ class SplitwiseConnector:
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
|
||||
logger.info(f"Sent {msg_type} to {addr}")
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
logger.warning(f"Connection to {addr} not established")
|
||||
self.logger.warning(f"Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
logger.warning(f"Send queue full for {addr}")
|
||||
self.logger.warning(f"Send queue full for {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
main_process_metrics.send_cache_failed_num.inc()
|
||||
self._close_connection(addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message preparation failed: {e}")
|
||||
self.logger.error(f"Message preparation failed: {e}")
|
||||
|
||||
def _close_connection(self, addr):
|
||||
"""
|
||||
@@ -265,7 +270,7 @@ class SplitwiseConnector:
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
@@ -295,7 +300,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
logger.info(f"send splitwise tasks to port {port} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||
current_port = port
|
||||
return current_port
|
||||
|
||||
@@ -305,7 +310,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
logger.info("send first token to port decode")
|
||||
self.logger.info("send first token to port decode")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
@@ -313,7 +318,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
logger.info(f"send first token to port {node} decode")
|
||||
self.logger.info(f"send first token to port {node} decode")
|
||||
self._send_message(node, "decode", tasks_list)
|
||||
|
||||
def create_connection(self, port):
|
||||
@@ -329,6 +334,26 @@ class SplitwiseConnector:
|
||||
client_id=0,
|
||||
)
|
||||
|
||||
def check_decode_allocated(self, task):
|
||||
start_time = time.time()
|
||||
if task.disaggregate_info is None:
|
||||
return True, ""
|
||||
if self.enable_decode_cache_task:
|
||||
return True, ""
|
||||
if task.disaggregate_info["role"] != "prefill":
|
||||
return True, ""
|
||||
while self.current_request_ids[task.request_id] == "init":
|
||||
time.sleep(0.001)
|
||||
if time.time() - start_time > 30:
|
||||
del self.current_request_ids[task.request_id]
|
||||
return False, "timeout"
|
||||
msg = self.current_request_ids[task.request_id]
|
||||
del self.current_request_ids[task.request_id]
|
||||
if msg == "finished":
|
||||
return True, ""
|
||||
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_infos(self, tasks, current_id):
|
||||
"""
|
||||
Send cache information to specific port.
|
||||
@@ -345,7 +370,7 @@ class SplitwiseConnector:
|
||||
for i in range(len(tasks)):
|
||||
if tasks[i].disaggregate_info is None:
|
||||
continue
|
||||
logger.info(f"{tasks[i].disaggregate_info}")
|
||||
self.logger.info(f"{tasks[i].disaggregate_info}")
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||
cache_info = {
|
||||
@@ -380,11 +405,19 @@ class SplitwiseConnector:
|
||||
addr = "prefill"
|
||||
if current_id == -1:
|
||||
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": current_id,
|
||||
}
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": tasks[i].idx,
|
||||
"need_prefill_tokens": tasks[i].need_prefill_tokens,
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": current_id,
|
||||
}
|
||||
if addr not in temp_cache_info:
|
||||
temp_cache_info[addr] = []
|
||||
|
||||
@@ -396,7 +429,7 @@ class SplitwiseConnector:
|
||||
else:
|
||||
if len(temp_cache_info):
|
||||
for k, v in temp_cache_info.items():
|
||||
logger.info(f"{k} {v}")
|
||||
self.logger.info(f"{k} {v}")
|
||||
if ":" in str(k):
|
||||
self._send_message(k, "cache_sync", v)
|
||||
else:
|
||||
@@ -427,7 +460,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
logger.info(f"{msg_type}")
|
||||
self.logger.info(f"{msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
self._handle_prefill(payload)
|
||||
@@ -435,11 +468,16 @@ class SplitwiseConnector:
|
||||
self._handle_decode(payload)
|
||||
elif msg_type == "cache_sync":
|
||||
for task in payload:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
self.logger.info(f"cache_sync task: {task}")
|
||||
current_status = task.get("error_msg", "finished")
|
||||
self.current_request_ids[task["request_id"]] = current_status
|
||||
if self.enable_decode_cache_task:
|
||||
del self.current_request_ids[task["request_id"]]
|
||||
if current_status == "finished":
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def _handle_prefill(self, tasks):
|
||||
"""
|
||||
@@ -462,8 +500,12 @@ class SplitwiseConnector:
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||
),
|
||||
finished=True,
|
||||
num_cached_tokens=task["num_cached_tokens"],
|
||||
error_code=task["error_code"],
|
||||
error_msg=task["error_msg"],
|
||||
)
|
||||
)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
@@ -259,6 +260,7 @@ class PaddleDisWorkerProc:
|
||||
"""Main event loop for Paddle Distributed Workers.
|
||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||
"""
|
||||
|
||||
# Currently, only support single node
|
||||
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
||||
req_ids = []
|
||||
@@ -643,6 +645,12 @@ def parse_args():
|
||||
help="Flag to specify dtype of lm_head as FP32",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache-transfer-protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="support protocol list, comma separated, default is ipc",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runner",
|
||||
type=str,
|
||||
@@ -762,8 +770,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
):
|
||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if args.splitwise_role != "mixed":
|
||||
logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
|
||||
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if not current_platform.is_cuda():
|
||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
|
||||
@@ -772,6 +779,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill":
|
||||
os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"
|
||||
|
||||
fd_config = FDConfig(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
|
Reference in New Issue
Block a user