[PD Disaggregation] support DP via v1 router and decouple DP and EP (#5197)

* [fix] support DP via v1 router and decouple DP and EP

* [fix] fix scripts

* [fix] reset model path

* [fix] dp use get_output_ep, fix router port type, update scripts

* [merge] merge with latest code

* [chore] remove some debug log

* [fix] fix code style check

* [fix] fix test_multi_api_server for log_dir name

* [chore] reduce logs

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yonghua Li
2025-12-04 15:38:43 +08:00
committed by GitHub
parent 5cd17fd662
commit f4119d51b4
15 changed files with 394 additions and 146 deletions

View File

@@ -26,71 +26,70 @@
#define MAX_BSZ 512
// #define GET_OUTPUT_DEBUG
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};
void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif
return;
}
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
GetOutput(x, rank_id, wait_flag, 1);
}
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
GetOutput(x, rank_id, wait_flag, msg_queue_id);
GetOutput(x, rank_id, wait_flag, msg_queue_id);
}
PD_BUILD_STATIC_OP(get_output)

View File

@@ -0,0 +1,141 @@
#!/bin/bash
set -e
# Test splitwise deployment
# There are two methods for splitwise deployment:
# v0: using splitwise_scheduler or dp_scheduler
# v1: using local_scheduler + router
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
DATA_PARALLEL_SIZE=2
TENSOR_PARALLEL_SIZE=1
NUM_GPUS=$(($DATA_PARALLEL_SIZE * $TENSOR_PARALLEL_SIZE))
LOG_DATE=$(date +%Y%m%d_%H%M%S)
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=1
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
export FD_ENABLE_MULTI_API_SERVER=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
source ${SCRIPT_DIR}/utils.sh
# start router
ROUTER_PORT=$(get_free_ports 1)
echo "---------------------------"
echo ROUTER_PORT: $ROUTER_PORT
export FD_LOG_DIR="log/$LOG_DATE/router"
rm -rf $FD_LOG_DIR
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.router.launch \
--port ${ROUTER_PORT} \
--splitwise \
2>&1 >${FD_LOG_DIR}/nohup &
sleep 1
# start prefill
P_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
P_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
echo "---------------------------"
echo P_SERVER_PORTS: $P_SERVER_PORTS
echo P_METRICS_PORTS: $P_METRICS_PORTS
echo P_ENGINE_WORKER_QUEUE_PORTS: $P_ENGINE_WORKER_QUEUE_PORTS
echo P_CACHE_QUEUE_PORTS: $P_CACHE_QUEUE_PORTS
echo P_RDMA_COMM_PORTS: $P_RDMA_COMM_PORTS
echo P_PD_COMM_PORTS: $P_PD_COMM_PORTS
export CUDA_VISIBLE_DEVICES="0,1"
export FD_LOG_DIR="log/$LOG_DATE/prefill"
rm -rf $FD_LOG_DIR
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
--num-servers ${DATA_PARALLEL_SIZE}\
--ports ${P_SERVER_PORTS} \
--metrics-port ${P_METRICS_PORTS} \
--args --model ${MODEL_NAME} \
--engine-worker-queue-port ${P_ENGINE_WORKER_QUEUE_PORTS} \
--cache-queue-port ${P_CACHE_QUEUE_PORTS} \
--max-model-len 32768 \
--data-parallel-size ${DATA_PARALLEL_SIZE} \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--splitwise-role "prefill" \
--cache-transfer-protocol "rdma" \
--rdma-comm-ports ${P_RDMA_COMM_PORTS} \
--pd-comm-port ${P_PD_COMM_PORTS} \
--router "0.0.0.0:${ROUTER_PORT}" \
2>&1 >${FD_LOG_DIR}/nohup &
echo "--- Health Check Status ---"
wait_for_health ${P_SERVER_PORTS}
# start decode
D_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
D_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
echo "---------------------------"
echo D_SERVER_PORTS: $D_SERVER_PORTS
echo D_ENGINE_WORKER_QUEUE_PORTS: $D_ENGINE_WORKER_QUEUE_PORTS
echo D_CACHE_QUEUE_PORTS: $D_CACHE_QUEUE_PORTS
echo D_METRICS_PORTS: $D_METRICS_PORTS
echo D_RDMA_COMM_PORTS: $D_RDMA_COMM_PORTS
echo D_PD_COMM_PORTS: $D_PD_COMM_PORTS
export CUDA_VISIBLE_DEVICES="2,3"
export FD_LOG_DIR="log/$LOG_DATE/decode"
rm -rf $FD_LOG_DIR
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
--num-servers ${DATA_PARALLEL_SIZE}\
--ports ${D_SERVER_PORTS} \
--metrics-port ${D_METRICS_PORTS} \
--args --model ${MODEL_NAME} \
--engine-worker-queue-port ${D_ENGINE_WORKER_QUEUE_PORTS} \
--cache-queue-port ${D_CACHE_QUEUE_PORTS} \
--max-model-len 32768 \
--data-parallel-size ${DATA_PARALLEL_SIZE} \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--splitwise-role "decode" \
--cache-transfer-protocol "rdma" \
--rdma-comm-ports ${D_RDMA_COMM_PORTS} \
--pd-comm-port ${D_PD_COMM_PORTS} \
--router "0.0.0.0:${ROUTER_PORT}" \
2>&1 >${FD_LOG_DIR}/nohup &
echo "--- Health Check Status ---"
wait_for_health ${D_SERVER_PORTS}
# send request
echo "------ Request Check ------"
sleep 10 # make sure server is registered to router
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "hello"}
],
"max_tokens": 100,
"stream": false
}'

