mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation] support DP via v1 router and decouple DP and EP (#5197)
* [fix] support DP via v1 router and decouple DP and EP * [fix] fix scripts * [fix] reset model path * [fix] dp use get_output_ep, fix router port type, update scripts * [merge] merge with latest code * [chore] remove some debug log * [fix] fix code style check * [fix] fix test_multi_api_server for log_dir name * [chore] reduce logs * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -40,8 +40,7 @@ void GetOutput(const paddle::Tensor& x,
|
||||
static struct msgdata msg_rcv;
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef GET_OUTPUT_DEBUG
|
||||
|
||||
141
examples/splitwise/start_v1_dp2.sh
Normal file
141
examples/splitwise/start_v1_dp2.sh
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# There are two methods for splitwise deployment:
|
||||
# v0: using splitwise_scheduler or dp_scheduler
|
||||
# v1: using local_scheduler + router
|
||||
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
DATA_PARALLEL_SIZE=2
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
NUM_GPUS=$(($DATA_PARALLEL_SIZE * $TENSOR_PARALLEL_SIZE))
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
export FD_ENABLE_MULTI_API_SERVER=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
# start router
|
||||
ROUTER_PORT=$(get_free_ports 1)
|
||||
echo "---------------------------"
|
||||
echo ROUTER_PORT: $ROUTER_PORT
|
||||
|
||||
export FD_LOG_DIR="log/$LOG_DATE/router"
|
||||
rm -rf $FD_LOG_DIR
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.router.launch \
|
||||
--port ${ROUTER_PORT} \
|
||||
--splitwise \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
sleep 1
|
||||
|
||||
|
||||
# start prefill
|
||||
P_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
|
||||
P_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
echo "---------------------------"
|
||||
echo P_SERVER_PORTS: $P_SERVER_PORTS
|
||||
echo P_METRICS_PORTS: $P_METRICS_PORTS
|
||||
echo P_ENGINE_WORKER_QUEUE_PORTS: $P_ENGINE_WORKER_QUEUE_PORTS
|
||||
echo P_CACHE_QUEUE_PORTS: $P_CACHE_QUEUE_PORTS
|
||||
echo P_RDMA_COMM_PORTS: $P_RDMA_COMM_PORTS
|
||||
echo P_PD_COMM_PORTS: $P_PD_COMM_PORTS
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
rm -rf $FD_LOG_DIR
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||
--num-servers ${DATA_PARALLEL_SIZE}\
|
||||
--ports ${P_SERVER_PORTS} \
|
||||
--metrics-port ${P_METRICS_PORTS} \
|
||||
--args --model ${MODEL_NAME} \
|
||||
--engine-worker-queue-port ${P_ENGINE_WORKER_QUEUE_PORTS} \
|
||||
--cache-queue-port ${P_CACHE_QUEUE_PORTS} \
|
||||
--max-model-len 32768 \
|
||||
--data-parallel-size ${DATA_PARALLEL_SIZE} \
|
||||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports ${P_RDMA_COMM_PORTS} \
|
||||
--pd-comm-port ${P_PD_COMM_PORTS} \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
echo "--- Health Check Status ---"
|
||||
wait_for_health ${P_SERVER_PORTS}
|
||||
|
||||
|
||||
# start decode
|
||||
D_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
|
||||
D_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
echo "---------------------------"
|
||||
echo D_SERVER_PORTS: $D_SERVER_PORTS
|
||||
echo D_ENGINE_WORKER_QUEUE_PORTS: $D_ENGINE_WORKER_QUEUE_PORTS
|
||||
echo D_CACHE_QUEUE_PORTS: $D_CACHE_QUEUE_PORTS
|
||||
echo D_METRICS_PORTS: $D_METRICS_PORTS
|
||||
echo D_RDMA_COMM_PORTS: $D_RDMA_COMM_PORTS
|
||||
echo D_PD_COMM_PORTS: $D_PD_COMM_PORTS
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="2,3"
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf $FD_LOG_DIR
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||
--num-servers ${DATA_PARALLEL_SIZE}\
|
||||
--ports ${D_SERVER_PORTS} \
|
||||
--metrics-port ${D_METRICS_PORTS} \
|
||||
--args --model ${MODEL_NAME} \
|
||||
--engine-worker-queue-port ${D_ENGINE_WORKER_QUEUE_PORTS} \
|
||||
--cache-queue-port ${D_CACHE_QUEUE_PORTS} \
|
||||
--max-model-len 32768 \
|
||||
--data-parallel-size ${DATA_PARALLEL_SIZE} \
|
||||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports ${D_RDMA_COMM_PORTS} \
|
||||
--pd-comm-port ${D_PD_COMM_PORTS} \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
echo "--- Health Check Status ---"
|
||||
wait_for_health ${D_SERVER_PORTS}
|
||||
|
||||
|
||||
# send request
|
||||
echo "------ Request Check ------"
|
||||
sleep 10 # make sure server is registered to router
|
||||
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": false
|
||||
}'
|
||||
@@ -1,8 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
is_port_free() {
|
||||
local port=$1
|
||||
if ss -ltn | awk '{print $4}' | grep -q ":${port}$"; then
|
||||
return 1 # Port is occupied
|
||||
fi
|
||||
return 0 # Port is free
|
||||
}
|
||||
|
||||
check_ports() {
|
||||
for port in "$@"; do
|
||||
if ss -tuln | grep -q ":$port "; then
|
||||
if ! is_port_free $port; then
|
||||
echo "❌ Port $port is already in use"
|
||||
return 1
|
||||
fi
|
||||
@@ -11,14 +19,79 @@ check_ports() {
|
||||
}
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
IFS=',' read -r -a server_ports <<< "$1"
|
||||
local num_ports=${#server_ports[@]}
|
||||
local total_lines=$((num_ports + 1))
|
||||
local first_run=true
|
||||
local GREEN='\033[0;32m'
|
||||
local RED='\033[0;31m'
|
||||
local NC='\033[0m' # No Color
|
||||
local start_time=$(date +%s)
|
||||
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
local all_ready=true
|
||||
for port in "${server_ports[@]}"; do
|
||||
status_code=$(curl -s --max-time 1 -o /dev/null -w "%{http_code}" "http://0.0.0.0:${port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
printf "Port %s: ${GREEN}[OK] 200${NC}\033[K\n" "$port"
|
||||
else
|
||||
all_ready=false
|
||||
printf "Port %s: ${RED}[WAIT] %s${NC}\033[K\n" "$port" "$status_code"
|
||||
fi
|
||||
done
|
||||
cur_time=$(date +%s)
|
||||
if [ "$all_ready" = "true" ]; then
|
||||
echo "All services are ready! [$((cur_time-start_time))s]"
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 4s..."
|
||||
sleep 4
|
||||
echo "Waiting for services... [$((cur_time-start_time))s]"
|
||||
printf "\033[%dA" "$total_lines" # roll back cursor
|
||||
sleep 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
get_free_ports() {
|
||||
free_ports_num=${1:-1}
|
||||
start_port=${2:-8000}
|
||||
end_port=${3:-9000}
|
||||
|
||||
free_ports=()
|
||||
if [[ ! -n ${free_ports_num} || "${free_ports_num}" -le 0 ]]; then
|
||||
log_warn "param can't be empty, and should > 0"
|
||||
echo ${free_ports[@]}
|
||||
return 1
|
||||
fi
|
||||
|
||||
used_ports1=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($4,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
|
||||
used_ports2=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($5,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
|
||||
all_used_ports=$(printf "%s\n" "${used_ports1}" "${used_ports2}" | sort -u)
|
||||
|
||||
# Generate random number between 0 and 32767
|
||||
random_num=$(( RANDOM ))
|
||||
port=$(( random_num % (end_port - start_port + 1) + start_port ))
|
||||
|
||||
while true; do
|
||||
(( port++ ))
|
||||
if [[ ${port} -ge ${end_port} ]]; then
|
||||
port=${start_port}
|
||||
fi
|
||||
|
||||
if [[ "${all_used_ports[@]}" =~ "${port}" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
if is_port_free ${port}; then
|
||||
free_ports+=("${port}")
|
||||
(( free_ports_num-- ))
|
||||
if [[ ${free_ports_num} = 0 ]]; then
|
||||
break
|
||||
fi
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
# echo ${free_ports[@]}
|
||||
IFS=',' && echo "${free_ports[*]}"
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ from fastdeploy.inter_communicator import (
|
||||
)
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
# logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -552,6 +552,7 @@ class CacheMessagerV1:
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
finished_add_cache_task_req_ids = []
|
||||
if cache_info:
|
||||
logger.debug(f"Get cache info from engine worker queue, {cache_info}")
|
||||
self.engine_worker_queue.cache_info_barrier.wait()
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
@@ -570,14 +571,15 @@ class CacheMessagerV1:
|
||||
current_info["sended_layer_id"] = -1
|
||||
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
|
||||
logger.info(f"Get cache info from D: finish add cache task: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
||||
else:
|
||||
logger.info(f"Get cache info from D: {info}")
|
||||
logger.info(f"Get cache info from P: {info}")
|
||||
self.cache_info[info["request_id"]] = info
|
||||
|
||||
if finished_add_cache_task_req_ids:
|
||||
logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}")
|
||||
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
else:
|
||||
@@ -671,7 +673,7 @@ class CacheMessagerV1:
|
||||
target_ip, target_id, decode_tp_size
|
||||
)
|
||||
if status:
|
||||
logger.info(f"connect to {target_ip}:{target_id} success")
|
||||
logger.debug(f"connect to {target_ip}:{target_id} success")
|
||||
else:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
task["status"] = "connection error"
|
||||
@@ -722,7 +724,7 @@ class CacheMessagerV1:
|
||||
if "error" not in task["status"]:
|
||||
task["status"] = "finished"
|
||||
logger.info(
|
||||
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||
f"Finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||
)
|
||||
else:
|
||||
task["sended_layer_id"] = -1
|
||||
@@ -736,7 +738,9 @@ class CacheMessagerV1:
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
logger.info(
|
||||
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
|
||||
)
|
||||
self.engine_cache_tasks[task["current_id"]] = dict()
|
||||
del self.cache_info[task["request_id"]]
|
||||
del self.idx_cache_task_dict[task["current_id"]]
|
||||
@@ -928,7 +932,8 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
main()
|
||||
|
||||
@@ -740,6 +740,6 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
|
||||
set_device(args.device_id)
|
||||
main()
|
||||
|
||||
@@ -280,7 +280,7 @@ class PrefixCacheManager:
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
@@ -372,7 +372,7 @@ class PrefixCacheManager:
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
|
||||
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
|
||||
@@ -545,6 +545,7 @@ class ParallelConfig:
|
||||
self.tensor_parallel_size = 1 # TP degree
|
||||
self.expert_parallel_rank = 0 # EP rank ID
|
||||
self.expert_parallel_size = 1 # EP degree
|
||||
self.data_parallel_rank = 0 # DP rank ID
|
||||
self.data_parallel_size = 1 # DP degree
|
||||
self.enable_expert_parallel = False
|
||||
self.enable_chunked_moe = False
|
||||
@@ -1887,7 +1888,11 @@ class FDConfig:
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
]
|
||||
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None
|
||||
connector_port = (
|
||||
self.cache_config.pd_comm_port[self.parallel_config.local_data_parallel_id]
|
||||
if self.cache_config.pd_comm_port
|
||||
else None
|
||||
)
|
||||
|
||||
self.disaggregate_info = {}
|
||||
if self.scheduler_config.splitwise_role != "mixed":
|
||||
|
||||
@@ -82,9 +82,9 @@ class EngineService:
|
||||
self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel:
|
||||
if self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.llm_logger = get_logger(
|
||||
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
"fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.llm_logger = llm_logger
|
||||
@@ -716,7 +716,11 @@ class EngineService:
|
||||
is_fetching = False
|
||||
return
|
||||
|
||||
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
|
||||
if tasks:
|
||||
self.llm_logger.debug(
|
||||
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
|
||||
)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
need_delete_tasks = []
|
||||
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
|
||||
@@ -724,22 +728,24 @@ class EngineService:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.llm_logger.debug(f"P has allocated resources for request: {task.request_id}")
|
||||
while True:
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
|
||||
self.llm_logger.error(
|
||||
f"D failed to allocate resource for request {task.request_id}, try again."
|
||||
)
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
break
|
||||
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
|
||||
else:
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
self.llm_logger.info("wait for preallocate_resource_in_p")
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.llm_logger.debug(f"P has allocated resources for request: {task.request_id}")
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
|
||||
for task in tasks:
|
||||
@@ -747,7 +753,9 @@ class EngineService:
|
||||
# assure fetch block ids from D
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.llm_logger.error(
|
||||
f"D failed to allocate resource for request {task.request_id}, message: {msg}."
|
||||
)
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
@@ -760,25 +768,32 @@ class EngineService:
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
else:
|
||||
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
|
||||
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
# release resource in P
|
||||
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
# to send cache info to cache messager
|
||||
if tasks:
|
||||
need_check_req_ids = [task.request_id for task in tasks]
|
||||
self.split_connector.send_cache_info_to_messager(tasks, 0)
|
||||
# ensure cache tasks has sent to cache_messager
|
||||
need_check_req_ids = [task.request_id for task in tasks]
|
||||
while need_check_req_ids:
|
||||
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
|
||||
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
|
||||
if req_ids:
|
||||
self.llm_logger.debug(
|
||||
f"P has successfully sent cache infos to cache messager for requests: {req_ids}"
|
||||
)
|
||||
for req_id in req_ids:
|
||||
assert req_id in need_check_req_ids
|
||||
need_check_req_ids.remove(req_id)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
|
||||
# Fetch requests and add them to the scheduling queue
|
||||
if tasks:
|
||||
for task in tasks:
|
||||
@@ -787,6 +802,9 @@ class EngineService:
|
||||
)
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
self.resource_manager.add_request_in_p(tasks)
|
||||
self.llm_logger.info(
|
||||
f"P add requests into running queue: {[task.request_id for task in tasks]}"
|
||||
)
|
||||
else:
|
||||
for task in tasks:
|
||||
self.resource_manager.add_request(task)
|
||||
@@ -917,7 +935,6 @@ class EngineService:
|
||||
request.llm_engine_recv_req_timestamp = time.time()
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
main_process_metrics.requests_number.inc()
|
||||
self.llm_logger.debug(f"Receive request: {request}")
|
||||
trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", ""))
|
||||
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
|
||||
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
|
||||
@@ -1082,10 +1099,14 @@ class EngineService:
|
||||
for item in items:
|
||||
tasks = item[1]
|
||||
if isinstance(tasks[0], Request):
|
||||
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
|
||||
self.llm_logger.debug(
|
||||
f"D has received tasks to preallocate resource for tasks: {[task.request_id for task in tasks]}"
|
||||
)
|
||||
allocate_resource_requests.extend(tasks)
|
||||
elif isinstance(tasks[0], RequestOutput):
|
||||
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
|
||||
self.llm_logger.debug(
|
||||
f"D has received tasks to process prefilled tasks: {[task.request_id for task in tasks]}"
|
||||
)
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
@@ -1099,13 +1120,13 @@ class EngineService:
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if self.resource_manager.preallocate_resource_in_d(task):
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
self.split_connector.send_cache_info_to_prefill([task])
|
||||
self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}")
|
||||
processed_indices.append(idx)
|
||||
is_success = True
|
||||
else:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
self.llm_logger.debug(f"D Resource available, processing task {task.request_id}")
|
||||
self.insert_tasks([task])
|
||||
processed_indices.append(idx)
|
||||
is_success = True
|
||||
@@ -1114,6 +1135,7 @@ class EngineService:
|
||||
if not self.enable_decode_cache_task:
|
||||
task.error_msg = "Not enough resources"
|
||||
self.split_connector.send_cache_info_to_prefill([task])
|
||||
self.llm_logger.warning(f"D has failed to send cache infos for task {task.request_id}")
|
||||
processed_indices.append(idx)
|
||||
else:
|
||||
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||
@@ -1169,7 +1191,7 @@ class EngineService:
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance
|
||||
self.scheduler.put_results([req_output])
|
||||
self.resource_manager.add_prefilled_request(req_output)
|
||||
self.llm_logger.debug(f"add prefilled request success, {request_id}")
|
||||
self.llm_logger.info(f"D has successfully added prefilled request, {request_id}")
|
||||
|
||||
def decode_loop():
|
||||
while self.running:
|
||||
|
||||
@@ -61,7 +61,6 @@ class ExpertService:
|
||||
]
|
||||
self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos]
|
||||
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
|
||||
self.cfg.disaggregate_info = None
|
||||
|
||||
if self.cfg.cache_config.num_gpu_blocks_override is None:
|
||||
self.do_profile = True
|
||||
@@ -127,6 +126,7 @@ class ExpertService:
|
||||
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
if self.cfg.parallel_config.enable_expert_parallel:
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
|
||||
)
|
||||
|
||||
@@ -496,6 +496,8 @@ class EngineWorkerQueue:
|
||||
self.tasks.append(tasks)
|
||||
self.lock.release()
|
||||
|
||||
llm_logger.debug(f"put_tasks: tasks={tasks}")
|
||||
|
||||
def get_tasks(self) -> Tuple[List[Any], bool]:
|
||||
"""
|
||||
Retrieve tasks from the shared queue and update read status.
|
||||
@@ -512,6 +514,7 @@ class EngineWorkerQueue:
|
||||
if all_client_read:
|
||||
self.tasks[:] = list()
|
||||
self.lock.release()
|
||||
llm_logger.debug(f"get_tasks: tasks={tasks}")
|
||||
return tasks, all_client_read
|
||||
|
||||
def num_tasks(self) -> int:
|
||||
@@ -600,8 +603,7 @@ class EngineWorkerQueue:
|
||||
|
||||
self.cache_infos.extend(cache_info)
|
||||
llm_logger.debug(
|
||||
f"put cache_infos to engine worker queue: {self.cache_infos}, "
|
||||
f"local_data_parallel_id:{self.local_data_parallel_id}"
|
||||
f"put_cache_info: cache_info={cache_info}, local_data_parallel_id={self.local_data_parallel_id}"
|
||||
)
|
||||
self.lock_info.release()
|
||||
|
||||
|
||||
@@ -214,6 +214,9 @@ class ZmqServerBase(ABC):
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
except zmq.error.ZMQError as e:
|
||||
llm_logger.error(f"recv_result_handle get zmq error: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||
continue
|
||||
|
||||
@@ -402,12 +402,8 @@ class TokenProcessor:
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
elif (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
elif self.cfg.parallel_config.data_parallel_size > 1:
|
||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
else:
|
||||
get_output(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class Router:
|
||||
self.prefill_servers = []
|
||||
self.decode_servers = []
|
||||
self.lock = asyncio.Lock() # async-safe lock
|
||||
logger.info("Router started at http://{}:{}".format(self.host, self.port))
|
||||
|
||||
async def register_instance(self, instance_info_dict: dict):
|
||||
"""Register an instance asynchronously"""
|
||||
@@ -172,6 +173,8 @@ class Router:
|
||||
async def handle_splitwise_request(self, request_data: dict, endpoint_name: str):
|
||||
logger.debug(f"Received request: {request_data}")
|
||||
prefill_server, decode_server = await self.select_pd()
|
||||
logger.debug(f"Selected prefill server: {prefill_server}")
|
||||
logger.debug(f"Selected decode server: {decode_server}")
|
||||
|
||||
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
|
||||
raise HTTPException(
|
||||
@@ -371,4 +374,4 @@ def launch_router(router_args: RouterArgs):
|
||||
app.state.router = Router(app.state.router_args)
|
||||
asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5))
|
||||
|
||||
uvicorn.run(app, host=router_args.host, port=router_args.port)
|
||||
uvicorn.run(app, host=router_args.host, port=int(router_args.port))
|
||||
|
||||
@@ -44,9 +44,10 @@ class SplitwiseConnector:
|
||||
resource_manager (object): Resource manager object.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.local_data_parallel_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
if self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.logger = get_logger(
|
||||
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||
"splitwise_connector", f"splitwise_connector_dprank{self.local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
@@ -54,7 +55,6 @@ class SplitwiseConnector:
|
||||
self.resource_manager = resource_manager
|
||||
self.connect_innode_instances = {}
|
||||
self.current_request_ids = dict()
|
||||
self.idx = self.cfg.parallel_config.local_data_parallel_id
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
@@ -74,7 +74,7 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
||||
self.logger.info(f"_init_network: bind {self.cfg.cache_config.pd_comm_port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
@@ -94,17 +94,17 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
self.logger.debug(f"receive {socks}")
|
||||
self.logger.debug(f"start_receiver: receive {socks}")
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
self.logger.debug(f"frames: {frames}")
|
||||
self.logger.debug(f"start_receiver: frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"start_receiver: Receiver error: {e}, {str(traceback.format_exc())}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_push_socket(self, addr):
|
||||
@@ -116,7 +116,7 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
try:
|
||||
self.logger.info(f"Establishing new connection to {addr}")
|
||||
self.logger.info(f"_get_push_socket: Establishing new connection to {addr}")
|
||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||
|
||||
# 设置连接参数
|
||||
@@ -135,36 +135,29 @@ class SplitwiseConnector:
|
||||
return sock
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||
self.logger.error(f"_get_push_socket: Connection to {addr} failed: {e}")
|
||||
|
||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||
|
||||
def _send_message(self, addr, msg_type: str, payload):
|
||||
if not addr:
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
message = self._serialize_message(msg_type, payload)
|
||||
|
||||
try:
|
||||
|
||||
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
self.logger.warning(f"Connection to {addr} not established")
|
||||
self.logger.warning(f"_send_message: Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
self.logger.warning(f"Send queue full for {addr}")
|
||||
self.logger.warning(f"_send_message: Send queue full for {addr}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"_send_message: Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
main_process_metrics.send_cache_failed_num.inc()
|
||||
self._close_connection(addr)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Message preparation failed: {e}")
|
||||
self.logger.error(f"_send_message: Message preparation failed: {e}")
|
||||
|
||||
def _close_connection(self, addr):
|
||||
"""
|
||||
@@ -191,21 +184,20 @@ class SplitwiseConnector:
|
||||
if task.disaggregate_info["transfer_protocol"] == "ipc":
|
||||
addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
|
||||
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
|
||||
self.logger.info(f"send_splitwise_tasks: protocol=ipc, addr={addr}, task={task.request_id}")
|
||||
self.send_splitwise_tasks_innode([task], addr)
|
||||
|
||||
else:
|
||||
|
||||
addr = (
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode, {task.request_id}")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
self.logger.debug(f"send task to coupled instance, {addr}, {task}")
|
||||
self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
|
||||
self._send_message(addr, "prefill", [task])
|
||||
task.disaggregate_info["cache_info"] = decode_diagg
|
||||
task.disaggregate_info["role"] = "prefill"
|
||||
@@ -226,12 +218,12 @@ class SplitwiseConnector:
|
||||
self.create_connection(port)
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
|
||||
self.idx
|
||||
self.local_data_parallel_id
|
||||
]
|
||||
self.logger.info(f"send_splitwise_tasks_innode: port={port}, tasks={[task.request_id for task in tasks]}")
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||
current_port = port
|
||||
return current_port
|
||||
|
||||
@@ -241,7 +233,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
self.logger.info(f"send first token to decode, {[x.request_id for x in tasks_list]}")
|
||||
self.logger.info(f"send_first_token: send first token to decode, {[x.request_id for x in tasks_list]}")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
@@ -249,7 +241,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
self.logger.info(f"send first token to port {node} decode")
|
||||
self.logger.info(f"send_first_token: send first token to port {node} decode")
|
||||
self._send_message(node, "decode", tasks_list)
|
||||
|
||||
def create_connection(self, port):
|
||||
@@ -288,7 +280,7 @@ class SplitwiseConnector:
|
||||
del self.current_request_ids[task.request_id]
|
||||
if msg == "finished":
|
||||
return True, ""
|
||||
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||
self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_info_to_messager(self, tasks: List[Request], current_id):
|
||||
@@ -359,9 +351,11 @@ class SplitwiseConnector:
|
||||
else:
|
||||
info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.parallel_config.device_ids.split(","),
|
||||
"device_ids": [self.cfg.parallel_config.device_ids.split(",")[self.local_data_parallel_id]],
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"rdma_ports": [
|
||||
self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"][self.local_data_parallel_id]
|
||||
],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": dsg_info["block_tables"],
|
||||
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
|
||||
@@ -404,7 +398,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
self.logger.info(f"{msg_type}")
|
||||
self.logger.info(f"_process_message: {msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
self._handle_prefill(payload)
|
||||
@@ -412,7 +406,7 @@ class SplitwiseConnector:
|
||||
self._handle_decode(payload)
|
||||
elif msg_type == "cache_sync":
|
||||
for task in payload:
|
||||
self.logger.info(f"cache_sync task: {task}")
|
||||
self.logger.info(f"_process_message: cache_sync task: {task}")
|
||||
current_status = task.get("error_msg", "finished")
|
||||
self.current_request_ids[task["request_id"]] = current_status
|
||||
if self.enable_decode_cache_task:
|
||||
@@ -421,13 +415,13 @@ class SplitwiseConnector:
|
||||
self.engine_worker_queue.put_cache_info(payload)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||
self.logger.error(f"_process_message: Message processing failed: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def _handle_prefill(self, tasks):
|
||||
"""
|
||||
Handle prefill tasks from other nodes.
|
||||
"""
|
||||
self.logger.debug(f"_handle_prefill function receive {tasks}")
|
||||
self.logger.debug(f"_handle_prefill: receive payload {tasks}")
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
|
||||
|
||||
@@ -435,7 +429,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
Handle decode tasks from other nodes.
|
||||
"""
|
||||
self.logger.debug(f"_handle_decode function receive {payload}")
|
||||
self.logger.debug(f"_handle_decode: receive payload {payload}")
|
||||
tasks = []
|
||||
for task in payload:
|
||||
tasks.append(RequestOutput.from_dict(task))
|
||||
|
||||
@@ -173,7 +173,11 @@ class PaddleDisWorkerProc:
|
||||
model_weights_status:
|
||||
"""
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
if (
|
||||
self.parallel_config.enable_expert_parallel
|
||||
and self.parallel_config.data_parallel_size > 1
|
||||
and not envs.FD_ENABLE_MULTI_API_SERVER
|
||||
):
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
|
||||
)
|
||||
@@ -905,6 +909,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
|
||||
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
|
||||
# config for DP
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
|
||||
max_chips_per_node // parallel_config.tensor_parallel_size
|
||||
)
|
||||
# config for EP
|
||||
if parallel_config.expert_parallel_size > 1:
|
||||
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
|
||||
@@ -914,11 +924,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
num_experts = model_config.moe_num_experts
|
||||
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
|
||||
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
|
||||
max_chips_per_node // parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
parallel_config.expert_parallel_rank = expert_parallel_rank
|
||||
parallel_config.num_experts_per_rank = num_experts_per_rank
|
||||
parallel_config.num_experts_start_offset = num_experts_start_offset
|
||||
|
||||
Reference in New Issue
Block a user