mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
|||||||
const int max_bsz,
|
const int max_bsz,
|
||||||
const int input_ids_stride,
|
const int input_ids_stride,
|
||||||
const int block_num_per_seq,
|
const int block_num_per_seq,
|
||||||
const int block_size) {
|
const int block_size,
|
||||||
|
bool prefill_one_step_stop) {
|
||||||
int thread_idx = threadIdx.x;
|
int thread_idx = threadIdx.x;
|
||||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
|||||||
seq_lens_encoder[thread_idx] = 0;
|
seq_lens_encoder[thread_idx] = 0;
|
||||||
} else {
|
} else {
|
||||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||||
// decoding
|
if (prefill_one_step_stop) {
|
||||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
// prefill done, stop
|
||||||
seq_lens_this_time[thread_idx] = 1;
|
stop_flags[thread_idx] = true;
|
||||||
seq_lens_encoder[thread_idx] = 0;
|
seq_lens_this_time[thread_idx] = 0;
|
||||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
seq_lens_decoder[thread_idx] = 0;
|
||||||
input_ids_now[0] = next_tokens[thread_idx];
|
seq_lens_encoder[thread_idx] = 0;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
} else{
|
||||||
|
// decoding
|
||||||
|
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||||
|
seq_lens_this_time[thread_idx] = 1;
|
||||||
|
seq_lens_encoder[thread_idx] = 0;
|
||||||
|
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||||
|
input_ids_now[0] = next_tokens[thread_idx];
|
||||||
|
|
||||||
// to judge whether block is not enough
|
// to judge whether block is not enough
|
||||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||||
// should be scheduled by server
|
// should be scheduled by server
|
||||||
is_block_step[thread_idx] = true;
|
is_block_step[thread_idx] = true;
|
||||||
seq_lens_this_time[thread_idx]= 0;
|
seq_lens_this_time[thread_idx]= 0;
|
||||||
stop_flags[thread_idx] = true;
|
stop_flags[thread_idx] = true;
|
||||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||||
seq_lens_decoder[thread_idx] = 0;
|
seq_lens_decoder[thread_idx] = 0;
|
||||||
stop_flag_now_int = 1;
|
stop_flag_now_int = 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
{
|
{
|
||||||
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
|||||||
#else
|
#else
|
||||||
auto cu_stream = input_ids.stream();
|
auto cu_stream = input_ids.stream();
|
||||||
#endif
|
#endif
|
||||||
|
bool prefill_one_step_stop = false;
|
||||||
|
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||||
|
if (env_p[0] == '1') {
|
||||||
|
prefill_one_step_stop = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
const int max_bsz = stop_flags.shape()[0];
|
const int max_bsz = stop_flags.shape()[0];
|
||||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||||
const int input_ids_stride = input_ids.shape()[1];
|
const int input_ids_stride = input_ids.shape()[1];
|
||||||
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
|||||||
max_bsz,
|
max_bsz,
|
||||||
input_ids_stride,
|
input_ids_stride,
|
||||||
block_num_per_seq,
|
block_num_per_seq,
|
||||||
block_size);
|
block_size,
|
||||||
|
prefill_one_step_stop);
|
||||||
auto not_need_stop_cpu =
|
auto not_need_stop_cpu =
|
||||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||||
|
@@ -14,7 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -23,16 +26,72 @@ import numpy as np
|
|||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
|
||||||
|
from fastdeploy.config import SpeculativeConfig
|
||||||
from fastdeploy.inter_communicator import (
|
from fastdeploy.inter_communicator import (
|
||||||
EngineWorkerQueue,
|
EngineWorkerQueue,
|
||||||
IPCSignal,
|
IPCSignal,
|
||||||
shared_memory_exists,
|
shared_memory_exists,
|
||||||
)
|
)
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
|
||||||
|
from fastdeploy.utils import envs, get_logger
|
||||||
|
|
||||||
logger = get_logger("cache_messager", "cache_messager.log")
|
logger = get_logger("cache_messager", "cache_messager.log")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
从命令行解析参数
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser("Cache Messager")
|
||||||
|
parser.add_argument(
|
||||||
|
"--splitwise_role",
|
||||||
|
type=str,
|
||||||
|
default="mixed",
|
||||||
|
help="splitwise role, can be decode, prefill or mixed",
|
||||||
|
)
|
||||||
|
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||||
|
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||||
|
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||||
|
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||||
|
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||||
|
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||||
|
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||||
|
parser.add_argument(
|
||||||
|
"--protocol",
|
||||||
|
type=str,
|
||||||
|
default="ipc",
|
||||||
|
help="cache transfer protocol, only surport ipc now",
|
||||||
|
)
|
||||||
|
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||||
|
parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
|
||||||
|
parser.add_argument(
|
||||||
|
"--engine_worker_queue_port",
|
||||||
|
type=int,
|
||||||
|
default=9923,
|
||||||
|
help="engine worker queue port",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||||
|
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dtype",
|
||||||
|
type=str,
|
||||||
|
default="bfloat16",
|
||||||
|
choices=["uint8", "bfloat16"],
|
||||||
|
help="cache dtype",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative_config",
|
||||||
|
type=json.loads,
|
||||||
|
default="{}",
|
||||||
|
help="speculative config",
|
||||||
|
)
|
||||||
|
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
class CacheMessager:
|
class CacheMessager:
|
||||||
"""
|
"""
|
||||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||||
@@ -69,11 +128,6 @@ class CacheMessager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert splitwise_role in [
|
|
||||||
"prefill",
|
|
||||||
"decode",
|
|
||||||
], "splitwise_role must be prefill or decode"
|
|
||||||
self.splitwise_role = splitwise_role
|
self.splitwise_role = splitwise_role
|
||||||
self.gpu_cache_kvs = gpu_cache_kvs
|
self.gpu_cache_kvs = gpu_cache_kvs
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@@ -147,15 +201,16 @@ class CacheMessager:
|
|||||||
|
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.cache_info = dict()
|
self.cache_info = dict()
|
||||||
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
|
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||||
|
|
||||||
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
|
if self.splitwise_role != "mixed":
|
||||||
layerwise_send_cache_thread.daemon = True
|
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||||
layerwise_send_cache_thread.start()
|
connect_rdma_thread.daemon = True
|
||||||
|
connect_rdma_thread.start()
|
||||||
|
|
||||||
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||||
|
|
||||||
def _prefill_layerwise_send_cache_thread(self):
|
def prefill_layerwise_send_cache_thread(self):
|
||||||
"""
|
"""
|
||||||
layerwise_send_cache_thread:
|
layerwise_send_cache_thread:
|
||||||
send cache to other instance
|
send cache to other instance
|
||||||
@@ -163,23 +218,23 @@ class CacheMessager:
|
|||||||
try:
|
try:
|
||||||
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||||
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}"
|
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
||||||
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
|
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
|
||||||
step_shm_value = IPCSignal(
|
step_shm_value = IPCSignal(
|
||||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
|
||||||
array=prefilled_step_idx_data,
|
array=prefilled_step_idx_data,
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
suffix=self.gpu_id,
|
suffix=self.gpu_id,
|
||||||
create=not shared_memory_exists(prefilled_step_name),
|
create=not shared_memory_exists(prefilled_step_name),
|
||||||
)
|
)
|
||||||
layer_shm_value = IPCSignal(
|
layer_shm_value = IPCSignal(
|
||||||
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
|
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
|
||||||
array=prefilled_layer_idx_data,
|
array=prefilled_layer_idx_data,
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
suffix=self.gpu_id,
|
suffix=self.gpu_id,
|
||||||
create=not shared_memory_exists(prefilled_layer_name),
|
create=not shared_memory_exists(prefilled_layer_name),
|
||||||
)
|
)
|
||||||
logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}")
|
logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")
|
||||||
|
|
||||||
step_shm_value.value[0] = -1
|
step_shm_value.value[0] = -1
|
||||||
layer_shm_value.value[0] = -1
|
layer_shm_value.value[0] = -1
|
||||||
@@ -187,6 +242,9 @@ class CacheMessager:
|
|||||||
self.last_step_idx = -1
|
self.last_step_idx = -1
|
||||||
self.last_layer_idx = -1 # int32
|
self.last_layer_idx = -1 # int32
|
||||||
|
|
||||||
|
max_step_idx = 100003
|
||||||
|
engine_recycled_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
cache_info = self.engine_worker_queue.get_cache_info()
|
cache_info = self.engine_worker_queue.get_cache_info()
|
||||||
@@ -202,11 +260,9 @@ class CacheMessager:
|
|||||||
-len(current_info["dest_block_ids"]) :
|
-len(current_info["dest_block_ids"]) :
|
||||||
]
|
]
|
||||||
current_info["src_block_ids"] = current_src_blocks
|
current_info["src_block_ids"] = current_src_blocks
|
||||||
current_info["current_layer_ids"] = 0
|
|
||||||
current_info["status"] = "init"
|
current_info["status"] = "init"
|
||||||
logger.info(f"start cache_infos: {current_info}")
|
logger.info(f"start cache_infos: {current_info}")
|
||||||
self.cache_info[info["request_id"]] = current_info
|
self.cache_info[info["request_id"]] = current_info
|
||||||
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
|
|
||||||
else:
|
else:
|
||||||
self.cache_info[info["request_id"]] = info
|
self.cache_info[info["request_id"]] = info
|
||||||
prefilled_layer_idx = layer_shm_value.value[0]
|
prefilled_layer_idx = layer_shm_value.value[0]
|
||||||
@@ -223,7 +279,18 @@ class CacheMessager:
|
|||||||
if not self.cache_info:
|
if not self.cache_info:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
|
||||||
|
if self.last_step_idx > prefilled_step_idx:
|
||||||
|
engine_recycled_count += 1
|
||||||
|
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
|
||||||
|
prefilled_step_idx = (
|
||||||
|
prefilled_step_idx + max_step_idx * engine_recycled_count
|
||||||
|
) # remap prefilled_step_idx for comparison
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
|
||||||
|
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
|
||||||
|
)
|
||||||
for req_id, item in list(self.cache_info.items()):
|
for req_id, item in list(self.cache_info.items()):
|
||||||
if "status" not in item:
|
if "status" not in item:
|
||||||
continue
|
continue
|
||||||
@@ -294,12 +361,493 @@ class CacheMessager:
|
|||||||
logger.info(f"finish write cache {item['request_id']}")
|
logger.info(f"finish write cache {item['request_id']}")
|
||||||
self.engine_worker_queue.finish_request_barrier.wait()
|
self.engine_worker_queue.finish_request_barrier.wait()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||||
logger.info(f"put write cache {item['request_id']}")
|
logger.info(f"put write cache {item['request_id']}")
|
||||||
del self.cache_info[req_id]
|
del self.cache_info[req_id]
|
||||||
|
self.last_layer_idx = prefilled_layer_idx
|
||||||
self.last_step_idx = prefilled_step_idx
|
|
||||||
self.last_layer_idx = prefilled_layer_idx
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
|
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
|
||||||
|
|
||||||
|
def _handle_connect_task(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||||
|
if task is None:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
logger.info(f"_handle_connect_task recv task: {task}")
|
||||||
|
task_id = task["task_id"]
|
||||||
|
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||||
|
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||||
|
if not status:
|
||||||
|
response = {"task_id": task_id, "success": False}
|
||||||
|
else:
|
||||||
|
response = {"task_id": task_id, "success": True}
|
||||||
|
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"handle_connect_task has exception: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class CacheMessagerV1:
|
||||||
|
"""
|
||||||
|
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
splitwise_role,
|
||||||
|
transfer_protocol,
|
||||||
|
pod_ip,
|
||||||
|
engine_worker_queue_port,
|
||||||
|
local_data_parallel_id,
|
||||||
|
gpu_cache_kvs,
|
||||||
|
rank,
|
||||||
|
nranks,
|
||||||
|
num_layers,
|
||||||
|
gpu_id=0,
|
||||||
|
block_size=64,
|
||||||
|
rdma_port=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the CacheMessager object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
|
||||||
|
transfer_protocol (str): support ipc and rdma
|
||||||
|
engine_worker_queue_port (int): engine_worker_queue port
|
||||||
|
gpu_cache_kvs (dict): GPU kv cache
|
||||||
|
rank (int): current rank
|
||||||
|
nranks (int): global rank number
|
||||||
|
num_layers (int): model layer number
|
||||||
|
gpu_id (int, optional): GPU ID
|
||||||
|
rdma_port (int, optional): RDMA port
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
self.splitwise_role = splitwise_role
|
||||||
|
self.gpu_cache_kvs = gpu_cache_kvs
|
||||||
|
self.rank = rank
|
||||||
|
self.nranks = nranks
|
||||||
|
address = (pod_ip, engine_worker_queue_port)
|
||||||
|
self.engine_worker_queue = EngineWorkerQueue(
|
||||||
|
address=address,
|
||||||
|
is_server=False,
|
||||||
|
num_client=self.nranks,
|
||||||
|
client_id=self.rank,
|
||||||
|
local_data_parallel_id=local_data_parallel_id,
|
||||||
|
)
|
||||||
|
self.block_size = block_size
|
||||||
|
transfer_protocol = transfer_protocol.split(",")
|
||||||
|
|
||||||
|
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||||
|
|
||||||
|
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||||
|
self.num_layers = num_layers
|
||||||
|
cache_k_ptr_list = []
|
||||||
|
cache_v_ptr_list = []
|
||||||
|
cache_k = []
|
||||||
|
cache_v = []
|
||||||
|
self.messager = {}
|
||||||
|
for layer_idx in range(self.num_layers):
|
||||||
|
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||||
|
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||||
|
cache_k.append(key_cache)
|
||||||
|
cache_v.append(val_cache)
|
||||||
|
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||||
|
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||||
|
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||||
|
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||||
|
|
||||||
|
# 2. initialize the block_bytes
|
||||||
|
cache_shape = key_cache.shape
|
||||||
|
max_block_num = cache_shape[0]
|
||||||
|
block_bytes = math.prod(cache_shape[1:])
|
||||||
|
if key_cache.dtype == paddle.bfloat16:
|
||||||
|
block_bytes *= 2
|
||||||
|
logger.info(
|
||||||
|
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
|
||||||
|
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
|
||||||
|
)
|
||||||
|
self.block_bytes = block_bytes
|
||||||
|
|
||||||
|
# 3. initialize the messager
|
||||||
|
for protocol in transfer_protocol:
|
||||||
|
if protocol == "ipc":
|
||||||
|
self.messager[protocol] = IPCCommManager(
|
||||||
|
self.rank,
|
||||||
|
gpu_id,
|
||||||
|
cache_k,
|
||||||
|
cache_v,
|
||||||
|
)
|
||||||
|
local_device_id = int(str(cache_k[0].place)[-2])
|
||||||
|
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
||||||
|
|
||||||
|
elif protocol == "rdma":
|
||||||
|
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
||||||
|
|
||||||
|
self.messager[protocol] = RDMACommManager(
|
||||||
|
splitwise_role,
|
||||||
|
rank,
|
||||||
|
gpu_id,
|
||||||
|
cache_k_ptr_list,
|
||||||
|
cache_v_ptr_list,
|
||||||
|
max_block_num,
|
||||||
|
block_bytes,
|
||||||
|
rdma_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
self.cache_info = dict()
|
||||||
|
self.rank_id = self.rank + local_data_parallel_id * self.nranks
|
||||||
|
self.engine_cache_task_thread_lock = threading.Lock()
|
||||||
|
self.engine_cache_tasks = [dict() for _ in range(512)]
|
||||||
|
self.idx_cache_task_dict = {}
|
||||||
|
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
|
||||||
|
if splitwise_role == "prefill":
|
||||||
|
consume_signals_thread = threading.Thread(target=self.consume_signals)
|
||||||
|
consume_signals_thread.daemon = True
|
||||||
|
consume_signals_thread.start()
|
||||||
|
add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
|
||||||
|
add_cache_task_thread.daemon = True
|
||||||
|
add_cache_task_thread.start()
|
||||||
|
|
||||||
|
if self.splitwise_role != "mixed":
|
||||||
|
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||||
|
connect_rdma_thread.daemon = True
|
||||||
|
connect_rdma_thread.start()
|
||||||
|
|
||||||
|
logger.info(f"cache messager init finished, use {transfer_protocol}")
|
||||||
|
|
||||||
|
def _add_cache_task_thread(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
cache_info = self.engine_worker_queue.get_cache_info()
|
||||||
|
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||||
|
finished_add_cache_task_req_ids = []
|
||||||
|
if cache_info:
|
||||||
|
for info in cache_info:
|
||||||
|
if info["request_id"] in self.cache_info:
|
||||||
|
self.cache_info[info["request_id"]].update(info)
|
||||||
|
current_info = self.cache_info[info["request_id"]]
|
||||||
|
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
|
||||||
|
finished_add_cache_task_req_ids.append(info["request_id"])
|
||||||
|
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
|
||||||
|
current_info["dest_block_ids"]
|
||||||
|
)
|
||||||
|
padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
|
||||||
|
"dest_block_ids"
|
||||||
|
]
|
||||||
|
current_info["dest_block_ids"] = padding_decode_block_ids
|
||||||
|
current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
|
||||||
|
current_info["sended_layer_id"] = -1
|
||||||
|
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
||||||
|
current_info["status"] = "init"
|
||||||
|
logger.info(f"finish add cache task: {current_info}")
|
||||||
|
self.cache_info[info["request_id"]] = current_info
|
||||||
|
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
||||||
|
else:
|
||||||
|
self.cache_info[info["request_id"]] = info
|
||||||
|
if self.rank == 0 and finished_add_cache_task_req_ids:
|
||||||
|
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
||||||
|
else:
|
||||||
|
time.sleep(0.001)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"add cache task occured error: {e}, {traceback.format_exc()!s}.")
|
||||||
|
|
||||||
|
def prefill_layerwise_send_cache_thread(self):
|
||||||
|
"""
|
||||||
|
layerwise_send_cache_thread:
|
||||||
|
send cache to other instance
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
||||||
|
self.engine_worker_queue.finish_request_barrier.wait()
|
||||||
|
block_start_end_list = []
|
||||||
|
current_prefilled_token_num_list = []
|
||||||
|
for engine_index in engine_indexes:
|
||||||
|
assert engine_index in self.idx_cache_task_dict
|
||||||
|
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
||||||
|
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
||||||
|
if (
|
||||||
|
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
||||||
|
): # all chunks have been prefilled
|
||||||
|
block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
|
||||||
|
else:
|
||||||
|
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
||||||
|
block_start_end_list.append((block_id_start, block_id_end))
|
||||||
|
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||||
|
while True: # from layer0 to last layer
|
||||||
|
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
||||||
|
start_layer_idx = sended_layer_idx + 1
|
||||||
|
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
||||||
|
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
||||||
|
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
||||||
|
logger.info(
|
||||||
|
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
current_prefilled_token_num_list[0]
|
||||||
|
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
||||||
|
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
||||||
|
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
||||||
|
else:
|
||||||
|
end_layer_idx = prefilled_layer_idx
|
||||||
|
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
|
||||||
|
time.sleep(0.01)
|
||||||
|
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||||
|
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||||
|
engine_index = engine_indexes[i]
|
||||||
|
task = self.idx_cache_task_dict[engine_index]
|
||||||
|
req_id = task["request_id"]
|
||||||
|
if (
|
||||||
|
block_id_start >= block_id_end
|
||||||
|
): # no blocks need to transfer for this request in this chunk
|
||||||
|
task["sended_layer_id"] += 1
|
||||||
|
assert task["sended_layer_id"] == layer_idx
|
||||||
|
if task["sended_layer_id"] == self.num_layers - 1:
|
||||||
|
task["sended_layer_id"] = -1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
current_transfer_protocol = task["transfer_protocol"]
|
||||||
|
if task["transfer_protocol"] == "rdma":
|
||||||
|
target_ip = task["ip"]
|
||||||
|
target_id = int(task["rdma_ports"][self.rank])
|
||||||
|
if task["status"] == "error":
|
||||||
|
continue
|
||||||
|
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||||
|
if not status:
|
||||||
|
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||||
|
task["status"] = "connection error"
|
||||||
|
continue
|
||||||
|
elif task["transfer_protocol"] == "ipc":
|
||||||
|
target_ip = "0.0.0.0"
|
||||||
|
target_id = int(task["device_ids"][self.rank])
|
||||||
|
|
||||||
|
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
|
||||||
|
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
|
||||||
|
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
|
||||||
|
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||||
|
)
|
||||||
|
tic = time.time()
|
||||||
|
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||||
|
target_ip,
|
||||||
|
target_id,
|
||||||
|
src_block_ids,
|
||||||
|
dest_block_ids,
|
||||||
|
layer_idx,
|
||||||
|
)
|
||||||
|
if return_code != 0:
|
||||||
|
task["status"] = "write cache error"
|
||||||
|
logger.error(
|
||||||
|
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||||
|
)
|
||||||
|
tok = time.time()
|
||||||
|
cost_time = tok - tic
|
||||||
|
block_num = len(src_block_ids)
|
||||||
|
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||||
|
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||||
|
logger.debug(
|
||||||
|
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
|
||||||
|
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||||
|
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||||
|
)
|
||||||
|
|
||||||
|
task["sended_layer_id"] += 1
|
||||||
|
assert task["sended_layer_id"] == layer_idx
|
||||||
|
if task["sended_layer_id"] == self.num_layers - 1:
|
||||||
|
self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
|
||||||
|
block_id_end - block_id_start
|
||||||
|
)
|
||||||
|
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
|
||||||
|
if task["status"] != "error":
|
||||||
|
task["status"] = "finished"
|
||||||
|
logger.info(
|
||||||
|
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
task["sended_layer_id"] = -1
|
||||||
|
if end_layer_idx == self.num_layers - 1:
|
||||||
|
with self.engine_cache_task_thread_lock:
|
||||||
|
for engine_idx in engine_indexes:
|
||||||
|
task = self.idx_cache_task_dict[engine_idx]
|
||||||
|
if task["status"] == "finished" or ("error" in task["status"]):
|
||||||
|
target_id = int(task["rdma_ports"][self.rank])
|
||||||
|
if task["transfer_protocol"] == "ipc":
|
||||||
|
self.messager["ipc"].write_block_by_sync(target_id)
|
||||||
|
if self.rank == 0:
|
||||||
|
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||||
|
self.engine_worker_queue.put_finished_req(
|
||||||
|
[(task["request_id"], task["status"])]
|
||||||
|
)
|
||||||
|
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||||
|
self.engine_cache_tasks[task["current_id"]] = dict()
|
||||||
|
del self.cache_info[task["request_id"]]
|
||||||
|
del self.idx_cache_task_dict[task["current_id"]]
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def consume_signals(self):
|
||||||
|
paddle.device.set_device("cpu")
|
||||||
|
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
|
||||||
|
if not self.cache_info:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
tasks_count = kv_signal_data[0]
|
||||||
|
if tasks_count == -1:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
layer_id = kv_signal_data[1].numpy().tolist()
|
||||||
|
if layer_id == self.num_layers - 1:
|
||||||
|
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
||||||
|
batch_engine_ids = []
|
||||||
|
with self.engine_cache_task_thread_lock:
|
||||||
|
for bi in range(tasks_count):
|
||||||
|
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
||||||
|
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
|
||||||
|
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
|
||||||
|
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
|
||||||
|
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
||||||
|
chuck_token_offset + current_seq_len
|
||||||
|
)
|
||||||
|
batch_engine_ids.append(engine_idx)
|
||||||
|
if layer_id == 0:
|
||||||
|
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Consume signals get exception: {e}")
|
||||||
|
|
||||||
|
def _handle_connect_task(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||||
|
if task is None:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
logger.info(f"_handle_connect_task recv task: {task}")
|
||||||
|
task_id = task["task_id"]
|
||||||
|
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||||
|
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||||
|
if not status:
|
||||||
|
response = {"task_id": task_id, "success": False}
|
||||||
|
else:
|
||||||
|
response = {"task_id": task_id, "success": True}
|
||||||
|
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"handle_connect_task has exception: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = args.device_id
|
||||||
|
rank = args.rank
|
||||||
|
paddle.set_device(f"gpu:{device}")
|
||||||
|
cache_type = args.cache_dtype
|
||||||
|
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||||
|
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||||
|
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
|
||||||
|
gpu_cache_kvs = {}
|
||||||
|
gpu_cache_k_tensors = []
|
||||||
|
gpu_cache_v_tensors = []
|
||||||
|
|
||||||
|
for i in range(args.num_layers + num_extra_layers):
|
||||||
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
|
||||||
|
|
||||||
|
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||||
|
shape=[
|
||||||
|
num_gpu_blocks,
|
||||||
|
args.kv_num_head,
|
||||||
|
args.block_size,
|
||||||
|
args.head_dim,
|
||||||
|
],
|
||||||
|
fill_value=0,
|
||||||
|
dtype=cache_type,
|
||||||
|
)
|
||||||
|
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||||
|
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||||
|
shape=[
|
||||||
|
num_gpu_blocks,
|
||||||
|
args.kv_num_head,
|
||||||
|
args.block_size,
|
||||||
|
args.head_dim,
|
||||||
|
],
|
||||||
|
fill_value=0,
|
||||||
|
dtype=cache_type,
|
||||||
|
)
|
||||||
|
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||||
|
|
||||||
|
set_data_ipc(
|
||||||
|
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||||
|
f"key_caches_{i}_rank{rank}.device{device}",
|
||||||
|
)
|
||||||
|
set_data_ipc(
|
||||||
|
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||||
|
f"value_caches_{i}_rank{rank}.device{device}",
|
||||||
|
)
|
||||||
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||||
|
logger.info(f"device :{device}")
|
||||||
|
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||||
|
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||||
|
|
||||||
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
cache_messager = CacheMessagerV1(
|
||||||
|
splitwise_role=args.splitwise_role,
|
||||||
|
transfer_protocol=args.protocol,
|
||||||
|
pod_ip=args.pod_ip,
|
||||||
|
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||||
|
local_data_parallel_id=args.local_data_parallel_id,
|
||||||
|
gpu_cache_kvs=gpu_cache_kvs,
|
||||||
|
rank=rank,
|
||||||
|
nranks=args.mp_num,
|
||||||
|
num_layers=args.num_layers + num_extra_layers,
|
||||||
|
gpu_id=device,
|
||||||
|
rdma_port=args.rdma_port,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cache_messager = CacheMessager(
|
||||||
|
splitwise_role=args.splitwise_role,
|
||||||
|
transfer_protocol=args.protocol,
|
||||||
|
pod_ip=args.pod_ip,
|
||||||
|
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||||
|
local_data_parallel_id=args.local_data_parallel_id,
|
||||||
|
gpu_cache_kvs=gpu_cache_kvs,
|
||||||
|
rank=rank,
|
||||||
|
nranks=args.mp_num,
|
||||||
|
num_layers=args.num_layers + num_extra_layers,
|
||||||
|
gpu_id=device,
|
||||||
|
rdma_port=args.rdma_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||||
|
cache_ready_signal = IPCSignal(
|
||||||
|
name="cache_ready_signal",
|
||||||
|
array=cache_ready_signal_data,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=args.engine_pid,
|
||||||
|
create=False,
|
||||||
|
)
|
||||||
|
cache_ready_signal.value[rank] = 1
|
||||||
|
if args.splitwise_role == "mixed":
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
cache_messager.prefill_layerwise_send_cache_thread()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||||
|
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||||
|
|
||||||
|
logger.info("create cache messager...")
|
||||||
|
logger.info(f"{args}")
|
||||||
|
main()
|
||||||
|
@@ -29,7 +29,7 @@ from fastdeploy.config import SpeculativeConfig
|
|||||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
cuda_host_alloc,
|
cuda_host_alloc,
|
||||||
set_data_ipc,
|
share_external_data,
|
||||||
swap_cache_all_layers,
|
swap_cache_all_layers,
|
||||||
)
|
)
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
@@ -139,40 +139,27 @@ class CacheTransferManager:
|
|||||||
self.num_cpu_blocks = args.num_cpu_blocks
|
self.num_cpu_blocks = args.num_cpu_blocks
|
||||||
|
|
||||||
cache_type = args.cache_dtype
|
cache_type = args.cache_dtype
|
||||||
|
cache_shape = [
|
||||||
|
args.num_gpu_blocks,
|
||||||
|
args.kv_num_head,
|
||||||
|
args.block_size,
|
||||||
|
args.head_dim,
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(args.num_layers + self.num_extra_layers):
|
for i in range(args.num_layers + self.num_extra_layers):
|
||||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||||
|
cache_shape[0] = num_gpu_blocks
|
||||||
|
key_name = f"key_caches_{i}_rank{rank}.device{device}"
|
||||||
|
value_name = f"value_caches_{i}_rank{rank}.device{device}"
|
||||||
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
|
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||||
|
value_cache = share_external_data(value_cache, value_name, cache_shape)
|
||||||
|
self.gpu_cache_kvs[key_name] = key_cache
|
||||||
|
self.gpu_cache_kvs[value_name] = value_cache
|
||||||
|
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||||
|
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
|
||||||
|
|
||||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
|
||||||
shape=[
|
|
||||||
num_gpu_blocks,
|
|
||||||
args.kv_num_head,
|
|
||||||
args.block_size,
|
|
||||||
args.head_dim,
|
|
||||||
],
|
|
||||||
fill_value=0,
|
|
||||||
dtype=cache_type,
|
|
||||||
)
|
|
||||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
|
||||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
|
||||||
shape=[
|
|
||||||
num_gpu_blocks,
|
|
||||||
args.kv_num_head,
|
|
||||||
args.block_size,
|
|
||||||
args.head_dim,
|
|
||||||
],
|
|
||||||
fill_value=0,
|
|
||||||
dtype=cache_type,
|
|
||||||
)
|
|
||||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
|
||||||
|
|
||||||
set_data_ipc(
|
|
||||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
|
||||||
f"key_caches_{i}_rank{rank}.device{device}",
|
|
||||||
)
|
|
||||||
set_data_ipc(
|
|
||||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
|
||||||
f"value_caches_{i}_rank{rank}.device{device}",
|
|
||||||
)
|
|
||||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||||
logger.info(f"device :{self.device}")
|
logger.info(f"device :{self.device}")
|
||||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||||
@@ -201,28 +188,6 @@ class CacheTransferManager:
|
|||||||
)
|
)
|
||||||
self.cache_ready_signal.value[self.rank] = 1
|
self.cache_ready_signal.value[self.rank] = 1
|
||||||
|
|
||||||
paddle.set_device(f"gpu:{device}")
|
|
||||||
if args.enable_splitwise:
|
|
||||||
logger.debug("create cache messager...")
|
|
||||||
logger.info(f"{args}")
|
|
||||||
from fastdeploy.cache_manager.cache_messager import CacheMessager
|
|
||||||
|
|
||||||
self.cache_messager = CacheMessager(
|
|
||||||
splitwise_role=args.splitwise_role,
|
|
||||||
transfer_protocol=args.protocol,
|
|
||||||
pod_ip=args.pod_ip,
|
|
||||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
|
||||||
local_data_parallel_id=args.local_data_parallel_id,
|
|
||||||
gpu_cache_kvs=self.gpu_cache_kvs,
|
|
||||||
rank=self.rank,
|
|
||||||
nranks=args.mp_num,
|
|
||||||
num_layers=args.num_layers + self.num_extra_layers,
|
|
||||||
gpu_id=self.device,
|
|
||||||
rdma_port=args.rdma_port,
|
|
||||||
)
|
|
||||||
logger.info("successfully create cache messager")
|
|
||||||
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
|
||||||
|
|
||||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||||
self.cache_task_broadcast_signal = IPCSignal(
|
self.cache_task_broadcast_signal = IPCSignal(
|
||||||
name="cache_task_broadcast_signal",
|
name="cache_task_broadcast_signal",
|
||||||
@@ -443,5 +408,7 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||||
|
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
|
||||||
|
paddle.set_device(f"gpu:{args.device_id}")
|
||||||
main()
|
main()
|
||||||
|
@@ -150,6 +150,19 @@ class PrefixCacheManager:
|
|||||||
filename = "cache_transfer_manager.py"
|
filename = "cache_transfer_manager.py"
|
||||||
py_path = os.path.join(current_dir_path, filename)
|
py_path = os.path.join(current_dir_path, filename)
|
||||||
|
|
||||||
|
cache_messager_processes = []
|
||||||
|
cache_messager_processes = self.launch_cache_messager(
|
||||||
|
cache_config,
|
||||||
|
tensor_parallel_size,
|
||||||
|
device_ids,
|
||||||
|
pod_ip,
|
||||||
|
engine_worker_queue_port,
|
||||||
|
pid_suffix,
|
||||||
|
)
|
||||||
|
if cache_messager_processes is None:
|
||||||
|
raise RuntimeError("Launch cache messager failed")
|
||||||
|
return []
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
@@ -213,7 +226,76 @@ class PrefixCacheManager:
|
|||||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||||
logger.info("Enable hierarchical cache.")
|
logger.info("Enable hierarchical cache.")
|
||||||
self._enable_cpu_cache()
|
self._enable_cpu_cache()
|
||||||
return cache_manager_processes
|
all_cache_processes = cache_messager_processes + cache_manager_processes
|
||||||
|
return all_cache_processes
|
||||||
|
|
||||||
|
def launch_cache_messager(
|
||||||
|
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
launch_cache_messager function used to initialize the cache messager.
|
||||||
|
"""
|
||||||
|
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||||
|
filename = "cache_messager.py"
|
||||||
|
if (
|
||||||
|
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
|
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||||
|
and cache_config.model_cfg.num_key_value_heads is not None
|
||||||
|
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||||
|
):
|
||||||
|
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||||
|
else:
|
||||||
|
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||||
|
|
||||||
|
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||||
|
self.cache_ready_signal = IPCSignal(
|
||||||
|
name="cache_ready_signal",
|
||||||
|
array=cache_ready_signal_data,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=pid_suffix,
|
||||||
|
create=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
py_path = os.path.join(current_dir_path, filename)
|
||||||
|
log_dir = envs.FD_LOG_DIR
|
||||||
|
cache_messager_processes = []
|
||||||
|
for i in range(tensor_parallel_size):
|
||||||
|
launch_cmd = (
|
||||||
|
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||||
|
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||||
|
+ f" {sys.executable} {py_path}"
|
||||||
|
+ f" --device_id {int(device_ids[i])}"
|
||||||
|
+ f" --rank {i}"
|
||||||
|
+ f" --splitwise_role {self.splitwise_role}"
|
||||||
|
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||||
|
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||||
|
+ f" --kv_num_head {kv_num_head}"
|
||||||
|
+ f" --mp_num {tensor_parallel_size}"
|
||||||
|
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||||
|
+ f" --pod_ip {pod_ip}"
|
||||||
|
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||||
|
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||||
|
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||||
|
+ f" --block_size {cache_config.block_size}"
|
||||||
|
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||||
|
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||||
|
+ f" --engine_pid {pid_suffix}"
|
||||||
|
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||||
|
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||||
|
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
|
||||||
|
)
|
||||||
|
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||||
|
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||||
|
logger.info("Waiting for cache ready...")
|
||||||
|
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||||
|
time.sleep(1)
|
||||||
|
exit_code = cache_messager_processes[-1].poll()
|
||||||
|
if exit_code is None:
|
||||||
|
logger.info("Launch cache messager successful")
|
||||||
|
else:
|
||||||
|
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
|
||||||
|
cache_messager_processes = None
|
||||||
|
return cache_messager_processes
|
||||||
|
|
||||||
def update_cache_config(self, cache_config):
|
def update_cache_config(self, cache_config):
|
||||||
"""
|
"""
|
||||||
|
@@ -61,18 +61,12 @@ class RDMACommManager:
|
|||||||
Connect to remote gpu and write cache.
|
Connect to remote gpu and write cache.
|
||||||
"""
|
"""
|
||||||
assert self.splitwise_role == "prefill", "only prefill can call this method"
|
assert self.splitwise_role == "prefill", "only prefill can call this method"
|
||||||
addr = f"{ip}:{port!s}"
|
|
||||||
if addr in self.connected_rdma:
|
|
||||||
return True
|
|
||||||
ret = self.messager.is_connected(ip, str(port))
|
ret = self.messager.is_connected(ip, str(port))
|
||||||
if ret:
|
if ret:
|
||||||
self.connected_rdma.add(addr)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ret = self.messager.connect(ip, str(port))
|
ret = self.messager.connect(ip, str(port))
|
||||||
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||||
if ret == 0:
|
|
||||||
self.connected_rdma.add(addr)
|
|
||||||
return ret == 0
|
return ret == 0
|
||||||
|
|
||||||
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):
|
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):
|
||||||
|
@@ -1481,7 +1481,7 @@ class FDConfig:
|
|||||||
self.model_config.model_format = "torch"
|
self.model_config.model_format = "torch"
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
self.max_prefill_batch = 3
|
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
self.max_prefill_batch = 1
|
self.max_prefill_batch = 1
|
||||||
if self.model_config is not None and self.model_config.enable_mm:
|
if self.model_config is not None and self.model_config.enable_mm:
|
||||||
|
@@ -422,7 +422,7 @@ class EngineArgs:
|
|||||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||||
if self.speculative_config is not None:
|
if self.speculative_config is not None:
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
if self.splitwise_role != "mixed":
|
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
|
@@ -46,7 +46,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker
|
|||||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||||
from fastdeploy.utils import EngineError, envs, llm_logger
|
from fastdeploy.utils import EngineError, envs, get_logger, llm_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
TokenProcessor = load_token_processor_plugins()
|
TokenProcessor = load_token_processor_plugins()
|
||||||
@@ -69,6 +69,13 @@ class EngineService:
|
|||||||
"""
|
"""
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
|
if self.cfg.parallel_config.enable_expert_parallel:
|
||||||
|
self.llm_logger = get_logger(
|
||||||
|
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.llm_logger = llm_logger
|
||||||
|
|
||||||
self.scheduler = cfg.scheduler_config.scheduler()
|
self.scheduler = cfg.scheduler_config.scheduler()
|
||||||
|
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
@@ -79,10 +86,6 @@ class EngineService:
|
|||||||
cfg.scheduler_config.splitwise_role,
|
cfg.scheduler_config.splitwise_role,
|
||||||
cfg.parallel_config.local_data_parallel_id,
|
cfg.parallel_config.local_data_parallel_id,
|
||||||
)
|
)
|
||||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.resource_manager = ResourceManager(
|
self.resource_manager = ResourceManager(
|
||||||
cfg.scheduler_config.max_num_seqs,
|
cfg.scheduler_config.max_num_seqs,
|
||||||
@@ -135,12 +138,14 @@ class EngineService:
|
|||||||
self.insert_task_to_worker_thread.start()
|
self.insert_task_to_worker_thread.start()
|
||||||
self.token_processor.tasks_queue = self.engine_worker_queue
|
self.token_processor.tasks_queue = self.engine_worker_queue
|
||||||
self.token_processor.run()
|
self.token_processor.run()
|
||||||
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
|
self.split_mode_get_tasks()
|
||||||
|
|
||||||
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||||||
current_suffix = int(
|
current_suffix = int(
|
||||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||||
)
|
)
|
||||||
llm_logger.info(f"current_suffix: {current_suffix}")
|
self.llm_logger.info(f"current_suffix: {current_suffix}")
|
||||||
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
||||||
self.exist_task_signal = IPCSignal(
|
self.exist_task_signal = IPCSignal(
|
||||||
name="exist_task_signal",
|
name="exist_task_signal",
|
||||||
@@ -201,7 +206,7 @@ class EngineService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
||||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||||
address=address,
|
address=address,
|
||||||
is_server=True,
|
is_server=True,
|
||||||
@@ -225,7 +230,7 @@ class EngineService:
|
|||||||
client_id=-1,
|
client_id=-1,
|
||||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||||
)
|
)
|
||||||
llm_logger.info(
|
self.llm_logger.info(
|
||||||
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
|
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
|
||||||
)
|
)
|
||||||
self.engine_worker_queue = EngineWorkerQueue(
|
self.engine_worker_queue = EngineWorkerQueue(
|
||||||
@@ -254,7 +259,17 @@ class EngineService:
|
|||||||
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
||||||
del self.resource_manager.req_dict[task.request_id]
|
del self.resource_manager.req_dict[task.request_id]
|
||||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||||
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||||
|
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
|
||||||
|
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||||
|
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||||
|
self.resource_manager._recycle_block_tables(cur_task)
|
||||||
|
if task.request_id in self.token_processor.tokens_counter:
|
||||||
|
del self.token_processor.tokens_counter[task.request_id]
|
||||||
|
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
|
||||||
|
continue
|
||||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||||
|
cur_task.num_cached_tokens = task.num_cached_tokens
|
||||||
if (
|
if (
|
||||||
self.cfg.speculative_config.method in ["mtp"]
|
self.cfg.speculative_config.method in ["mtp"]
|
||||||
and self.cfg.scheduler_config.splitwise_role == "decode"
|
and self.cfg.scheduler_config.splitwise_role == "decode"
|
||||||
@@ -267,13 +282,14 @@ class EngineService:
|
|||||||
if task.request_id in self.token_processor.tokens_counter:
|
if task.request_id in self.token_processor.tokens_counter:
|
||||||
del self.token_processor.tokens_counter[task.request_id]
|
del self.token_processor.tokens_counter[task.request_id]
|
||||||
self.scheduler.put_results([task])
|
self.scheduler.put_results([task])
|
||||||
llm_logger.warning(
|
self.llm_logger.warning(
|
||||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
self.token_processor.tokens_counter[task.request_id] = 1
|
self.token_processor.tokens_counter[task.request_id] = 1
|
||||||
current_tasks.append(cur_task)
|
current_tasks.append(cur_task)
|
||||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
if current_tasks:
|
||||||
|
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self.resource_manager.check_and_free_block_tables()
|
self.resource_manager.check_and_free_block_tables()
|
||||||
@@ -281,13 +297,34 @@ class EngineService:
|
|||||||
if not isinstance(tasks, list):
|
if not isinstance(tasks, list):
|
||||||
tasks = [tasks]
|
tasks = [tasks]
|
||||||
|
|
||||||
|
need_delete_tasks = []
|
||||||
|
for task in tasks:
|
||||||
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
|
status, msg = self.split_connector.check_decode_allocated(task)
|
||||||
|
if not status:
|
||||||
|
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||||
|
self.scheduler.put_results(
|
||||||
|
[
|
||||||
|
RequestOutput(
|
||||||
|
request_id=task.request_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=500,
|
||||||
|
error_msg=msg,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
need_delete_tasks.append(task)
|
||||||
|
continue
|
||||||
|
for tmp_task in need_delete_tasks:
|
||||||
|
tasks.remove(tmp_task)
|
||||||
|
|
||||||
for item in tasks:
|
for item in tasks:
|
||||||
item.schedule_start_time = time.time()
|
item.schedule_start_time = time.time()
|
||||||
|
|
||||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||||
if len(tasks) > available_batch:
|
if len(tasks) > available_batch:
|
||||||
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
||||||
llm_logger.error("The exceeded part will be ignored!")
|
self.llm_logger.error("The exceeded part will be ignored!")
|
||||||
tasks = tasks[:available_batch]
|
tasks = tasks[:available_batch]
|
||||||
|
|
||||||
req_ids = [t.request_id for t in tasks]
|
req_ids = [t.request_id for t in tasks]
|
||||||
@@ -296,7 +333,7 @@ class EngineService:
|
|||||||
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
||||||
llm_logger.error(error_msg)
|
self.llm_logger.error(error_msg)
|
||||||
raise EngineError(error_msg, error_code=500)
|
raise EngineError(error_msg, error_code=500)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -314,7 +351,7 @@ class EngineService:
|
|||||||
|
|
||||||
self.split_connector.send_cache_infos(tasks, current_id)
|
self.split_connector.send_cache_infos(tasks, current_id)
|
||||||
if not is_decode:
|
if not is_decode:
|
||||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.inference_start_time = time.time()
|
task.inference_start_time = time.time()
|
||||||
if not is_prefill:
|
if not is_prefill:
|
||||||
@@ -473,7 +510,7 @@ class EngineService:
|
|||||||
Insert task to engine thread, monitor scheduler request queue.
|
Insert task to engine thread, monitor scheduler request queue.
|
||||||
if the engine has resource, insert task to engine
|
if the engine has resource, insert task to engine
|
||||||
"""
|
"""
|
||||||
current_id = -1
|
current_id = 0
|
||||||
while getattr(self, "running", True):
|
while getattr(self, "running", True):
|
||||||
try:
|
try:
|
||||||
if self.resource_manager.available_batch() == 0:
|
if self.resource_manager.available_batch() == 0:
|
||||||
@@ -514,18 +551,21 @@ class EngineService:
|
|||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
current_id = (current_id + 1) % 100003
|
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
llm_logger.info("Inserting splitwise tasks")
|
self.llm_logger.info("Inserting splitwise tasks")
|
||||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||||
|
|
||||||
self.insert_tasks(tasks, current_id)
|
insert_successful = self.insert_tasks(tasks, current_id)
|
||||||
|
if insert_successful:
|
||||||
|
current_id = current_id + 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}."
|
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||||
llm_logger.error(err_msg)
|
self.llm_logger.error(err_msg)
|
||||||
|
|
||||||
def _scheduler_task_to_worker_v1(self):
|
def _scheduler_task_to_worker_v1(self):
|
||||||
"""
|
"""
|
||||||
@@ -535,40 +575,100 @@ class EngineService:
|
|||||||
is_fetching = False
|
is_fetching = False
|
||||||
|
|
||||||
def _fetch_request():
|
def _fetch_request():
|
||||||
nonlocal is_fetching
|
try:
|
||||||
is_fetching = True
|
nonlocal is_fetching
|
||||||
num_prefill_batch = min(
|
is_fetching = True
|
||||||
int(self.resource_manager.available_batch()),
|
num_prefill_batch = min(
|
||||||
self.cfg.max_prefill_batch,
|
int(self.resource_manager.available_batch()),
|
||||||
)
|
self.cfg.max_prefill_batch,
|
||||||
if self.cfg.model_config.enable_mm:
|
)
|
||||||
available_blocks = self.resource_manager.available_block_num()
|
if self.cfg.model_config.enable_mm:
|
||||||
else:
|
available_blocks = self.resource_manager.available_block_num()
|
||||||
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
else:
|
||||||
|
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
||||||
|
|
||||||
tasks = self.scheduler.get_requests(
|
tasks = self.scheduler.get_requests(
|
||||||
available_blocks=available_blocks,
|
available_blocks=available_blocks,
|
||||||
block_size=self.cfg.cache_config.block_size,
|
block_size=self.cfg.cache_config.block_size,
|
||||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||||
max_num_batched_tokens=self.cfg.max_model_len,
|
max_num_batched_tokens=self.cfg.max_model_len,
|
||||||
batch=num_prefill_batch,
|
batch=num_prefill_batch,
|
||||||
)
|
)
|
||||||
# Fetch requests and add them to the scheduling queue
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
self.resource_manager.add_request(task)
|
# assure can allocate block ids in P
|
||||||
is_fetching = False
|
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||||
|
time.sleep(0.005)
|
||||||
|
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||||
|
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||||
|
need_delete_tasks = []
|
||||||
|
for task in tasks:
|
||||||
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
|
# assure fetch block ids from D
|
||||||
|
status, msg = self.split_connector.check_decode_allocated(task)
|
||||||
|
if not status:
|
||||||
|
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||||
|
self.scheduler.put_results(
|
||||||
|
[
|
||||||
|
RequestOutput(
|
||||||
|
request_id=task.request_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=500,
|
||||||
|
error_msg=msg,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
need_delete_tasks.append(task)
|
||||||
|
continue
|
||||||
|
for tmp_task in need_delete_tasks:
|
||||||
|
tasks.remove(tmp_task)
|
||||||
|
# release resource in P
|
||||||
|
self.resource_manager.prerelease_resource(task)
|
||||||
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||||
|
# to send cache info to cache messager
|
||||||
|
if tasks:
|
||||||
|
self.split_connector.send_cache_infos(tasks, 0)
|
||||||
|
# ensure cache tasks has sent to cache_messager
|
||||||
|
need_check_req_ids = [task.request_id for task in tasks]
|
||||||
|
while need_check_req_ids:
|
||||||
|
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
|
||||||
|
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
|
||||||
|
if req_ids:
|
||||||
|
for req_id in req_ids:
|
||||||
|
assert req_id in need_check_req_ids
|
||||||
|
need_check_req_ids.remove(req_id)
|
||||||
|
else:
|
||||||
|
time.sleep(0.001)
|
||||||
|
# Fetch requests and add them to the scheduling queue
|
||||||
|
if tasks:
|
||||||
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||||
|
self.resource_manager.add_request_in_p(tasks)
|
||||||
|
else:
|
||||||
|
for task in tasks:
|
||||||
|
self.resource_manager.add_request(task)
|
||||||
|
is_fetching = False
|
||||||
|
except Exception as e:
|
||||||
|
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
|
||||||
|
is_fetching = False
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
if self.engine_worker_queue.num_tasks() > 0:
|
if self.engine_worker_queue.num_tasks() > 0:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
if (
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
len(self.resource_manager.waiting) == 0
|
if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
|
||||||
and (not is_fetching)
|
not is_fetching
|
||||||
and self.exist_prefill_task_signal.value[0] == 0
|
):
|
||||||
):
|
get_request_pool.submit(_fetch_request)
|
||||||
get_request_pool.submit(_fetch_request)
|
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
len(self.resource_manager.waiting) == 0
|
||||||
|
and (not is_fetching)
|
||||||
|
and self.exist_prefill_task_signal.value[0] == 0
|
||||||
|
):
|
||||||
|
get_request_pool.submit(_fetch_request)
|
||||||
# 2. Schedule requests
|
# 2. Schedule requests
|
||||||
tasks = self.resource_manager.schedule()
|
tasks = self.resource_manager.schedule()
|
||||||
# 3. Send to engine
|
# 3. Send to engine
|
||||||
@@ -579,8 +679,8 @@ class EngineService:
|
|||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||||
llm_logger.error(err_msg)
|
self.llm_logger.error(err_msg)
|
||||||
|
|
||||||
def start_zmq_service(self, api_server_pid=None):
|
def start_zmq_service(self, api_server_pid=None):
|
||||||
if api_server_pid is None:
|
if api_server_pid is None:
|
||||||
@@ -608,6 +708,9 @@ class EngineService:
|
|||||||
|
|
||||||
def _insert_zmq_task_to_scheduler(self):
|
def _insert_zmq_task_to_scheduler(self):
|
||||||
added_requests: Dict[str, int] = dict()
|
added_requests: Dict[str, int] = dict()
|
||||||
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||||
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||||
|
return
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
block = True if len(added_requests) == 0 else False
|
block = True if len(added_requests) == 0 else False
|
||||||
@@ -616,7 +719,7 @@ class EngineService:
|
|||||||
else:
|
else:
|
||||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
||||||
break
|
break
|
||||||
|
|
||||||
request, insert_task = None, []
|
request, insert_task = None, []
|
||||||
@@ -627,16 +730,16 @@ class EngineService:
|
|||||||
request = Request.from_dict(data)
|
request = Request.from_dict(data)
|
||||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||||
main_process_metrics.requests_number.inc()
|
main_process_metrics.requests_number.inc()
|
||||||
llm_logger.debug(f"Receive request: {request}")
|
self.llm_logger.debug(f"Receive request: {request}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||||
err_msg = str(e)
|
err_msg = str(e)
|
||||||
results.append((data["request_id"], err_msg))
|
results.append((data["request_id"], err_msg))
|
||||||
|
|
||||||
if self.guided_decoding_checker is not None and err_msg is None:
|
if self.guided_decoding_checker is not None and err_msg is None:
|
||||||
request, err_msg = self.guided_decoding_checker.schema_format(request)
|
request, err_msg = self.guided_decoding_checker.schema_format(request)
|
||||||
if err_msg is not None:
|
if err_msg is not None:
|
||||||
llm_logger.error(f"Receive request error: {err_msg}")
|
self.llm_logger.error(f"Receive request error: {err_msg}")
|
||||||
results.append((request.request_id, err_msg))
|
results.append((request.request_id, err_msg))
|
||||||
|
|
||||||
if err_msg is None:
|
if err_msg is None:
|
||||||
@@ -670,7 +773,7 @@ class EngineService:
|
|||||||
# Send result by zmq directly
|
# Send result by zmq directly
|
||||||
self.send_response_server.send_response(request_id, [error_result])
|
self.send_response_server.send_response(request_id, [error_result])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(
|
self.llm_logger.error(
|
||||||
f"Error happened while receiving new request from zmq, details={e}, "
|
f"Error happened while receiving new request from zmq, details={e}, "
|
||||||
f"traceback={traceback.format_exc()}"
|
f"traceback={traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
@@ -689,7 +792,7 @@ class EngineService:
|
|||||||
self.send_response_server.send_response(request_id, contents)
|
self.send_response_server.send_response(request_id, contents)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}")
|
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||||
|
|
||||||
def split_mode_get_tasks(self):
|
def split_mode_get_tasks(self):
|
||||||
"""
|
"""
|
||||||
@@ -702,13 +805,22 @@ class EngineService:
|
|||||||
|
|
||||||
processed_indices = []
|
processed_indices = []
|
||||||
for idx, task in enumerate(self.waiting_requests):
|
for idx, task in enumerate(self.waiting_requests):
|
||||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
self.insert_tasks([task])
|
if self.resource_manager.preallocate_resource_in_d(task):
|
||||||
llm_logger.info(f"Resource available, processing task {task.request_id}")
|
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||||
processed_indices.append(idx)
|
self.split_connector.send_cache_infos([task], -1)
|
||||||
|
processed_indices.append(idx)
|
||||||
|
else:
|
||||||
|
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||||
break
|
self.insert_tasks([task])
|
||||||
|
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||||
|
processed_indices.append(idx)
|
||||||
|
else:
|
||||||
|
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||||
|
break
|
||||||
|
|
||||||
for idx in sorted(processed_indices, reverse=True):
|
for idx in sorted(processed_indices, reverse=True):
|
||||||
self.waiting_requests.pop(idx)
|
self.waiting_requests.pop(idx)
|
||||||
@@ -730,32 +842,79 @@ class EngineService:
|
|||||||
tasks = [tasks]
|
tasks = [tasks]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.finished = False
|
task.finished = False
|
||||||
self.insert_tasks(tasks, allocated=True)
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
for task in tasks:
|
||||||
if self.cfg.innode_prefill_ports is not None:
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||||
self.scheduler.put_results(tasks)
|
if (
|
||||||
|
not task.outputs.token_ids
|
||||||
|
): # first token is eos in Prefill, just recycle resource and continue
|
||||||
|
cur_task = self.resource_manager.requests[task.request_id]
|
||||||
|
self.resource_manager.stop_flags[cur_task.idx] = True
|
||||||
|
self.resource_manager.tasks_list[cur_task.idx] = None
|
||||||
|
self.resource_manager._free_blocks(cur_task)
|
||||||
|
if cur_task.request_id in self.token_processor.tokens_counter:
|
||||||
|
del self.token_processor.tokens_counter[task.request_id]
|
||||||
|
self.llm_logger.warning(
|
||||||
|
f"{task.request_id} need not decode after first token"
|
||||||
|
)
|
||||||
|
del self.resource_manager.requests[task.request_id]
|
||||||
|
del self.resource_manager.req_dict[task.request_id]
|
||||||
|
continue
|
||||||
|
if task.error_code != 200:
|
||||||
|
cur_task = self.resource_manager.requests[task.request_id]
|
||||||
|
self.resource_manager.stop_flags[cur_task.idx] = True
|
||||||
|
self.resource_manager.tasks_list[cur_task.idx] = None
|
||||||
|
self.resource_manager._free_blocks(cur_task)
|
||||||
|
if cur_task.request_id in self.token_processor.tokens_counter:
|
||||||
|
del self.token_processor.tokens_counter[task.request_id]
|
||||||
|
self.scheduler.put_results([task])
|
||||||
|
self.llm_logger.warning(
|
||||||
|
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
self.resource_manager.insert_task_for_decoding(task)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.insert_tasks(tasks, allocated=True)
|
||||||
|
if self.cfg.innode_prefill_ports is not None:
|
||||||
|
self.scheduler.put_results(tasks)
|
||||||
else:
|
else:
|
||||||
if len(self.waiting_requests):
|
if len(self.waiting_requests):
|
||||||
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||||
self.waiting_requests.extend(tasks)
|
self.waiting_requests.extend(tasks)
|
||||||
else:
|
else:
|
||||||
new_waiting = []
|
new_waiting = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
can_allocate_resource = False
|
||||||
self.insert_tasks([task])
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
if self.resource_manager.preallocate_resource_in_d(task):
|
||||||
|
self.split_connector.send_cache_infos([task], -1)
|
||||||
|
can_allocate_resource = True
|
||||||
else:
|
else:
|
||||||
|
if self.resource_manager.is_resource_sufficient(
|
||||||
|
task.prompt_token_ids_len
|
||||||
|
):
|
||||||
|
self.insert_tasks([task])
|
||||||
|
can_allocate_resource = True
|
||||||
|
if can_allocate_resource is False:
|
||||||
|
if not self.enable_decode_cache_task:
|
||||||
|
task.error_msg = "Not enough resources"
|
||||||
new_waiting.append(task)
|
new_waiting.append(task)
|
||||||
|
|
||||||
if new_waiting:
|
if new_waiting:
|
||||||
self.waiting_requests.extend(new_waiting)
|
if not self.enable_decode_cache_task:
|
||||||
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||||
|
else:
|
||||||
|
self.waiting_requests.extend(new_waiting)
|
||||||
|
self.llm_logger.info(
|
||||||
|
f"Added {len(new_waiting)} tasks to waiting queue"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Error in main loop: {e}")
|
self.llm_logger.error(f"Error in main loop: {e}")
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
threading.Thread(target=receiver_loop, daemon=True).start()
|
threading.Thread(target=receiver_loop, daemon=True).start()
|
||||||
|
@@ -120,11 +120,10 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.data_processor = self.input_processor.create_processor()
|
self.data_processor = self.input_processor.create_processor()
|
||||||
self.engine.data_processor = self.data_processor
|
self.engine.data_processor = self.data_processor
|
||||||
|
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||||
|
self.launch_components()
|
||||||
|
|
||||||
self.engine.start()
|
self.engine.start()
|
||||||
if api_server_pid is not None:
|
|
||||||
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
|
|
||||||
self.engine.start_zmq_service(api_server_pid)
|
|
||||||
|
|
||||||
if self.do_profile == 0 and (
|
if self.do_profile == 0 and (
|
||||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
|
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
|
||||||
@@ -159,11 +158,14 @@ class LLMEngine:
|
|||||||
|
|
||||||
if self.do_profile:
|
if self.do_profile:
|
||||||
self._stop_profile()
|
self._stop_profile()
|
||||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
|
||||||
self.launch_components()
|
|
||||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
self.launched_cache_manager_signal.value[0] = 1
|
self.launched_cache_manager_signal.value[0] = 1
|
||||||
|
|
||||||
|
if api_server_pid is not None:
|
||||||
|
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
|
||||||
|
self.engine.start_zmq_service(api_server_pid)
|
||||||
|
|
||||||
# Worker launched
|
# Worker launched
|
||||||
self.check_worker_initialize_status_func_thread.join()
|
self.check_worker_initialize_status_func_thread.join()
|
||||||
if not result_container["worker_is_alive"]:
|
if not result_container["worker_is_alive"]:
|
||||||
@@ -427,7 +429,10 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
variables["FLAGS_use_pd_disaggregation"] = 1
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
|
||||||
|
else:
|
||||||
|
variables["FLAGS_use_pd_disaggregation"] = 1
|
||||||
# TODO dynamic load environment variable
|
# TODO dynamic load environment variable
|
||||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||||
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
||||||
@@ -498,6 +503,7 @@ class LLMEngine:
|
|||||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||||
f" --ips {ips}"
|
f" --ips {ips}"
|
||||||
|
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
|
||||||
f" --runner {self.cfg.model_config.runner}"
|
f" --runner {self.cfg.model_config.runner}"
|
||||||
f" --convert {self.cfg.model_config.convert}"
|
f" --convert {self.cfg.model_config.convert}"
|
||||||
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
||||||
@@ -625,13 +631,11 @@ class LLMEngine:
|
|||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
# 单机逻辑
|
# 单机逻辑
|
||||||
self.engine.engine_worker_queue.available_prefill_instances.put(1)
|
self.engine.engine_worker_queue.available_prefill_instances.put(1)
|
||||||
self.engine.split_mode_get_tasks()
|
self.splitwise_receive_thread = threading.Thread(
|
||||||
if self.cfg.scheduler_config.name == "splitwise":
|
target=self.engine.split_connector.start_receiver, args=()
|
||||||
self.splitwise_receive_thread = threading.Thread(
|
)
|
||||||
target=self.engine.split_connector.start_receiver, args=()
|
self.splitwise_receive_thread.daemon = True
|
||||||
)
|
self.splitwise_receive_thread.start()
|
||||||
self.splitwise_receive_thread.daemon = True
|
|
||||||
self.splitwise_receive_thread.start()
|
|
||||||
|
|
||||||
self.cfg.init_cache_info()
|
self.cfg.init_cache_info()
|
||||||
|
|
||||||
@@ -640,6 +644,14 @@ class LLMEngine:
|
|||||||
disaggregate = self.cfg.disaggregate_info
|
disaggregate = self.cfg.disaggregate_info
|
||||||
if self.cfg.scheduler_config.name == "splitwise":
|
if self.cfg.scheduler_config.name == "splitwise":
|
||||||
self.engine.scheduler.start(role, host_ip, disaggregate)
|
self.engine.scheduler.start(role, host_ip, disaggregate)
|
||||||
|
elif self.cfg.scheduler_config.name == "dp":
|
||||||
|
request_queues_for_dp_ipc = []
|
||||||
|
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||||
|
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||||
|
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||||
|
self.engine.scheduler.start(
|
||||||
|
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||||
|
)
|
||||||
|
|
||||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||||
@@ -669,6 +681,9 @@ class LLMEngine:
|
|||||||
args=(
|
args=(
|
||||||
self.cfg,
|
self.cfg,
|
||||||
i,
|
i,
|
||||||
|
None,
|
||||||
|
request_queues_for_dp_ipc,
|
||||||
|
result_queue_for_dp_ipc,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -27,6 +27,7 @@ import numpy as np
|
|||||||
|
|
||||||
from fastdeploy.engine.common_engine import EngineService
|
from fastdeploy.engine.common_engine import EngineService
|
||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import IPCSignal
|
||||||
|
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||||
from fastdeploy.utils import console_logger, envs, llm_logger
|
from fastdeploy.utils import console_logger, envs, llm_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -69,8 +70,12 @@ class ExpertService:
|
|||||||
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||||
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||||
|
self.internal_adapter = InternalAdapter(cfg=self.cfg, engine=self.engine, dp_rank=local_data_parallel_id)
|
||||||
|
|
||||||
def start(self, ipc_signal_suffix, local_data_parallel_id):
|
def start(
|
||||||
|
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the engine and starts its sub-services.
|
Initializes the engine and starts its sub-services.
|
||||||
If `api_server_pid` is defined, will launch a thread
|
If `api_server_pid` is defined, will launch a thread
|
||||||
@@ -80,6 +85,11 @@ class ExpertService:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self.engine.start()
|
self.engine.start()
|
||||||
|
if self.cfg.scheduler_config.name == "dp":
|
||||||
|
self.cfg.init_cache_info()
|
||||||
|
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
|
||||||
|
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
|
||||||
|
|
||||||
if ipc_signal_suffix is not None:
|
if ipc_signal_suffix is not None:
|
||||||
self.api_server_pid = ipc_signal_suffix
|
self.api_server_pid = ipc_signal_suffix
|
||||||
self.engine.start_zmq_service(ipc_signal_suffix)
|
self.engine.start_zmq_service(ipc_signal_suffix)
|
||||||
@@ -88,8 +98,8 @@ class ExpertService:
|
|||||||
|
|
||||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
|
ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
|
||||||
self.engine.split_mode_get_tasks()
|
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache)
|
||||||
|
|
||||||
if self.cfg.scheduler_config.name == "splitwise":
|
if self.cfg.scheduler_config.name == "splitwise":
|
||||||
self.cfg.init_cache_info()
|
self.cfg.init_cache_info()
|
||||||
@@ -144,14 +154,18 @@ class ExpertService:
|
|||||||
self.zmq_server.close()
|
self.zmq_server.close()
|
||||||
|
|
||||||
|
|
||||||
def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None):
|
def start_data_parallel_service(
|
||||||
|
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Start expert service
|
Start expert service
|
||||||
"""
|
"""
|
||||||
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
|
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
expert_service.start(
|
||||||
|
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||||
|
)
|
||||||
|
|
||||||
def deamon_thread():
|
def deamon_thread():
|
||||||
while True:
|
while True:
|
||||||
@@ -159,5 +173,6 @@ def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=N
|
|||||||
|
|
||||||
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
|
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
|
||||||
t_deamon.start()
|
t_deamon.start()
|
||||||
|
t_deamon.join()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")
|
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")
|
||||||
|
@@ -73,6 +73,7 @@ class Request:
|
|||||||
guided_json_object: Optional[bool] = None,
|
guided_json_object: Optional[bool] = None,
|
||||||
enable_thinking: Optional[bool] = True,
|
enable_thinking: Optional[bool] = True,
|
||||||
trace_carrier: dict = dict(),
|
trace_carrier: dict = dict(),
|
||||||
|
dp_rank: Optional[int] = None,
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
image_start: int = 0,
|
image_start: int = 0,
|
||||||
video_start: int = 0,
|
video_start: int = 0,
|
||||||
@@ -145,6 +146,8 @@ class Request:
|
|||||||
# extend block tables
|
# extend block tables
|
||||||
self.use_extend_tables = False
|
self.use_extend_tables = False
|
||||||
self.extend_block_tables = []
|
self.extend_block_tables = []
|
||||||
|
# dp
|
||||||
|
self.dp_rank = dp_rank
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict):
|
def from_dict(cls, d: dict):
|
||||||
@@ -187,6 +190,7 @@ class Request:
|
|||||||
image_end=d.get("image_end", 0),
|
image_end=d.get("image_end", 0),
|
||||||
video_end=d.get("video_end", 0),
|
video_end=d.get("video_end", 0),
|
||||||
audio_end=d.get("audio_end", 0),
|
audio_end=d.get("audio_end", 0),
|
||||||
|
dp_rank=d.get("dp_rank", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@@ -328,8 +328,8 @@ class ResourceManager:
|
|||||||
Delete cached data from the task's prompt token ids based on the cached length.
|
Delete cached data from the task's prompt token ids based on the cached length.
|
||||||
"""
|
"""
|
||||||
if cached_len == len(task.prompt_token_ids):
|
if cached_len == len(task.prompt_token_ids):
|
||||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
|
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
|
||||||
task.seq_lens_decoder = cached_len - 1
|
task.seq_lens_decoder = cached_len - self.cfg.block_size
|
||||||
else:
|
else:
|
||||||
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
||||||
task.seq_lens_decoder = cached_len
|
task.seq_lens_decoder = cached_len
|
||||||
|
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -26,7 +27,7 @@ from typing import Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request, RequestStatus, RequestType
|
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
||||||
from fastdeploy.engine.resource_manager import ResourceManager
|
from fastdeploy.engine.resource_manager import ResourceManager
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
from fastdeploy.utils import llm_logger
|
from fastdeploy.utils import llm_logger
|
||||||
@@ -297,6 +298,11 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
while req_index < len(self.running) and token_budget > 0:
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
request = self.running[req_index]
|
request = self.running[req_index]
|
||||||
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
||||||
|
if (
|
||||||
|
self.config.scheduler_config.splitwise_role == "prefill"
|
||||||
|
): # do not need to schedule for decoding
|
||||||
|
req_index += 1
|
||||||
|
continue
|
||||||
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
||||||
request.num_computed_tokens = request.num_total_tokens - 1
|
request.num_computed_tokens = request.num_total_tokens - 1
|
||||||
if (
|
if (
|
||||||
@@ -400,11 +406,12 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
main_process_metrics.num_requests_waiting.dec(1)
|
main_process_metrics.num_requests_waiting.dec(1)
|
||||||
main_process_metrics.num_requests_running.inc(1)
|
main_process_metrics.num_requests_running.inc(1)
|
||||||
allocated_position = self.get_available_position()
|
if self.config.scheduler_config.splitwise_role == "mixed":
|
||||||
request.idx = allocated_position
|
allocated_position = self.get_available_position()
|
||||||
self.tasks_list[allocated_position] = request
|
request.idx = allocated_position
|
||||||
self.stop_flags[allocated_position] = False
|
self.tasks_list[allocated_position] = request
|
||||||
self.req_dict[request.request_id] = allocated_position
|
self.stop_flags[allocated_position] = False
|
||||||
|
self.req_dict[request.request_id] = allocated_position
|
||||||
else:
|
else:
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
@@ -569,6 +576,127 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
self.waiting.append(request)
|
self.waiting.append(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
|
|
||||||
|
def prerelease_resource(self, request: Request):
|
||||||
|
"""
|
||||||
|
Release resource in P or D before finished due to unexpected error.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.tasks_list[request.idx] = None
|
||||||
|
self.stop_flags[request.idx] = True
|
||||||
|
del self.requests[request.request_id]
|
||||||
|
del self.req_dict[request.request_id]
|
||||||
|
self._free_blocks(request)
|
||||||
|
|
||||||
|
def add_request_in_p(self, requests: list[Request]):
|
||||||
|
with self.lock:
|
||||||
|
for request in requests:
|
||||||
|
request.inference_start_time = time.time()
|
||||||
|
request.schedule_start_time = time.time()
|
||||||
|
self.running.append(request)
|
||||||
|
|
||||||
|
def preallocate_resource_in_p(self, request: Request):
|
||||||
|
"""
|
||||||
|
In P/D aggregated deployment, preallocate resource for P.
|
||||||
|
If can allocate, allocate resources and return True
|
||||||
|
If can not, return False
|
||||||
|
"""
|
||||||
|
assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method"
|
||||||
|
with self.lock:
|
||||||
|
if self.available_batch() == 0:
|
||||||
|
return False
|
||||||
|
request.need_prefill_tokens = len(request.prompt_token_ids)
|
||||||
|
need_prealloc_prefill_blocks = (
|
||||||
|
request.need_prefill_tokens + self.config.cache_config.block_size - 1
|
||||||
|
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
|
||||||
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
# Enable prefix caching
|
||||||
|
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
|
||||||
|
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||||
|
need_prealloc_prefill_blocks
|
||||||
|
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||||
|
return False
|
||||||
|
success = self.get_prefix_cached_blocks(request)
|
||||||
|
if not success:
|
||||||
|
self._free_blocks(request)
|
||||||
|
return False
|
||||||
|
# consider for mtp, plus enc_dec_block_num
|
||||||
|
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
|
||||||
|
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
|
||||||
|
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
|
||||||
|
allocated_position = self.get_available_position()
|
||||||
|
request.idx = allocated_position
|
||||||
|
self.tasks_list[request.idx] = request
|
||||||
|
self.stop_flags[request.idx] = False
|
||||||
|
self.requests[request.request_id] = request
|
||||||
|
self.req_dict[request.request_id] = allocated_position
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._free_blocks(request)
|
||||||
|
return False
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
||||||
|
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
|
||||||
|
request.num_computed_tokens = 0
|
||||||
|
allocated_position = self.get_available_position()
|
||||||
|
request.idx = allocated_position
|
||||||
|
self.tasks_list[request.idx] = request
|
||||||
|
self.stop_flags[request.idx] = False
|
||||||
|
self.requests[request.request_id] = request
|
||||||
|
self.req_dict[request.request_id] = allocated_position
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def preallocate_resource_in_d(self, request: Request):
|
||||||
|
"""
|
||||||
|
In P/D aggregated deployment, D should preallocate resource for P.
|
||||||
|
If can allocate, allocate resources and return True
|
||||||
|
If can not, return False
|
||||||
|
"""
|
||||||
|
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
|
||||||
|
with self.lock:
|
||||||
|
if len(self.waiting) > 0:
|
||||||
|
return False
|
||||||
|
if self.available_batch() == 0:
|
||||||
|
return False
|
||||||
|
request.need_prefill_tokens = len(request.prompt_token_ids)
|
||||||
|
need_prealloc_prefill_blocks = (
|
||||||
|
request.need_prefill_tokens + self.config.cache_config.block_size - 1
|
||||||
|
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
|
||||||
|
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
|
||||||
|
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
|
||||||
|
request.num_computed_tokens = request.need_prefill_tokens
|
||||||
|
request.disaggregate_info["block_tables"] = request.block_tables
|
||||||
|
allocated_position = self.get_available_position()
|
||||||
|
request.idx = allocated_position
|
||||||
|
self.tasks_list[request.idx] = request
|
||||||
|
self.stop_flags[request.idx] = False
|
||||||
|
self.requests[request.request_id] = request
|
||||||
|
self.req_dict[request.request_id] = allocated_position
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def insert_task_for_decoding(self, request_output_in_p: RequestOutput):
|
||||||
|
"""
|
||||||
|
In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P.
|
||||||
|
"""
|
||||||
|
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
|
||||||
|
with self.lock:
|
||||||
|
request = self.requests[request_output_in_p.request_id]
|
||||||
|
request.output_token_ids.append(request_output_in_p.outputs.token_ids[0])
|
||||||
|
request.num_cached_tokens = request_output_in_p.num_cached_tokens
|
||||||
|
if (
|
||||||
|
self.config.speculative_config.method in ["mtp"]
|
||||||
|
and self.config.scheduler_config.splitwise_role == "decode"
|
||||||
|
):
|
||||||
|
request.draft_token_ids = copy.deepcopy(request_output_in_p.outputs.draft_token_ids)
|
||||||
|
# update request.need_prefill_tokens
|
||||||
|
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
|
||||||
|
request.inference_start_time = time.time()
|
||||||
|
request.schedule_start_time = time.time()
|
||||||
|
self.running.append(request)
|
||||||
|
|
||||||
def _free_blocks(self, request: Request):
|
def _free_blocks(self, request: Request):
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
self.cache_manager.release_block_ids(request)
|
self.cache_manager.release_block_ids(request)
|
||||||
@@ -620,5 +748,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
self.tasks_list[request.idx] = None
|
self.tasks_list[request.idx] = None
|
||||||
self.stop_flags[request.idx] = True
|
self.stop_flags[request.idx] = True
|
||||||
del self.requests[req_id]
|
del self.requests[req_id]
|
||||||
|
if req_id in self.req_dict:
|
||||||
|
del self.req_dict[req_id]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
||||||
|
@@ -109,6 +109,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||||
|
# Whether to enable cache task in decode node
|
||||||
|
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
|
||||||
|
# Batched token timeout in EP
|
||||||
|
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||||
|
# Max pre-fetch requests number in PD
|
||||||
|
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||||
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
|
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,6 +126,14 @@ def __getattr__(name: str):
|
|||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_name(self, name):
|
||||||
|
"""
|
||||||
|
Get unique name for config
|
||||||
|
"""
|
||||||
|
shm_uuid = os.getenv("SHM_UUID", "")
|
||||||
|
return name + f"_{shm_uuid}"
|
||||||
|
|
||||||
|
|
||||||
def __setattr__(name: str, value: Any):
|
def __setattr__(name: str, value: Any):
|
||||||
assert name in environment_variables
|
assert name in environment_variables
|
||||||
environment_variables[name] = lambda: value
|
environment_variables[name] = lambda: value
|
||||||
|
@@ -84,18 +84,28 @@ class EngineWorkerQueue:
|
|||||||
Value("i", 0) for _ in range(self.local_data_parallel_size)
|
Value("i", 0) for _ in range(self.local_data_parallel_size)
|
||||||
]
|
]
|
||||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||||
|
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||||
|
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||||
|
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||||
self.client_read_info_flag_init: List[List[int]] = [
|
self.client_read_info_flag_init: List[List[int]] = [
|
||||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||||
]
|
]
|
||||||
self.lock_info_init: List[threading.Lock] = [
|
self.lock_info_init: List[threading.Lock] = [
|
||||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||||
]
|
]
|
||||||
|
self.connect_task_lock_init: List[threading.Lock] = [
|
||||||
|
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
self.finish_request_barrier = [
|
self.finish_request_barrier = [
|
||||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.finish_add_cache_task_barrier = [
|
||||||
|
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
# Register shared objects with proxy types
|
# Register shared objects with proxy types
|
||||||
QueueManager.register(
|
QueueManager.register(
|
||||||
"get_tasks",
|
"get_tasks",
|
||||||
@@ -117,6 +127,19 @@ class EngineWorkerQueue:
|
|||||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||||
proxytype=ValueProxy,
|
proxytype=ValueProxy,
|
||||||
)
|
)
|
||||||
|
QueueManager.register(
|
||||||
|
"get_connect_task_lock",
|
||||||
|
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||||
|
proxytype=AcquirerProxy,
|
||||||
|
)
|
||||||
|
QueueManager.register(
|
||||||
|
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
|
||||||
|
)
|
||||||
|
QueueManager.register(
|
||||||
|
"get_connect_rdma_tasks_responses",
|
||||||
|
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
|
||||||
|
proxytype=ListProxy,
|
||||||
|
)
|
||||||
QueueManager.register(
|
QueueManager.register(
|
||||||
"get_connected_client_counter",
|
"get_connected_client_counter",
|
||||||
callable=lambda idx: self.connected_client_counter_init[idx],
|
callable=lambda idx: self.connected_client_counter_init[idx],
|
||||||
@@ -128,6 +151,11 @@ class EngineWorkerQueue:
|
|||||||
callable=lambda idx: self.finished_req_queue[idx],
|
callable=lambda idx: self.finished_req_queue[idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
QueueManager.register(
|
||||||
|
"get_finish_add_cache_task_queue",
|
||||||
|
callable=lambda idx: self.finished_add_cache_task_queue[idx],
|
||||||
|
)
|
||||||
|
|
||||||
QueueManager.register(
|
QueueManager.register(
|
||||||
"get_cache_infos",
|
"get_cache_infos",
|
||||||
callable=lambda idx: self.cache_infos_init[idx],
|
callable=lambda idx: self.cache_infos_init[idx],
|
||||||
@@ -161,6 +189,10 @@ class EngineWorkerQueue:
|
|||||||
"get_finish_request_barrier",
|
"get_finish_request_barrier",
|
||||||
callable=lambda idx: self.finish_request_barrier[idx],
|
callable=lambda idx: self.finish_request_barrier[idx],
|
||||||
)
|
)
|
||||||
|
QueueManager.register(
|
||||||
|
"get_finish_add_cache_task_barrier",
|
||||||
|
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
|
||||||
|
)
|
||||||
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
||||||
self.manager.start()
|
self.manager.start()
|
||||||
else:
|
else:
|
||||||
@@ -174,12 +206,17 @@ class EngineWorkerQueue:
|
|||||||
QueueManager.register("get_read_finish_flag")
|
QueueManager.register("get_read_finish_flag")
|
||||||
QueueManager.register("get_connected_client_counter")
|
QueueManager.register("get_connected_client_counter")
|
||||||
QueueManager.register("get_finish_request_queue")
|
QueueManager.register("get_finish_request_queue")
|
||||||
|
QueueManager.register("get_finish_add_cache_task_queue")
|
||||||
QueueManager.register("get_cache_infos")
|
QueueManager.register("get_cache_infos")
|
||||||
QueueManager.register("get_client_read_info_flag")
|
QueueManager.register("get_client_read_info_flag")
|
||||||
QueueManager.register("get_lock_info")
|
QueueManager.register("get_lock_info")
|
||||||
QueueManager.register("get_disaggregate_requests")
|
QueueManager.register("get_disaggregate_requests")
|
||||||
QueueManager.register("get_available_prefill_instances")
|
QueueManager.register("get_available_prefill_instances")
|
||||||
QueueManager.register("get_finish_request_barrier")
|
QueueManager.register("get_finish_request_barrier")
|
||||||
|
QueueManager.register("get_finish_add_cache_task_barrier")
|
||||||
|
QueueManager.register("get_connect_rdma_tasks")
|
||||||
|
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||||
|
QueueManager.register("get_connect_task_lock")
|
||||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||||
self._connect_with_retry()
|
self._connect_with_retry()
|
||||||
|
|
||||||
@@ -199,7 +236,20 @@ class EngineWorkerQueue:
|
|||||||
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
|
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
|
||||||
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
self.available_prefill_instances = self.manager.get_available_prefill_instances()
|
||||||
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
|
||||||
|
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
|
||||||
|
self.local_data_parallel_id
|
||||||
|
)
|
||||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||||
|
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
|
||||||
|
self.local_data_parallel_id
|
||||||
|
)
|
||||||
|
# p/d互联
|
||||||
|
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||||
|
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||||
|
self.local_data_parallel_id
|
||||||
|
)
|
||||||
|
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||||
|
|
||||||
assert self.num_client == len(self.client_read_flag)
|
assert self.num_client == len(self.client_read_flag)
|
||||||
|
|
||||||
if is_server:
|
if is_server:
|
||||||
@@ -281,6 +331,44 @@ class EngineWorkerQueue:
|
|||||||
self.lock.release()
|
self.lock.release()
|
||||||
return total_num
|
return total_num
|
||||||
|
|
||||||
|
def put_connect_rdma_task(self, connect_rdma_task):
|
||||||
|
self.connect_task_lock.acquire()
|
||||||
|
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||||
|
self.connect_task_lock.release()
|
||||||
|
|
||||||
|
def get_connect_rdma_task(self):
|
||||||
|
result = None
|
||||||
|
self.connect_task_lock.acquire()
|
||||||
|
if len(self.connect_rdma_task_queue) == 0:
|
||||||
|
self.connect_task_lock.release()
|
||||||
|
return result
|
||||||
|
try:
|
||||||
|
result = self.connect_rdma_task_queue.pop(0)
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||||
|
finally:
|
||||||
|
self.connect_task_lock.release()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||||
|
self.connect_task_lock.acquire()
|
||||||
|
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||||
|
self.connect_task_lock.release()
|
||||||
|
|
||||||
|
def get_connect_rdma_task_response(self):
|
||||||
|
result = None
|
||||||
|
self.connect_task_lock.acquire()
|
||||||
|
if len(self.connect_rdma_task_response_queue) == 0:
|
||||||
|
self.connect_task_lock.release()
|
||||||
|
return result
|
||||||
|
try:
|
||||||
|
result = self.connect_rdma_task_response_queue.pop(0)
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||||
|
finally:
|
||||||
|
self.connect_task_lock.release()
|
||||||
|
return result
|
||||||
|
|
||||||
def get_prefill_instances(self):
|
def get_prefill_instances(self):
|
||||||
"""
|
"""
|
||||||
check if the prefill queue is empty
|
check if the prefill queue is empty
|
||||||
@@ -365,6 +453,29 @@ class EngineWorkerQueue:
|
|||||||
llm_logger.debug(f"get finished req: {ans}")
|
llm_logger.debug(f"get finished req: {ans}")
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
def put_finished_add_cache_task_req(self, req_ids) -> None:
|
||||||
|
"""
|
||||||
|
Put finished request ID into the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_ids: Request ID to be added to the queue
|
||||||
|
"""
|
||||||
|
self.finished_add_cache_task_queue.put(req_ids)
|
||||||
|
|
||||||
|
def get_finished_add_cache_task_req(self) -> str:
|
||||||
|
"""
|
||||||
|
Get finished request ID from the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Finished request ID
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
if self.finished_add_cache_task_queue.empty():
|
||||||
|
return ans
|
||||||
|
ans = self.finished_add_cache_task_queue.get()
|
||||||
|
llm_logger.debug(f"get finished req: {ans}")
|
||||||
|
return ans
|
||||||
|
|
||||||
def disaggregate_queue_empty(self):
|
def disaggregate_queue_empty(self):
|
||||||
"""
|
"""
|
||||||
Check if the disaggregated task queue is empty.
|
Check if the disaggregated task queue is empty.
|
||||||
|
@@ -211,9 +211,8 @@ class DeepEPEngine:
|
|||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_local_experts = num_experts // ep_size
|
self.num_local_experts = num_experts // ep_size
|
||||||
self.async_finish = async_finish
|
self.async_finish = async_finish
|
||||||
from paddle.base.core import Config
|
|
||||||
|
|
||||||
self.ep_config = Config(24, 6, 256)
|
self.ep_config = None
|
||||||
|
|
||||||
# Store phase and role for buffer management
|
# Store phase and role for buffer management
|
||||||
self._splitwise_role = splitwise_role
|
self._splitwise_role = splitwise_role
|
||||||
|
@@ -76,6 +76,7 @@ else:
|
|||||||
update_inputs,
|
update_inputs,
|
||||||
step_reschedule,
|
step_reschedule,
|
||||||
update_inputs_v1,
|
update_inputs_v1,
|
||||||
|
speculate_step_reschedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -413,12 +414,11 @@ def step_cuda(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if speculative_config.method is not None:
|
if speculative_config.method is not None:
|
||||||
if enable_prefix_caching:
|
if DISABLE_RECOVER:
|
||||||
speculate_step_system_cache(
|
speculate_step_reschedule(
|
||||||
share_inputs["stop_flags"],
|
share_inputs["stop_flags"],
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["seq_lens_this_time"],
|
||||||
share_inputs["step_seq_lens_encoder"],
|
share_inputs["step_seq_lens_encoder"],
|
||||||
share_inputs["step_seq_lens_decoder"],
|
|
||||||
share_inputs["seq_lens_encoder"],
|
share_inputs["seq_lens_encoder"],
|
||||||
share_inputs["seq_lens_decoder"],
|
share_inputs["seq_lens_decoder"],
|
||||||
share_inputs["block_tables"],
|
share_inputs["block_tables"],
|
||||||
@@ -444,64 +444,67 @@ def step_cuda(
|
|||||||
speculative_config.num_speculative_tokens,
|
speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
speculate_step_paddle(
|
if enable_prefix_caching:
|
||||||
share_inputs["stop_flags"],
|
speculate_step_system_cache(
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["stop_flags"],
|
||||||
share_inputs["step_seq_lens_encoder"],
|
share_inputs["seq_lens_this_time"],
|
||||||
share_inputs["seq_lens_encoder"],
|
share_inputs["step_seq_lens_encoder"],
|
||||||
share_inputs["seq_lens_decoder"],
|
share_inputs["step_seq_lens_decoder"],
|
||||||
share_inputs["block_tables"],
|
share_inputs["seq_lens_encoder"],
|
||||||
share_inputs["encoder_block_lens"],
|
share_inputs["seq_lens_decoder"],
|
||||||
share_inputs["is_block_step"],
|
share_inputs["block_tables"],
|
||||||
share_inputs["step_block_list"],
|
share_inputs["encoder_block_lens"],
|
||||||
share_inputs["step_lens"],
|
share_inputs["is_block_step"],
|
||||||
share_inputs["recover_block_list"],
|
share_inputs["step_block_list"],
|
||||||
share_inputs["recover_lens"],
|
share_inputs["step_lens"],
|
||||||
share_inputs["need_block_list"],
|
share_inputs["recover_block_list"],
|
||||||
share_inputs["need_block_len"],
|
share_inputs["recover_lens"],
|
||||||
share_inputs["used_list_len"],
|
share_inputs["need_block_list"],
|
||||||
share_inputs["free_list"],
|
share_inputs["need_block_len"],
|
||||||
share_inputs["free_list_len"],
|
share_inputs["used_list_len"],
|
||||||
share_inputs["input_ids"],
|
share_inputs["free_list"],
|
||||||
share_inputs["pre_ids"],
|
share_inputs["free_list_len"],
|
||||||
share_inputs["step_idx"],
|
share_inputs["input_ids"],
|
||||||
share_inputs["next_tokens"],
|
share_inputs["pre_ids"],
|
||||||
share_inputs["first_token_ids"],
|
share_inputs["step_idx"],
|
||||||
share_inputs["accept_num"],
|
share_inputs["next_tokens"],
|
||||||
block_size,
|
share_inputs["first_token_ids"],
|
||||||
enc_dec_block_num,
|
share_inputs["accept_num"],
|
||||||
speculative_config.num_speculative_tokens,
|
block_size,
|
||||||
)
|
enc_dec_block_num,
|
||||||
|
speculative_config.num_speculative_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
speculate_step_paddle(
|
||||||
|
share_inputs["stop_flags"],
|
||||||
|
share_inputs["seq_lens_this_time"],
|
||||||
|
share_inputs["step_seq_lens_encoder"],
|
||||||
|
share_inputs["seq_lens_encoder"],
|
||||||
|
share_inputs["seq_lens_decoder"],
|
||||||
|
share_inputs["block_tables"],
|
||||||
|
share_inputs["encoder_block_lens"],
|
||||||
|
share_inputs["is_block_step"],
|
||||||
|
share_inputs["step_block_list"],
|
||||||
|
share_inputs["step_lens"],
|
||||||
|
share_inputs["recover_block_list"],
|
||||||
|
share_inputs["recover_lens"],
|
||||||
|
share_inputs["need_block_list"],
|
||||||
|
share_inputs["need_block_len"],
|
||||||
|
share_inputs["used_list_len"],
|
||||||
|
share_inputs["free_list"],
|
||||||
|
share_inputs["free_list_len"],
|
||||||
|
share_inputs["input_ids"],
|
||||||
|
share_inputs["pre_ids"],
|
||||||
|
share_inputs["step_idx"],
|
||||||
|
share_inputs["next_tokens"],
|
||||||
|
share_inputs["first_token_ids"],
|
||||||
|
share_inputs["accept_num"],
|
||||||
|
block_size,
|
||||||
|
enc_dec_block_num,
|
||||||
|
speculative_config.num_speculative_tokens,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if enable_prefix_caching:
|
if DISABLE_RECOVER:
|
||||||
step_system_cache(
|
|
||||||
share_inputs["stop_flags"],
|
|
||||||
share_inputs["seq_lens_this_time"],
|
|
||||||
share_inputs["step_seq_lens_encoder"],
|
|
||||||
share_inputs["step_seq_lens_decoder"],
|
|
||||||
share_inputs["seq_lens_encoder"],
|
|
||||||
share_inputs["seq_lens_decoder"],
|
|
||||||
share_inputs["block_tables"],
|
|
||||||
share_inputs["encoder_block_lens"],
|
|
||||||
share_inputs["is_block_step"],
|
|
||||||
share_inputs["step_block_list"],
|
|
||||||
share_inputs["step_lens"],
|
|
||||||
share_inputs["recover_block_list"],
|
|
||||||
share_inputs["recover_lens"],
|
|
||||||
share_inputs["need_block_list"],
|
|
||||||
share_inputs["need_block_len"],
|
|
||||||
share_inputs["used_list_len"],
|
|
||||||
share_inputs["free_list"],
|
|
||||||
share_inputs["free_list_len"],
|
|
||||||
share_inputs["input_ids"],
|
|
||||||
share_inputs["pre_ids"],
|
|
||||||
share_inputs["step_idx"],
|
|
||||||
share_inputs["next_tokens"],
|
|
||||||
share_inputs["first_token_ids"],
|
|
||||||
block_size,
|
|
||||||
enc_dec_block_num,
|
|
||||||
)
|
|
||||||
elif DISABLE_RECOVER:
|
|
||||||
step_reschedule(
|
step_reschedule(
|
||||||
share_inputs["stop_flags"],
|
share_inputs["stop_flags"],
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["seq_lens_this_time"],
|
||||||
@@ -529,32 +532,61 @@ def step_cuda(
|
|||||||
enc_dec_block_num,
|
enc_dec_block_num,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
step_paddle(
|
if enable_prefix_caching:
|
||||||
share_inputs["stop_flags"],
|
step_system_cache(
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["stop_flags"],
|
||||||
share_inputs["step_seq_lens_encoder"],
|
share_inputs["seq_lens_this_time"],
|
||||||
share_inputs["seq_lens_encoder"],
|
share_inputs["step_seq_lens_encoder"],
|
||||||
share_inputs["seq_lens_decoder"],
|
share_inputs["step_seq_lens_decoder"],
|
||||||
share_inputs["block_tables"],
|
share_inputs["seq_lens_encoder"],
|
||||||
share_inputs["encoder_block_lens"],
|
share_inputs["seq_lens_decoder"],
|
||||||
share_inputs["is_block_step"],
|
share_inputs["block_tables"],
|
||||||
share_inputs["step_block_list"],
|
share_inputs["encoder_block_lens"],
|
||||||
share_inputs["step_lens"],
|
share_inputs["is_block_step"],
|
||||||
share_inputs["recover_block_list"],
|
share_inputs["step_block_list"],
|
||||||
share_inputs["recover_lens"],
|
share_inputs["step_lens"],
|
||||||
share_inputs["need_block_list"],
|
share_inputs["recover_block_list"],
|
||||||
share_inputs["need_block_len"],
|
share_inputs["recover_lens"],
|
||||||
share_inputs["used_list_len"],
|
share_inputs["need_block_list"],
|
||||||
share_inputs["free_list"],
|
share_inputs["need_block_len"],
|
||||||
share_inputs["free_list_len"],
|
share_inputs["used_list_len"],
|
||||||
share_inputs["input_ids"],
|
share_inputs["free_list"],
|
||||||
share_inputs["pre_ids"],
|
share_inputs["free_list_len"],
|
||||||
share_inputs["step_idx"],
|
share_inputs["input_ids"],
|
||||||
share_inputs["next_tokens"],
|
share_inputs["pre_ids"],
|
||||||
share_inputs["first_token_ids"],
|
share_inputs["step_idx"],
|
||||||
block_size,
|
share_inputs["next_tokens"],
|
||||||
enc_dec_block_num,
|
share_inputs["first_token_ids"],
|
||||||
)
|
block_size,
|
||||||
|
enc_dec_block_num,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
step_paddle(
|
||||||
|
share_inputs["stop_flags"],
|
||||||
|
share_inputs["seq_lens_this_time"],
|
||||||
|
share_inputs["step_seq_lens_encoder"],
|
||||||
|
share_inputs["seq_lens_encoder"],
|
||||||
|
share_inputs["seq_lens_decoder"],
|
||||||
|
share_inputs["block_tables"],
|
||||||
|
share_inputs["encoder_block_lens"],
|
||||||
|
share_inputs["is_block_step"],
|
||||||
|
share_inputs["step_block_list"],
|
||||||
|
share_inputs["step_lens"],
|
||||||
|
share_inputs["recover_block_list"],
|
||||||
|
share_inputs["recover_lens"],
|
||||||
|
share_inputs["need_block_list"],
|
||||||
|
share_inputs["need_block_len"],
|
||||||
|
share_inputs["used_list_len"],
|
||||||
|
share_inputs["free_list"],
|
||||||
|
share_inputs["free_list_len"],
|
||||||
|
share_inputs["input_ids"],
|
||||||
|
share_inputs["pre_ids"],
|
||||||
|
share_inputs["step_idx"],
|
||||||
|
share_inputs["next_tokens"],
|
||||||
|
share_inputs["first_token_ids"],
|
||||||
|
block_size,
|
||||||
|
enc_dec_block_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rebuild_padding(
|
def rebuild_padding(
|
||||||
|
@@ -58,7 +58,6 @@ class TokenProcessor:
|
|||||||
self.split_connector = split_connector
|
self.split_connector = split_connector
|
||||||
|
|
||||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
|
||||||
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
|
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
|
||||||
self.zmq_server = ZmqIpcServer(
|
self.zmq_server = ZmqIpcServer(
|
||||||
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
|
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
|
||||||
@@ -298,10 +297,15 @@ class TokenProcessor:
|
|||||||
try:
|
try:
|
||||||
is_blocking = True
|
is_blocking = True
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
if (
|
||||||
|
self.cfg.parallel_config.enable_expert_parallel
|
||||||
|
and self.cfg.parallel_config.data_parallel_size > 1
|
||||||
|
):
|
||||||
|
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||||
|
else:
|
||||||
|
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||||
if self.output_tokens[0] == -2:
|
if self.output_tokens[0] == -2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.use_logprobs:
|
if self.use_logprobs:
|
||||||
get_output_topk(
|
get_output_topk(
|
||||||
@@ -370,14 +374,18 @@ class TokenProcessor:
|
|||||||
llm_logger.info(f"finished_task_id: {finished_task_id}")
|
llm_logger.info(f"finished_task_id: {finished_task_id}")
|
||||||
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
|
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
|
||||||
if task_id in self.prefill_result_status:
|
if task_id in self.prefill_result_status:
|
||||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
self.resource_manager.stop_flags[index] = True
|
self.resource_manager.finish_requests_async(task_id)
|
||||||
self.resource_manager.tasks_list[index] = None
|
else:
|
||||||
self.resource_manager._recycle_block_tables(task)
|
self.resource_manager.stop_flags[index] = True
|
||||||
|
self.resource_manager.tasks_list[index] = None
|
||||||
|
self.resource_manager._recycle_block_tables(task)
|
||||||
|
if task_id in self.resource_manager.req_dict:
|
||||||
|
del self.resource_manager.req_dict[task_id]
|
||||||
if self.prefill_result_status[task_id] != "finished":
|
if self.prefill_result_status[task_id] != "finished":
|
||||||
result.error_code = 400
|
result.error_code = 400
|
||||||
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
|
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
|
||||||
del self.resource_manager.req_dict[task_id]
|
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
time.sleep(0.002)
|
time.sleep(0.002)
|
||||||
@@ -388,6 +396,8 @@ class TokenProcessor:
|
|||||||
self.resource_manager.stop_flags[index] = True
|
self.resource_manager.stop_flags[index] = True
|
||||||
self.resource_manager.tasks_list[index] = None
|
self.resource_manager.tasks_list[index] = None
|
||||||
self.resource_manager._recycle_block_tables(task)
|
self.resource_manager._recycle_block_tables(task)
|
||||||
|
if task_id in self.resource_manager.req_dict:
|
||||||
|
del self.resource_manager.req_dict[task_id]
|
||||||
|
|
||||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
|
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
|
||||||
main_process_metrics.available_gpu_block_num.set(
|
main_process_metrics.available_gpu_block_num.set(
|
||||||
@@ -461,16 +471,22 @@ class TokenProcessor:
|
|||||||
|
|
||||||
task_id = task.request_id
|
task_id = task.request_id
|
||||||
if self.cfg.speculative_config.method:
|
if self.cfg.speculative_config.method:
|
||||||
token_ids = tokens[
|
if accept_num[i] == -3:
|
||||||
2
|
recovery_stop = True
|
||||||
+ SPECULATE_MAX_BSZ
|
if recovery_stop:
|
||||||
+ i * MAX_DRAFT_TOKENS : 2
|
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||||
+ SPECULATE_MAX_BSZ
|
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||||
+ i * MAX_DRAFT_TOKENS
|
else:
|
||||||
+ accept_num[i]
|
token_ids = tokens[
|
||||||
].tolist()
|
2
|
||||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
+ SPECULATE_MAX_BSZ
|
||||||
continue
|
+ i * MAX_DRAFT_TOKENS : 2
|
||||||
|
+ SPECULATE_MAX_BSZ
|
||||||
|
+ i * MAX_DRAFT_TOKENS
|
||||||
|
+ accept_num[i]
|
||||||
|
].tolist()
|
||||||
|
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
token_id = int(tokens[i, 0])
|
token_id = int(tokens[i, 0])
|
||||||
token_ids = [token_id]
|
token_ids = [token_id]
|
||||||
@@ -527,7 +543,7 @@ class TokenProcessor:
|
|||||||
if self.tokens_counter[task_id] == 0:
|
if self.tokens_counter[task_id] == 0:
|
||||||
if task.messages is not None:
|
if task.messages is not None:
|
||||||
result.prompt = task.messages
|
result.prompt = task.messages
|
||||||
result.num_cached_tokens = task.num_cached_tokens
|
result.num_cached_tokens = task.num_cached_tokens
|
||||||
|
|
||||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||||
|
|
||||||
@@ -537,7 +553,8 @@ class TokenProcessor:
|
|||||||
for token_id in token_ids:
|
for token_id in token_ids:
|
||||||
self.tokens_counter[task_id] += 1
|
self.tokens_counter[task_id] += 1
|
||||||
if token_id != RECOVERY_STOP_SIGNAL:
|
if token_id != RECOVERY_STOP_SIGNAL:
|
||||||
result.outputs.token_ids.append(token_id)
|
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
|
||||||
|
result.outputs.token_ids.append(token_id)
|
||||||
task.output_token_ids.append(token_id)
|
task.output_token_ids.append(token_id)
|
||||||
if self.use_logprobs:
|
if self.use_logprobs:
|
||||||
result.outputs.logprob = float(scores[i, 0])
|
result.outputs.logprob = float(scores[i, 0])
|
||||||
@@ -567,7 +584,11 @@ class TokenProcessor:
|
|||||||
self._record_completion_metrics(task, current_time)
|
self._record_completion_metrics(task, current_time)
|
||||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||||
break
|
break
|
||||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
if (
|
||||||
|
not is_prefill
|
||||||
|
or self.cfg.scheduler_config.name == "splitwise"
|
||||||
|
or self.cfg.scheduler_config.name == "dp"
|
||||||
|
):
|
||||||
batch_result.append(result)
|
batch_result.append(result)
|
||||||
|
|
||||||
self.postprocess(batch_result)
|
self.postprocess(batch_result)
|
||||||
@@ -609,7 +630,7 @@ class TokenProcessor:
|
|||||||
self.cfg.speculative_config.num_speculative_tokens,
|
self.cfg.speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
real_accept_num = [x for x in accept_num if x != 0]
|
real_accept_num = [x for x in accept_num if x > 0]
|
||||||
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
||||||
self.num_accepted_tokens += num_accepted_tokens
|
self.num_accepted_tokens += num_accepted_tokens
|
||||||
num_emitted_tokens = sum(real_accept_num)
|
num_emitted_tokens = sum(real_accept_num)
|
||||||
|
@@ -18,6 +18,7 @@ import redis
|
|||||||
|
|
||||||
from fastdeploy.utils import llm_logger
|
from fastdeploy.utils import llm_logger
|
||||||
|
|
||||||
|
from .dp_scheduler import DPScheduler
|
||||||
from .global_scheduler import GlobalScheduler
|
from .global_scheduler import GlobalScheduler
|
||||||
from .local_scheduler import LocalScheduler
|
from .local_scheduler import LocalScheduler
|
||||||
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
|
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
|
||||||
@@ -89,6 +90,54 @@ class LocalSchedulerConfig:
|
|||||||
llm_logger.info("=============================================================")
|
llm_logger.info("=============================================================")
|
||||||
|
|
||||||
|
|
||||||
|
class DPLocalSchedulerConfig(LocalSchedulerConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for DPLocalScheduler.
|
||||||
|
Attributes:
|
||||||
|
max_size: Maximum number of concurrent requests (-1 for unlimited)
|
||||||
|
ttl: Time-to-live in seconds for request expiration
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int = -1,
|
||||||
|
ttl: int = 900,
|
||||||
|
max_model_len: int = 8192,
|
||||||
|
enable_chunked_prefill: bool = False,
|
||||||
|
max_num_partial_prefills: int = 1,
|
||||||
|
max_long_partial_prefills: int = 1,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
splitwise_role: str = "prefill",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize LocalScheduler configuration.
|
||||||
|
Args:
|
||||||
|
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
|
||||||
|
ttl: Time-to-live in seconds for request expiration (default 900s)
|
||||||
|
max_model_len: Maximum model context length in tokens
|
||||||
|
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||||
|
max_num_partial_prefills: Max partial prefill operations allowed
|
||||||
|
max_long_partial_prefills: Max long-running partial prefill ops
|
||||||
|
long_prefill_token_threshold: Token count threshold for long prefill
|
||||||
|
**kwargs: Additional unused arguments (for forward compatibility)
|
||||||
|
Note:
|
||||||
|
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
|
||||||
|
- See LocalScheduler class for implementation details
|
||||||
|
"""
|
||||||
|
self.max_size = max_size
|
||||||
|
self.ttl = ttl
|
||||||
|
|
||||||
|
self.max_model_len = max_model_len
|
||||||
|
self.enable_chunked_prefill = enable_chunked_prefill
|
||||||
|
self.max_num_partial_prefills = max_num_partial_prefills
|
||||||
|
self.max_long_partial_prefills = max_long_partial_prefills
|
||||||
|
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||||
|
if self.long_prefill_token_threshold == 0:
|
||||||
|
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||||
|
self.splitwise_role = splitwise_role
|
||||||
|
|
||||||
|
|
||||||
class GlobalSchedulerConfig:
|
class GlobalSchedulerConfig:
|
||||||
"""
|
"""
|
||||||
Configuration class for GlobalScheduler (Redis-based).
|
Configuration class for GlobalScheduler (Redis-based).
|
||||||
@@ -235,6 +284,9 @@ class SchedulerConfig:
|
|||||||
if self.name == "splitwise":
|
if self.name == "splitwise":
|
||||||
self.config = SplitWiseSchedulerConfig(**args)
|
self.config = SplitWiseSchedulerConfig(**args)
|
||||||
|
|
||||||
|
if self.name == "dp":
|
||||||
|
self.config = DPLocalSchedulerConfig(**args)
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
"""
|
"""
|
||||||
Validate the configuration.
|
Validate the configuration.
|
||||||
@@ -242,7 +294,7 @@ class SchedulerConfig:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If invalid scheduler type is specified
|
Exception: If invalid scheduler type is specified
|
||||||
"""
|
"""
|
||||||
if self.name not in ["local", "global", "splitwise"]:
|
if self.name not in ["local", "global", "splitwise", "dp"]:
|
||||||
raise Exception(f"Unknown scheduler type {self.name}")
|
raise Exception(f"Unknown scheduler type {self.name}")
|
||||||
|
|
||||||
self.config.check()
|
self.config.check()
|
||||||
@@ -280,6 +332,17 @@ class SchedulerConfig:
|
|||||||
if self.name == "splitwise":
|
if self.name == "splitwise":
|
||||||
return SplitWiseScheduler(self.config)
|
return SplitWiseScheduler(self.config)
|
||||||
|
|
||||||
|
if self.name == "dp":
|
||||||
|
return DPScheduler(
|
||||||
|
max_size=self.config.max_size,
|
||||||
|
ttl=self.config.ttl,
|
||||||
|
enable_chunked_prefill=self.config.enable_chunked_prefill,
|
||||||
|
max_num_partial_prefills=self.config.max_num_partial_prefills,
|
||||||
|
max_long_partial_prefills=self.config.max_long_partial_prefills,
|
||||||
|
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
|
||||||
|
splitwise_role=self.config.splitwise_role,
|
||||||
|
)
|
||||||
|
|
||||||
return LocalScheduler(
|
return LocalScheduler(
|
||||||
max_size=self.config.max_size,
|
max_size=self.config.max_size,
|
||||||
ttl=self.config.ttl,
|
ttl=self.config.ttl,
|
||||||
|
272
fastdeploy/scheduler/dp_scheduler.py
Normal file
272
fastdeploy/scheduler/dp_scheduler.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from multiprocessing import Queue
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import Request, RequestOutput
|
||||||
|
from fastdeploy.scheduler.data import ScheduledResponse
|
||||||
|
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||||
|
from fastdeploy.utils import envs, get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class DPLocalScheduler(LocalScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int,
|
||||||
|
ttl: int,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
max_num_partial_prefills: int,
|
||||||
|
max_long_partial_prefills: int,
|
||||||
|
long_prefill_token_threshold: int,
|
||||||
|
splitwise_role: str = "prefill",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
max_size,
|
||||||
|
ttl,
|
||||||
|
enable_chunked_prefill,
|
||||||
|
max_num_partial_prefills,
|
||||||
|
max_long_partial_prefills,
|
||||||
|
long_prefill_token_threshold,
|
||||||
|
)
|
||||||
|
self.splitwise_role = splitwise_role
|
||||||
|
self.scheduler_logger = logging
|
||||||
|
|
||||||
|
def put_results(self, results: List[RequestOutput]):
|
||||||
|
"""
|
||||||
|
Add processing results back to the scheduler.
|
||||||
|
Args:
|
||||||
|
results: List of RequestOutput objects containing results
|
||||||
|
"""
|
||||||
|
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
|
||||||
|
|
||||||
|
finished_responses = [response.request_id for response in responses if response.finished]
|
||||||
|
if len(finished_responses) > 0:
|
||||||
|
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||||
|
|
||||||
|
with self.mutex:
|
||||||
|
for response in responses:
|
||||||
|
if response.request_id not in self.responses:
|
||||||
|
self.responses[response.request_id] = [response]
|
||||||
|
continue
|
||||||
|
self.responses[response.request_id].append(response)
|
||||||
|
self.responses_not_empty.notify_all()
|
||||||
|
|
||||||
|
def _recycle(self, request_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Clean up expired or completed requests to free memory.
|
||||||
|
Args:
|
||||||
|
request_id: Optional specific request ID to remove.
|
||||||
|
If None, removes all expired requests.
|
||||||
|
"""
|
||||||
|
if request_id is not None:
|
||||||
|
self.requests.pop(request_id, None)
|
||||||
|
self.responses.pop(request_id, None)
|
||||||
|
if self.splitwise_role == "decode":
|
||||||
|
return
|
||||||
|
self.ids.pop(self.ids.index(request_id))
|
||||||
|
self.ids_read_cursor -= 1
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.max_size <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.requests) <= self.max_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
expired_ids = []
|
||||||
|
for request_id in self.ids:
|
||||||
|
request = self.requests[request_id]
|
||||||
|
if now - request.schedule_time < self.ttl:
|
||||||
|
break
|
||||||
|
expired_ids.append(request.request_id)
|
||||||
|
|
||||||
|
for i, expired_id in enumerate(expired_ids):
|
||||||
|
self.requests.pop(expired_id, None)
|
||||||
|
self.responses.pop(expired_id, None)
|
||||||
|
self.ids.pop(i)
|
||||||
|
|
||||||
|
if len(expired_ids) > 0:
|
||||||
|
if len(expired_ids) - 1 >= self.ids_read_cursor:
|
||||||
|
self.ids_read_cursor = 0
|
||||||
|
else:
|
||||||
|
self.ids_read_cursor -= len(expired_ids)
|
||||||
|
|
||||||
|
def get_requests(
|
||||||
|
self,
|
||||||
|
available_blocks,
|
||||||
|
block_size,
|
||||||
|
reserved_output_blocks,
|
||||||
|
max_num_batched_tokens,
|
||||||
|
batch=1,
|
||||||
|
) -> List[Request]:
|
||||||
|
"""
|
||||||
|
Retrieve requests from the scheduler based on available resources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_blocks: Number of available processing blocks
|
||||||
|
block_size: Size of each processing block
|
||||||
|
reserved_output_blocks: Blocks reserved for output
|
||||||
|
max_num_batched_tokens: Maximum tokens that can be batched
|
||||||
|
batch: Preferred batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Request objects ready for processing
|
||||||
|
"""
|
||||||
|
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||||
|
self.scheduler_logger.debug(
|
||||||
|
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||||
|
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||||
|
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
required_total_blocks = 0
|
||||||
|
current_prefill_tokens = 0
|
||||||
|
start_batch_time = time.time()
|
||||||
|
requests: List[Request] = []
|
||||||
|
|
||||||
|
with self.requests_not_empty:
|
||||||
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
while True:
|
||||||
|
batch_ids = self.requests_not_empty.wait_for(
|
||||||
|
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||||
|
0.005,
|
||||||
|
)
|
||||||
|
if batch_ids:
|
||||||
|
for request_id in batch_ids:
|
||||||
|
request = self.requests[request_id]
|
||||||
|
required_input_blocks = self.calc_required_blocks(
|
||||||
|
request.prompt_tokens_ids_len, block_size
|
||||||
|
)
|
||||||
|
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||||
|
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||||
|
if required_total_blocks > available_blocks:
|
||||||
|
break
|
||||||
|
|
||||||
|
requests.append(request.raw)
|
||||||
|
self.ids_read_cursor += 1
|
||||||
|
start_batch_time = time.time()
|
||||||
|
if current_prefill_tokens > max_num_batched_tokens:
|
||||||
|
break
|
||||||
|
if len(requests) >= batch:
|
||||||
|
break
|
||||||
|
if (
|
||||||
|
(current_prefill_tokens > max_num_batched_tokens)
|
||||||
|
or (len(requests) >= batch)
|
||||||
|
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||||
|
):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
batch_ids = self.requests_not_empty.wait_for(
|
||||||
|
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||||
|
0.005,
|
||||||
|
)
|
||||||
|
if batch_ids:
|
||||||
|
for request_id in batch_ids:
|
||||||
|
request = self.requests[request_id]
|
||||||
|
requests.append(request.raw)
|
||||||
|
self.ids_read_cursor += 1
|
||||||
|
|
||||||
|
if batch_ids:
|
||||||
|
if len(batch_ids) > 0 and len(requests) == 0:
|
||||||
|
self.scheduler_logger.debug(
|
||||||
|
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(requests) > 0:
|
||||||
|
self.scheduler_logger.info(
|
||||||
|
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
class DPScheduler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int,
|
||||||
|
ttl: int,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
max_num_partial_prefills: int,
|
||||||
|
max_long_partial_prefills: int,
|
||||||
|
long_prefill_token_threshold: int,
|
||||||
|
splitwise_role: str = "prefill",
|
||||||
|
):
|
||||||
|
self._scheduler = DPLocalScheduler(
|
||||||
|
max_size,
|
||||||
|
ttl,
|
||||||
|
enable_chunked_prefill,
|
||||||
|
max_num_partial_prefills,
|
||||||
|
max_long_partial_prefills,
|
||||||
|
long_prefill_token_threshold,
|
||||||
|
splitwise_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
|
||||||
|
self.dp_rank = dp_rank
|
||||||
|
self.request_queues = request_queues
|
||||||
|
self.result_queue = result_queue
|
||||||
|
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
|
||||||
|
self._scheduler.scheduler_logger = self.scheduler_logger
|
||||||
|
threading.Thread(target=self._put_requests_to_local).start()
|
||||||
|
threading.Thread(target=self._get_response_from_local).start()
|
||||||
|
|
||||||
|
def put_requests(self, requests: List[Dict]):
|
||||||
|
results = []
|
||||||
|
for request in requests:
|
||||||
|
if not hasattr(request, "dp_rank"):
|
||||||
|
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
|
||||||
|
self.request_queues[request.dp_rank].put(request)
|
||||||
|
results.append((request.request_id, None))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _put_requests_to_local(self):
|
||||||
|
while True:
|
||||||
|
request = self.request_queues[self.dp_rank].get()
|
||||||
|
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
|
||||||
|
self._scheduler.put_requests([request])
|
||||||
|
|
||||||
|
def _get_response_from_local(self):
|
||||||
|
while True:
|
||||||
|
results = self._scheduler.get_results()
|
||||||
|
if len(results) == 0:
|
||||||
|
continue
|
||||||
|
self.result_queue.put(results)
|
||||||
|
|
||||||
|
def get_requests(
|
||||||
|
self,
|
||||||
|
available_blocks,
|
||||||
|
block_size,
|
||||||
|
reserved_output_blocks,
|
||||||
|
max_num_batched_tokens,
|
||||||
|
batch=1,
|
||||||
|
) -> List[Request]:
|
||||||
|
return self._scheduler.get_requests(
|
||||||
|
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_unhandled_request_num(self):
|
||||||
|
return len(self._scheduler.requests)
|
||||||
|
|
||||||
|
def put_results(self, results: List[RequestOutput]):
|
||||||
|
self._scheduler.put_results(results)
|
||||||
|
|
||||||
|
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||||
|
return self.result_queue.get()
|
@@ -28,8 +28,6 @@ from fastdeploy.inter_communicator import EngineWorkerQueue
|
|||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
|
||||||
|
|
||||||
|
|
||||||
class SplitwiseConnector:
|
class SplitwiseConnector:
|
||||||
"""
|
"""
|
||||||
@@ -46,12 +44,19 @@ class SplitwiseConnector:
|
|||||||
resource_manager (object): Resource manager object.
|
resource_manager (object): Resource manager object.
|
||||||
"""
|
"""
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||||
|
self.logger = get_logger(
|
||||||
|
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||||
self.engine_worker_queue = worker_queue
|
self.engine_worker_queue = worker_queue
|
||||||
self.resource_manager = resource_manager
|
self.resource_manager = resource_manager
|
||||||
self.connect_innode_instances = {}
|
self.connect_innode_instances = {}
|
||||||
self.temp_cache_info = dict()
|
self.temp_cache_info = dict()
|
||||||
self.current_request_ids = dict()
|
self.current_request_ids = dict()
|
||||||
self.idx = self.cfg.parallel_config.local_data_parallel_id
|
self.idx = self.cfg.parallel_config.local_data_parallel_id
|
||||||
|
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||||
|
|
||||||
if self.cfg.cache_config.pd_comm_port is not None:
|
if self.cfg.cache_config.pd_comm_port is not None:
|
||||||
self.zmq_ctx = zmq.Context()
|
self.zmq_ctx = zmq.Context()
|
||||||
@@ -70,7 +75,7 @@ class SplitwiseConnector:
|
|||||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||||
|
|
||||||
self.poller = zmq.Poller()
|
self.poller = zmq.Poller()
|
||||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||||
@@ -90,17 +95,17 @@ class SplitwiseConnector:
|
|||||||
if not socks:
|
if not socks:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.debug(f"receive {socks}")
|
self.logger.debug(f"receive {socks}")
|
||||||
|
|
||||||
frames = self.router_socket.recv_multipart()
|
frames = self.router_socket.recv_multipart()
|
||||||
logger.debug(f"frames: {frames}")
|
self.logger.debug(f"frames: {frames}")
|
||||||
message = frames[-1]
|
message = frames[-1]
|
||||||
self.io_executor.submit(self._process_message, message)
|
self.io_executor.submit(self._process_message, message)
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
else:
|
else:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _get_push_socket(self, addr):
|
def _get_push_socket(self, addr):
|
||||||
@@ -112,7 +117,7 @@ class SplitwiseConnector:
|
|||||||
return sock
|
return sock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Establishing new connection to {addr}")
|
self.logger.info(f"Establishing new connection to {addr}")
|
||||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||||
|
|
||||||
# 设置连接参数
|
# 设置连接参数
|
||||||
@@ -131,7 +136,7 @@ class SplitwiseConnector:
|
|||||||
return sock
|
return sock
|
||||||
|
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
logger.error(f"Connection to {addr} failed: {e}")
|
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||||
|
|
||||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||||
|
|
||||||
@@ -140,7 +145,7 @@ class SplitwiseConnector:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Sent {msg_type} to {addr}")
|
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||||
message = self._serialize_message(msg_type, payload)
|
message = self._serialize_message(msg_type, payload)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -148,19 +153,19 @@ class SplitwiseConnector:
|
|||||||
sock = self._get_push_socket(addr)
|
sock = self._get_push_socket(addr)
|
||||||
sock.send_multipart([b"", message])
|
sock.send_multipart([b"", message])
|
||||||
|
|
||||||
logger.info(f"Sent {msg_type} to {addr}")
|
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||||
|
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
logger.warning(f"Connection to {addr} not established")
|
self.logger.warning(f"Connection to {addr} not established")
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
logger.warning(f"Send queue full for {addr}")
|
self.logger.warning(f"Send queue full for {addr}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||||
main_process_metrics.send_cache_failed_num.inc()
|
main_process_metrics.send_cache_failed_num.inc()
|
||||||
self._close_connection(addr)
|
self._close_connection(addr)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Message preparation failed: {e}")
|
self.logger.error(f"Message preparation failed: {e}")
|
||||||
|
|
||||||
def _close_connection(self, addr):
|
def _close_connection(self, addr):
|
||||||
"""
|
"""
|
||||||
@@ -265,7 +270,7 @@ class SplitwiseConnector:
|
|||||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||||
)
|
)
|
||||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||||
self.current_request_ids[task.request_id] = "init"
|
self.current_request_ids[task.request_id] = "init"
|
||||||
decode_diagg = task.disaggregate_info["cache_info"]
|
decode_diagg = task.disaggregate_info["cache_info"]
|
||||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||||
@@ -295,7 +300,7 @@ class SplitwiseConnector:
|
|||||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||||
logger.info(f"send splitwise tasks to port {port} decode")
|
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||||
current_port = port
|
current_port = port
|
||||||
return current_port
|
return current_port
|
||||||
|
|
||||||
@@ -305,7 +310,7 @@ class SplitwiseConnector:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(tasks_list, list):
|
if not isinstance(tasks_list, list):
|
||||||
tasks_list = [tasks_list]
|
tasks_list = [tasks_list]
|
||||||
logger.info("send first token to port decode")
|
self.logger.info("send first token to port decode")
|
||||||
if prefill_msg["transfer_protocol"] == "ipc":
|
if prefill_msg["transfer_protocol"] == "ipc":
|
||||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||||
if port not in self.connect_innode_instances:
|
if port not in self.connect_innode_instances:
|
||||||
@@ -313,7 +318,7 @@ class SplitwiseConnector:
|
|||||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||||
else:
|
else:
|
||||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||||
logger.info(f"send first token to port {node} decode")
|
self.logger.info(f"send first token to port {node} decode")
|
||||||
self._send_message(node, "decode", tasks_list)
|
self._send_message(node, "decode", tasks_list)
|
||||||
|
|
||||||
def create_connection(self, port):
|
def create_connection(self, port):
|
||||||
@@ -329,6 +334,26 @@ class SplitwiseConnector:
|
|||||||
client_id=0,
|
client_id=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_decode_allocated(self, task):
|
||||||
|
start_time = time.time()
|
||||||
|
if task.disaggregate_info is None:
|
||||||
|
return True, ""
|
||||||
|
if self.enable_decode_cache_task:
|
||||||
|
return True, ""
|
||||||
|
if task.disaggregate_info["role"] != "prefill":
|
||||||
|
return True, ""
|
||||||
|
while self.current_request_ids[task.request_id] == "init":
|
||||||
|
time.sleep(0.001)
|
||||||
|
if time.time() - start_time > 30:
|
||||||
|
del self.current_request_ids[task.request_id]
|
||||||
|
return False, "timeout"
|
||||||
|
msg = self.current_request_ids[task.request_id]
|
||||||
|
del self.current_request_ids[task.request_id]
|
||||||
|
if msg == "finished":
|
||||||
|
return True, ""
|
||||||
|
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||||
|
return False, msg
|
||||||
|
|
||||||
def send_cache_infos(self, tasks, current_id):
|
def send_cache_infos(self, tasks, current_id):
|
||||||
"""
|
"""
|
||||||
Send cache information to specific port.
|
Send cache information to specific port.
|
||||||
@@ -345,7 +370,7 @@ class SplitwiseConnector:
|
|||||||
for i in range(len(tasks)):
|
for i in range(len(tasks)):
|
||||||
if tasks[i].disaggregate_info is None:
|
if tasks[i].disaggregate_info is None:
|
||||||
continue
|
continue
|
||||||
logger.info(f"{tasks[i].disaggregate_info}")
|
self.logger.info(f"{tasks[i].disaggregate_info}")
|
||||||
if tasks[i].disaggregate_info["role"] == "decode":
|
if tasks[i].disaggregate_info["role"] == "decode":
|
||||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||||
cache_info = {
|
cache_info = {
|
||||||
@@ -380,11 +405,19 @@ class SplitwiseConnector:
|
|||||||
addr = "prefill"
|
addr = "prefill"
|
||||||
if current_id == -1:
|
if current_id == -1:
|
||||||
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
|
current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"]
|
||||||
cache_info = {
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
"request_id": tasks[i].request_id,
|
cache_info = {
|
||||||
"src_block_ids": tasks[i].block_tables,
|
"request_id": tasks[i].request_id,
|
||||||
"current_id": current_id,
|
"src_block_ids": tasks[i].block_tables,
|
||||||
}
|
"current_id": tasks[i].idx,
|
||||||
|
"need_prefill_tokens": tasks[i].need_prefill_tokens,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cache_info = {
|
||||||
|
"request_id": tasks[i].request_id,
|
||||||
|
"src_block_ids": tasks[i].block_tables,
|
||||||
|
"current_id": current_id,
|
||||||
|
}
|
||||||
if addr not in temp_cache_info:
|
if addr not in temp_cache_info:
|
||||||
temp_cache_info[addr] = []
|
temp_cache_info[addr] = []
|
||||||
|
|
||||||
@@ -396,7 +429,7 @@ class SplitwiseConnector:
|
|||||||
else:
|
else:
|
||||||
if len(temp_cache_info):
|
if len(temp_cache_info):
|
||||||
for k, v in temp_cache_info.items():
|
for k, v in temp_cache_info.items():
|
||||||
logger.info(f"{k} {v}")
|
self.logger.info(f"{k} {v}")
|
||||||
if ":" in str(k):
|
if ":" in str(k):
|
||||||
self._send_message(k, "cache_sync", v)
|
self._send_message(k, "cache_sync", v)
|
||||||
else:
|
else:
|
||||||
@@ -427,7 +460,7 @@ class SplitwiseConnector:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
msg_type, payload = self._deserialize_message(message)
|
msg_type, payload = self._deserialize_message(message)
|
||||||
logger.info(f"{msg_type}")
|
self.logger.info(f"{msg_type}")
|
||||||
|
|
||||||
if msg_type == "prefill":
|
if msg_type == "prefill":
|
||||||
self._handle_prefill(payload)
|
self._handle_prefill(payload)
|
||||||
@@ -435,11 +468,16 @@ class SplitwiseConnector:
|
|||||||
self._handle_decode(payload)
|
self._handle_decode(payload)
|
||||||
elif msg_type == "cache_sync":
|
elif msg_type == "cache_sync":
|
||||||
for task in payload:
|
for task in payload:
|
||||||
del self.current_request_ids[task["request_id"]]
|
self.logger.info(f"cache_sync task: {task}")
|
||||||
self.engine_worker_queue.put_cache_info(payload)
|
current_status = task.get("error_msg", "finished")
|
||||||
|
self.current_request_ids[task["request_id"]] = current_status
|
||||||
|
if self.enable_decode_cache_task:
|
||||||
|
del self.current_request_ids[task["request_id"]]
|
||||||
|
if current_status == "finished":
|
||||||
|
self.engine_worker_queue.put_cache_info(payload)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||||
|
|
||||||
def _handle_prefill(self, tasks):
|
def _handle_prefill(self, tasks):
|
||||||
"""
|
"""
|
||||||
@@ -462,8 +500,12 @@ class SplitwiseConnector:
|
|||||||
index=task["outputs"]["index"],
|
index=task["outputs"]["index"],
|
||||||
send_idx=0,
|
send_idx=0,
|
||||||
token_ids=task["outputs"]["token_ids"],
|
token_ids=task["outputs"]["token_ids"],
|
||||||
|
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||||
),
|
),
|
||||||
finished=True,
|
finished=True,
|
||||||
|
num_cached_tokens=task["num_cached_tokens"],
|
||||||
|
error_code=task["error_code"],
|
||||||
|
error_msg=task["error_msg"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||||
|
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@@ -259,6 +260,7 @@ class PaddleDisWorkerProc:
|
|||||||
"""Main event loop for Paddle Distributed Workers.
|
"""Main event loop for Paddle Distributed Workers.
|
||||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Currently, only support single node
|
# Currently, only support single node
|
||||||
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
||||||
req_ids = []
|
req_ids = []
|
||||||
@@ -643,6 +645,12 @@ def parse_args():
|
|||||||
help="Flag to specify dtype of lm_head as FP32",
|
help="Flag to specify dtype of lm_head as FP32",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-transfer-protocol",
|
||||||
|
type=str,
|
||||||
|
default="ipc",
|
||||||
|
help="support protocol list, comma separated, default is ipc",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--runner",
|
"--runner",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -762,8 +770,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
):
|
):
|
||||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
|
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
if args.splitwise_role != "mixed":
|
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
|
||||||
logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
|
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
|
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
|
||||||
@@ -772,6 +779,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
|
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
|
|
||||||
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill":
|
||||||
|
os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"
|
||||||
|
|
||||||
fd_config = FDConfig(
|
fd_config = FDConfig(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
|
Reference in New Issue
Block a user