View File

@@ -1,8 +1,16 @@
#!/bin/bash
is_port_free() {
local port=$1
if ss -ltn | awk '{print $4}' | grep -q ":${port}$"; then
return 1 # Port is occupied
fi
return 0 # Port is free
}
check_ports() {
for port in "$@"; do
if ss -tuln | grep -q ":$port "; then
if ! is_port_free $port; then
echo "❌ Port $port is already in use"
return 1
fi
@@ -11,14 +19,79 @@ check_ports() {
}
wait_for_health() {
local server_port=$1
IFS=',' read -r -a server_ports <<< "$1"
local num_ports=${#server_ports[@]}
local total_lines=$((num_ports + 1))
local first_run=true
local GREEN='\033[0;32m'
local RED='\033[0;31m'
local NC='\033[0m' # No Color
local start_time=$(date +%s)
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
local all_ready=true
for port in "${server_ports[@]}"; do
status_code=$(curl -s --max-time 1 -o /dev/null -w "%{http_code}" "http://0.0.0.0:${port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
printf "Port %s: ${GREEN}[OK] 200${NC}\033[K\n" "$port"
else
all_ready=false
printf "Port %s: ${RED}[WAIT] %s${NC}\033[K\n" "$port" "$status_code"
fi
done
cur_time=$(date +%s)
if [ "$all_ready" = "true" ]; then
echo "All services are ready! [$((cur_time-start_time))s]"
break
else
echo "Service not ready. Retrying in 4s..."
sleep 4
fi
else
echo "Waiting for services... [$((cur_time-start_time))s]"
printf "\033[%dA" "$total_lines" # roll back cursor
sleep 1
fi
done
}
get_free_ports() {
free_ports_num=${1:-1}
start_port=${2:-8000}
end_port=${3:-9000}
free_ports=()
if [[ ! -n ${free_ports_num} || "${free_ports_num}" -le 0 ]]; then
log_warn "param can't be empty, and should > 0"
echo ${free_ports[@]}
return 1
fi
used_ports1=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($4,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
used_ports2=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($5,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
all_used_ports=$(printf "%s\n" "${used_ports1}" "${used_ports2}" | sort -u)
# Generate random number between 0 and 32767
random_num=$(( RANDOM ))
port=$(( random_num % (end_port - start_port + 1) + start_port ))
while true; do
(( port++ ))
if [[ ${port} -ge ${end_port} ]]; then
port=${start_port}
fi
if [[ "${all_used_ports[@]}" =~ "${port}" ]]; then
continue
fi
if is_port_free ${port}; then
free_ports+=("${port}")
(( free_ports_num-- ))
if [[ ${free_ports_num} = 0 ]]; then
break
fi
fi
done
# echo ${free_ports[@]}
IFS=',' && echo "${free_ports[*]}"
return 0
}

View File

@@ -41,7 +41,7 @@ from fastdeploy.inter_communicator import (
)
from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log")
# logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
@@ -552,6 +552,7 @@ class CacheMessagerV1:
cache_info = self.engine_worker_queue.get_cache_info()
finished_add_cache_task_req_ids = []
if cache_info:
logger.debug(f"Get cache info from engine worker queue, {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
@@ -570,14 +571,15 @@ class CacheMessagerV1:
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
logger.info(f"Get cache info from D: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
logger.info(f"Get cache info from D: {info}")
logger.info(f"Get cache info from P: {info}")
self.cache_info[info["request_id"]] = info
if finished_add_cache_task_req_ids:
logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}")
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else:
@@ -671,7 +673,7 @@ class CacheMessagerV1:
target_ip, target_id, decode_tp_size
)
if status:
logger.info(f"connect to {target_ip}:{target_id} success")
logger.debug(f"connect to {target_ip}:{target_id} success")
else:
logger.error(f"connect to {target_ip}:{target_id} failed")
task["status"] = "connection error"
@@ -722,7 +724,7 @@ class CacheMessagerV1:
if "error" not in task["status"]:
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
f"Finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
)
else:
task["sended_layer_id"] = -1
@@ -736,7 +738,9 @@ class CacheMessagerV1:
self.messager["ipc"].write_block_by_sync(target_id)
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
logger.info(
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
)
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
@@ -928,7 +932,8 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -740,6 +740,6 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
set_device(args.device_id)
main()

View File

@@ -280,7 +280,7 @@ class PrefixCacheManager:
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -372,7 +372,7 @@ class PrefixCacheManager:
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))

View File

@@ -545,6 +545,7 @@ class ParallelConfig:
self.tensor_parallel_size = 1 # TP degree
self.expert_parallel_rank = 0 # EP rank ID
self.expert_parallel_size = 1 # EP degree
self.data_parallel_rank = 0 # DP rank ID
self.data_parallel_size = 1 # DP degree
self.enable_expert_parallel = False
self.enable_chunked_moe = False
@@ -1887,7 +1888,11 @@ class FDConfig:
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
]
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None
connector_port = (
self.cache_config.pd_comm_port[self.parallel_config.local_data_parallel_id]
if self.cache_config.pd_comm_port
else None
)
self.disaggregate_info = {}
if self.scheduler_config.splitwise_role != "mixed":

View File

@@ -82,9 +82,9 @@ class EngineService:
self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
if self.cfg.parallel_config.enable_expert_parallel:
if self.cfg.parallel_config.data_parallel_size > 1:
self.llm_logger = get_logger(
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log"
"fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.llm_logger = llm_logger
@@ -716,7 +716,11 @@ class EngineService:
is_fetching = False
return
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if tasks:
self.llm_logger.debug(
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
@@ -724,22 +728,24 @@ class EngineService:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.llm_logger.debug(f"P has allocated resources for request: {task.request_id}")
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
self.llm_logger.error(
f"D failed to allocate resource for request {task.request_id}, try again."
)
time.sleep(0.05)
else:
break
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
self.llm_logger.info("wait for preallocate_resource_in_p")
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.llm_logger.debug(f"P has allocated resources for request: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks:
@@ -747,7 +753,9 @@ class EngineService:
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.llm_logger.error(
f"D failed to allocate resource for request {task.request_id}, message: {msg}."
)
self.scheduler.put_results(
[
RequestOutput(
@@ -760,25 +768,32 @@ class EngineService:
)
need_delete_tasks.append(task)
continue
else:
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager
if tasks:
need_check_req_ids = [task.request_id for task in tasks]
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids:
self.llm_logger.debug(
f"P has successfully sent cache infos to cache messager for requests: {req_ids}"
)
for req_id in req_ids:
assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id)
else:
time.sleep(0.001)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
@@ -787,6 +802,9 @@ class EngineService:
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
self.llm_logger.info(
f"P add requests into running queue: {[task.request_id for task in tasks]}"
)
else:
for task in tasks:
self.resource_manager.add_request(task)
@@ -917,7 +935,6 @@ class EngineService:
request.llm_engine_recv_req_timestamp = time.time()
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc()
self.llm_logger.debug(f"Receive request: {request}")
trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
@@ -1082,10 +1099,14 @@ class EngineService:
for item in items:
tasks = item[1]
if isinstance(tasks[0], Request):
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
self.llm_logger.debug(
f"D has received tasks to preallocate resource for tasks: {[task.request_id for task in tasks]}"
)
allocate_resource_requests.extend(tasks)
elif isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
self.llm_logger.debug(
f"D has received tasks to process prefilled tasks: {[task.request_id for task in tasks]}"
)
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
@@ -1099,13 +1120,13 @@ class EngineService:
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_info_to_prefill([task])
self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}")
processed_indices.append(idx)
is_success = True
else:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.llm_logger.debug(f"D Resource available, processing task {task.request_id}")
self.insert_tasks([task])
processed_indices.append(idx)
is_success = True
@@ -1114,6 +1135,7 @@ class EngineService:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
self.split_connector.send_cache_info_to_prefill([task])
self.llm_logger.warning(f"D has failed to send cache infos for task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
@@ -1169,7 +1191,7 @@ class EngineService:
if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance
self.scheduler.put_results([req_output])
self.resource_manager.add_prefilled_request(req_output)
self.llm_logger.debug(f"add prefilled request success, {request_id}")
self.llm_logger.info(f"D has successfully added prefilled request, {request_id}")
def decode_loop():
while self.running:

View File

@@ -61,7 +61,6 @@ class ExpertService:
]
self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos]
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None
if self.cfg.cache_config.num_gpu_blocks_override is None:
self.do_profile = True
@@ -127,17 +126,18 @@ class ExpertService:
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
if not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
self.launched_expert_service_signal.value[local_rank] = 1
if self.cfg.parallel_config.enable_expert_parallel:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
self.launched_expert_service_signal.value[local_rank] = 1
if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
while True:

View File

@@ -496,6 +496,8 @@ class EngineWorkerQueue:
self.tasks.append(tasks)
self.lock.release()
llm_logger.debug(f"put_tasks: tasks={tasks}")
def get_tasks(self) -> Tuple[List[Any], bool]:
"""
Retrieve tasks from the shared queue and update read status.
@@ -512,6 +514,7 @@ class EngineWorkerQueue:
if all_client_read:
self.tasks[:] = list()
self.lock.release()
llm_logger.debug(f"get_tasks: tasks={tasks}")
return tasks, all_client_read
def num_tasks(self) -> int:
@@ -600,8 +603,7 @@ class EngineWorkerQueue:
self.cache_infos.extend(cache_info)
llm_logger.debug(
f"put cache_infos to engine worker queue: {self.cache_infos}, "
f"local_data_parallel_id:{self.local_data_parallel_id}"
f"put_cache_info: cache_info={cache_info}, local_data_parallel_id={self.local_data_parallel_id}"
)
self.lock_info.release()

View File

@@ -214,6 +214,9 @@ class ZmqServerBase(ABC):
except zmq.Again:
time.sleep(0.001)
continue
except zmq.error.ZMQError as e:
llm_logger.error(f"recv_result_handle get zmq error: {e}")
break
except Exception as e:
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
continue

View File

@@ -402,12 +402,8 @@ class TokenProcessor:
rank_id,
is_blocking,
)
elif (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
elif self.cfg.parallel_config.data_parallel_size > 1:
get_output_ep(self.output_tokens, rank_id, is_blocking)
else:
get_output(self.output_tokens, rank_id, is_blocking)

View File

@@ -91,6 +91,7 @@ class Router:
self.prefill_servers = []
self.decode_servers = []
self.lock = asyncio.Lock() # async-safe lock
logger.info("Router started at http://{}:{}".format(self.host, self.port))
async def register_instance(self, instance_info_dict: dict):
"""Register an instance asynchronously"""
@@ -172,6 +173,8 @@ class Router:
async def handle_splitwise_request(self, request_data: dict, endpoint_name: str):
logger.debug(f"Received request: {request_data}")
prefill_server, decode_server = await self.select_pd()
logger.debug(f"Selected prefill server: {prefill_server}")
logger.debug(f"Selected decode server: {decode_server}")
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
raise HTTPException(
@@ -371,4 +374,4 @@ def launch_router(router_args: RouterArgs):
app.state.router = Router(app.state.router_args)
asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5))
uvicorn.run(app, host=router_args.host, port=router_args.port)
uvicorn.run(app, host=router_args.host, port=int(router_args.port))

View File

@@ -44,9 +44,10 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object.
"""
self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.local_data_parallel_id = self.cfg.parallel_config.local_data_parallel_id
if self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
"splitwise_connector", f"splitwise_connector_dprank{self.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
@@ -54,7 +55,6 @@ class SplitwiseConnector:
self.resource_manager = resource_manager
self.connect_innode_instances = {}
self.current_request_ids = dict()
self.idx = self.cfg.parallel_config.local_data_parallel_id
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None:
@@ -74,7 +74,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.logger.info(f"_init_network: bind {self.cfg.cache_config.pd_comm_port}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
@@ -94,17 +94,17 @@ class SplitwiseConnector:
if not socks:
continue
else:
self.logger.debug(f"receive {socks}")
self.logger.debug(f"start_receiver: receive {socks}")
frames = self.router_socket.recv_multipart()
self.logger.debug(f"frames: {frames}")
self.logger.debug(f"start_receiver: frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
time.sleep(0.001)
else:
time.sleep(5)
except Exception as e:
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
self.logger.error(f"start_receiver: Receiver error: {e}, {str(traceback.format_exc())}")
time.sleep(1)
def _get_push_socket(self, addr):
@@ -116,7 +116,7 @@ class SplitwiseConnector:
return sock
try:
self.logger.info(f"Establishing new connection to {addr}")
self.logger.info(f"_get_push_socket: Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数
@@ -135,36 +135,29 @@ class SplitwiseConnector:
return sock
except zmq.ZMQError as e:
self.logger.error(f"Connection to {addr} failed: {e}")
self.logger.error(f"_get_push_socket: Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e
def _send_message(self, addr, msg_type: str, payload):
if not addr:
return
try:
self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload)
try:
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
self.logger.warning(f"Connection to {addr} not established")
self.logger.warning(f"_send_message: Connection to {addr} not established")
except zmq.Again:
self.logger.warning(f"Send queue full for {addr}")
self.logger.warning(f"_send_message: Send queue full for {addr}")
except Exception as e:
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
self.logger.error(f"_send_message: Send to {addr} failed: {e}, {str(traceback.format_exc())}")
main_process_metrics.send_cache_failed_num.inc()
self._close_connection(addr)
except Exception as e:
self.logger.error(f"Message preparation failed: {e}")
self.logger.error(f"_send_message: Message preparation failed: {e}")
def _close_connection(self, addr):
"""
@@ -191,21 +184,20 @@ class SplitwiseConnector:
if task.disaggregate_info["transfer_protocol"] == "ipc":
addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
self.logger.info(f"send_splitwise_tasks: protocol=ipc, addr={addr}, task={task.request_id}")
self.send_splitwise_tasks_innode([task], addr)
else:
addr = (
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
self.logger.info(f"send splitwise tasks to port {addr} decode, {task.request_id}")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
task.disaggregate_info["role"] = "decode"
self.logger.debug(f"send task to coupled instance, {addr}, {task}")
self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill"
@@ -226,12 +218,12 @@ class SplitwiseConnector:
self.create_connection(port)
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
self.idx
self.local_data_parallel_id
]
self.logger.info(f"send_splitwise_tasks_innode: port={port}, tasks={[task.request_id for task in tasks]}")
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port
return current_port
@@ -241,7 +233,7 @@ class SplitwiseConnector:
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
self.logger.info(f"send first token to decode, {[x.request_id for x in tasks_list]}")
self.logger.info(f"send_first_token: send first token to decode, {[x.request_id for x in tasks_list]}")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
@@ -249,7 +241,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
self.logger.info(f"send first token to port {node} decode")
self.logger.info(f"send_first_token: send first token to port {node} decode")
self._send_message(node, "decode", tasks_list)
def create_connection(self, port):
@@ -288,7 +280,7 @@ class SplitwiseConnector:
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_info_to_messager(self, tasks: List[Request], current_id):
@@ -359,9 +351,11 @@ class SplitwiseConnector:
else:
info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.parallel_config.device_ids.split(","),
"device_ids": [self.cfg.parallel_config.device_ids.split(",")[self.local_data_parallel_id]],
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"rdma_ports": [
self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"][self.local_data_parallel_id]
],
"transfer_protocol": "rdma",
"dest_block_ids": dsg_info["block_tables"],
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
@@ -404,7 +398,7 @@ class SplitwiseConnector:
"""
try:
msg_type, payload = self._deserialize_message(message)
self.logger.info(f"{msg_type}")
self.logger.info(f"_process_message: {msg_type}")
if msg_type == "prefill":
self._handle_prefill(payload)
@@ -412,7 +406,7 @@ class SplitwiseConnector:
self._handle_decode(payload)
elif msg_type == "cache_sync":
for task in payload:
self.logger.info(f"cache_sync task: {task}")
self.logger.info(f"_process_message: cache_sync task: {task}")
current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
@@ -421,13 +415,13 @@ class SplitwiseConnector:
self.engine_worker_queue.put_cache_info(payload)
except Exception as e:
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
self.logger.error(f"_process_message: Message processing failed: {e}, {str(traceback.format_exc())}")
def _handle_prefill(self, tasks):
"""
Handle prefill tasks from other nodes.
"""
self.logger.debug(f"_handle_prefill function receive {tasks}")
self.logger.debug(f"_handle_prefill: receive payload {tasks}")
tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
@@ -435,7 +429,7 @@ class SplitwiseConnector:
"""
Handle decode tasks from other nodes.
"""
self.logger.debug(f"_handle_decode function receive {payload}")
self.logger.debug(f"_handle_decode: receive payload {payload}")
tasks = []
for task in payload:
tasks.append(RequestOutput.from_dict(task))

View File

@@ -173,7 +173,11 @@ class PaddleDisWorkerProc:
model_weights_status:
"""
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
if (
self.parallel_config.enable_expert_parallel
and self.parallel_config.data_parallel_size > 1
and not envs.FD_ENABLE_MULTI_API_SERVER
):
launched_expert_service_signal_data = np.zeros(
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
)
@@ -905,6 +909,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
# config for DP
if parallel_config.data_parallel_size > 1:
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
max_chips_per_node // parallel_config.tensor_parallel_size
)
# config for EP
if parallel_config.expert_parallel_size > 1:
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
@@ -914,11 +924,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
num_experts = model_config.moe_num_experts
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
max_chips_per_node // parallel_config.tensor_parallel_size
)
parallel_config.expert_parallel_rank = expert_parallel_rank
parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset