mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] [PD] add simple router and refine splitwise deployment (#4709)
* add simple router and refine splitwise deployment * fix
This commit is contained in:
@@ -94,10 +94,11 @@ async def async_request_eb_openai_chat_completions(
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True,
|
||||
}
|
||||
},
|
||||
"max_tokens": request_func_input.output_len,
|
||||
}
|
||||
if request_func_input.response_format:
|
||||
payload["response_format"] =request_func_input.response_format
|
||||
payload["response_format"] = request_func_input.response_format
|
||||
|
||||
# 超参由yaml传入
|
||||
payload.update(request_func_input.hyper_parameters)
|
||||
@@ -132,13 +133,13 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
#print("####chunk:", chunk, type(chunk))
|
||||
# print("####chunk:", chunk, type(chunk))
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if request_id == "None" and "id" in data:
|
||||
request_id = data["id"]
|
||||
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
reason_content = choices[0]["delta"].get("reasoning_content")
|
||||
@@ -164,7 +165,6 @@ async def async_request_eb_openai_chat_completions(
|
||||
elif usage := data.get("usage", {}):
|
||||
output.output_tokens = usage.get("completion_tokens", 0)
|
||||
output.prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class SampleRequest:
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
"""BenchmarkDataset"""
|
||||
@@ -299,7 +299,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
prompt = entry["messages"][-1].get("content", "")
|
||||
history_QA = entry.get("messages", [])
|
||||
response_format = entry.get("response_format")
|
||||
new_output_len = int(entry.get("max_tokens", 12288))
|
||||
new_output_len = int(entry.get("max_tokens", output_len if output_len else 12288))
|
||||
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
@@ -311,7 +311,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
expected_output_len=new_output_len,
|
||||
response_format=response_format
|
||||
response_format=response_format,
|
||||
)
|
||||
)
|
||||
cnt += 1
|
||||
|
||||
@@ -352,7 +352,7 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
debug=debug,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
print("test_input:", test_input)
|
||||
@@ -384,7 +384,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format
|
||||
response_format=response_format,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@@ -444,7 +444,7 @@ async def benchmark(
|
||||
debug=debug,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format
|
||||
response_format=response_format,
|
||||
)
|
||||
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
@@ -460,7 +460,7 @@ async def benchmark(
|
||||
api_url=base_url + "/stop_profile",
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format
|
||||
response_format=response_format,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
|
||||
@@ -3,4 +3,4 @@ max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.85
|
||||
tensor_parallel_size: 1
|
||||
limit_mm_per_prompt: '{"image": 100, "video": 100}'
|
||||
enable_mm: True
|
||||
enable_mm: True
|
||||
|
||||
@@ -5,4 +5,4 @@ metadata:
|
||||
max_tokens: 32768
|
||||
repetition_penalty: 1.05
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
presence_penalty: 0
|
||||
|
||||
@@ -26,7 +26,7 @@ We recommend using mpirun for one-command startup without manually starting each
|
||||
4. Ensure all nodes can resolve each other's hostnames
|
||||
|
||||
* Online inference startup example:
|
||||
|
||||
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
@@ -40,7 +40,7 @@ We recommend using mpirun for one-command startup without manually starting each
|
||||
```
|
||||
|
||||
* Offline startup example:
|
||||
|
||||
|
||||
```python
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
4. 确保所有节点能够解析彼此的主机名
|
||||
|
||||
* 在线推理启动示例:
|
||||
|
||||
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
@@ -40,7 +40,7 @@
|
||||
```
|
||||
|
||||
* 离线启动示例:
|
||||
|
||||
|
||||
```python
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
|
||||
71
examples/splitwise/start_mixed.sh
Normal file
71
examples/splitwise/start_mixed.sh
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 2s..."
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# prepare environment
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
|
||||
# start router
|
||||
export FD_LOG_DIR="log_router"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
router_port=9000
|
||||
nohup python -m fastdeploy.router.launch \
|
||||
--port ${router_port} \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
sleep 1
|
||||
|
||||
# start modelserver 0
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log_server_0"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8100 \
|
||||
--metrics-port 8101 \
|
||||
--engine-worker-queue-port 8102 \
|
||||
--cache-queue-port 8103 \
|
||||
--max-model-len 32768 \
|
||||
--router "0.0.0.0:${router_port}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
sleep 1
|
||||
|
||||
wait_for_health 8100
|
||||
|
||||
# start modelserver 1
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log_server_1"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8200 \
|
||||
--metrics-port 8201 \
|
||||
--engine-worker-queue-port 8202 \
|
||||
--cache-queue-port 8203 \
|
||||
--max-model-len 32768 \
|
||||
--router "0.0.0.0:${router_port}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health 8200
|
||||
66
examples/splitwise/start_v0_tp1.sh
Normal file
66
examples/splitwise/start_v0_tp1.sh
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 2s..."
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
aistudio download --model ${MODEL_NAME}
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
|
||||
# start prefill
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8100 \
|
||||
--metrics-port 8101 \
|
||||
--engine-worker-queue-port 8102 \
|
||||
--cache-queue-port 8103 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "prefill" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
wait_for_health 8100
|
||||
|
||||
# start decode
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 9000 \
|
||||
--metrics-port 9001 \
|
||||
--engine-worker-queue-port 9002 \
|
||||
--cache-queue-port 9003 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "decode" \
|
||||
--innode-prefill-ports 8102 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
wait_for_health 9000
|
||||
96
examples/splitwise/start_v1_tp1.sh
Normal file
96
examples/splitwise/start_v1_tp1.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 2s..."
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# prepare environment
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
|
||||
# start redis
|
||||
if ! redis-cli ping &>/dev/null; then
|
||||
echo "Redis is not running. Starting redis-server..."
|
||||
redis-server --daemonize yes
|
||||
sleep 1
|
||||
else
|
||||
echo "Redis is already running."
|
||||
fi
|
||||
sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8100 \
|
||||
--metrics-port 8101 \
|
||||
--engine-worker-queue-port 8102 \
|
||||
--cache-queue-port 8103 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--rdma-comm-ports 8104 \
|
||||
--pd-comm-port 8105 \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
wait_for_health 8100
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 9000 \
|
||||
--metrics-port 9001 \
|
||||
--engine-worker-queue-port 9002 \
|
||||
--cache-queue-port 9003 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--rdma-comm-ports 9004 \
|
||||
--pd-comm-port 9005 \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
wait_for_health 9000
|
||||
98
examples/splitwise/start_v1_tp2.sh
Normal file
98
examples/splitwise/start_v1_tp2.sh
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 2s..."
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# prepare environment
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
|
||||
# start redis
|
||||
if ! redis-cli ping &>/dev/null; then
|
||||
echo "Redis is not running. Starting redis-server..."
|
||||
redis-server --daemonize yes
|
||||
sleep 1
|
||||
else
|
||||
echo "Redis is already running."
|
||||
fi
|
||||
sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8100 \
|
||||
--metrics-port 8101 \
|
||||
--engine-worker-queue-port 8102 \
|
||||
--cache-queue-port 8103 \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 2 \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--pd-comm-port 8104 \
|
||||
--rdma-comm-ports 8105,8106 \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
wait_for_health 8100
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=2,3
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 9000 \
|
||||
--metrics-port 9001 \
|
||||
--engine-worker-queue-port 9002 \
|
||||
--cache-queue-port 9003 \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 2 \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--pd-comm-port 9004 \
|
||||
--rdma-comm-ports 9005,9006 \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
wait_for_health 9000
|
||||
93
examples/splitwise/start_v2_tp1.sh
Normal file
93
examples/splitwise/start_v2_tp1.sh
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 2s..."
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# prepare environment
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
|
||||
# start router
|
||||
export FD_LOG_DIR="log_router"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
router_port=9000
|
||||
nohup python -m fastdeploy.router.launch \
|
||||
--port ${router_port} \
|
||||
--splitwise \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8100 \
|
||||
--metrics-port 8101 \
|
||||
--engine-worker-queue-port 8102 \
|
||||
--cache-queue-port 8103 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "ipc,rdma" \
|
||||
--rdma-comm-ports 8104 \
|
||||
--pd-comm-port 8105 \
|
||||
--router "0.0.0.0:${router_port}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health 8100
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8200 \
|
||||
--metrics-port 8201 \
|
||||
--engine-worker-queue-port 8202 \
|
||||
--cache-queue-port 8203 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "ipc,rdma" \
|
||||
--rdma-comm-ports 8204 \
|
||||
--pd-comm-port 8205 \
|
||||
--router "0.0.0.0:${router_port}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health 8200
|
||||
96
examples/splitwise/start_v2_tp2.sh
Normal file
96
examples/splitwise/start_v2_tp2.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
|
||||
wait_for_health() {
|
||||
local server_port=$1
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
else
|
||||
echo "Service not ready. Retrying in 2s..."
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# prepare environment
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=0
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
|
||||
# start router
|
||||
export FD_LOG_DIR="log_router"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
echo "start router"
|
||||
router_port=9000
|
||||
nohup python -m fastdeploy.router.launch \
|
||||
--port ${router_port} \
|
||||
--splitwise \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
echo "start prefill"
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8100 \
|
||||
--metrics-port 8101 \
|
||||
--engine-worker-queue-port 8102 \
|
||||
--cache-queue-port 8103 \
|
||||
--tensor-parallel-size 2 \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "prefill" \
|
||||
--pd-comm-port 8104 \
|
||||
--rdma-comm-ports 8105,8106 \
|
||||
--router "0.0.0.0:${router_port}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health 8100
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=2,3
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
echo "start decode"
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port 8200 \
|
||||
--metrics-port 8201 \
|
||||
--engine-worker-queue-port 8202 \
|
||||
--cache-queue-port 8203 \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 2 \
|
||||
--splitwise-role "decode" \
|
||||
--pd-comm-port 8204 \
|
||||
--rdma-comm-ports 8205,8206 \
|
||||
--router "0.0.0.0:${router_port}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health 8200
|
||||
7
examples/splitwise/stop.sh
Normal file
7
examples/splitwise/stop.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
pkill -9 -f python
|
||||
pkill -9 -f fastdeploy
|
||||
pkill -f -9 gunicorn
|
||||
|
||||
if redis-cli ping >/dev/null 2>&1; then
|
||||
redis-cli shutdown
|
||||
fi
|
||||
20
examples/splitwise/test.sh
Normal file
20
examples/splitwise/test.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
# using v0 version, the request must be sent to the decode instance
|
||||
# using v1 version, the request can be sent to the prefill or decode instance
|
||||
# using v2 version, the request must be sent to the router
|
||||
|
||||
port=${1:-9000}
|
||||
echo "port: ${port}"
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
|
||||
curl -X POST "http://0.0.0.0:${port}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Introduce shenzhen"}
|
||||
],
|
||||
"max_tokens": 20,
|
||||
"stream": true
|
||||
}'
|
||||
@@ -344,6 +344,8 @@ class CacheMessager:
|
||||
)
|
||||
item["layer_idx"] = current_layer_idx
|
||||
if item["layer_idx"] == self.num_layers:
|
||||
if "error" not in item["status"]:
|
||||
item["status"] = "finished"
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
@@ -359,7 +361,7 @@ class CacheMessager:
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
task, _ = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
@@ -376,7 +378,8 @@ class CacheMessager:
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
time.sleep(0.001)
|
||||
logger.error(f"handle_connect_task has exception: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
|
||||
class CacheMessagerV1:
|
||||
|
||||
@@ -1310,6 +1310,24 @@ class CacheConfig:
|
||||
logger.info("=============================================================")
|
||||
|
||||
|
||||
class RouterConfig:
|
||||
"""
|
||||
Configuration for router
|
||||
Attributes:
|
||||
router: the url of router, such as http://127.0.0.1:8000
|
||||
api_server_host: the host ip of model server
|
||||
api_server_port: the http port of model server
|
||||
"""
|
||||
|
||||
def __init__(self, args: dict):
|
||||
self.router = args["router"]
|
||||
if self.router is not None and not self.router.startswith(("http://", "https://")):
|
||||
self.router = f"http://{self.router}"
|
||||
|
||||
self.api_server_host = get_host_ip()
|
||||
self.api_server_port = args["port"]
|
||||
|
||||
|
||||
class CommitConfig:
|
||||
"""
|
||||
Configuration for tracking version information from version.txt
|
||||
@@ -1411,6 +1429,7 @@ class FDConfig:
|
||||
speculative_config: SpeculativeConfig = None,
|
||||
eplb_config: EPLBConfig = None,
|
||||
structured_outputs_config: StructuredOutputsConfig = None,
|
||||
router_config: RouterConfig = None,
|
||||
tokenizer: str = None,
|
||||
ips: str = None,
|
||||
use_warmup: bool = False,
|
||||
@@ -1438,6 +1457,7 @@ class FDConfig:
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
|
||||
self.router_config: RouterConfig = router_config
|
||||
|
||||
# Initialize cuda graph capture list
|
||||
max_capture_shape = self.scheduler_config.max_num_seqs
|
||||
@@ -1517,6 +1537,7 @@ class FDConfig:
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.init_cache_info()
|
||||
if test_mode:
|
||||
return
|
||||
self.check()
|
||||
@@ -1734,29 +1755,66 @@ class FDConfig:
|
||||
"""
|
||||
initialize cache info
|
||||
"""
|
||||
disaggregate_info = {}
|
||||
# TODO: group the splitiwse params, remove code of v0
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise or dp scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
self.splitwise_version = None
|
||||
if self.scheduler_config.name == "local" and (self.router_config is None or self.router_config.router is None):
|
||||
self.splitwise_version = "v0"
|
||||
elif self.scheduler_config.name in ("splitwise", "dp"):
|
||||
self.splitwise_version = "v1"
|
||||
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
|
||||
self.splitwise_version = "v2"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported scheduler mode, scheduler_name: {self.scheduler_config.name}, "
|
||||
f"router_config: {self.router_config}"
|
||||
)
|
||||
logger.info(f"splitwise_version: {self.splitwise_version}")
|
||||
|
||||
if isinstance(self.parallel_config.engine_worker_queue_port, (int, str)):
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port
|
||||
else:
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
]
|
||||
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None
|
||||
|
||||
self.disaggregate_info = {}
|
||||
if self.scheduler_config.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.scheduler_config.splitwise_role
|
||||
disaggregate_info["cache_info"] = dict()
|
||||
self.disaggregate_info["role"] = self.scheduler_config.splitwise_role
|
||||
self.disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
||||
disaggregate_info["transfer_protocol"] = current_protocol
|
||||
self.disaggregate_info["transfer_protocol"] = current_protocol
|
||||
|
||||
for protocol in current_protocol:
|
||||
if protocol == "ipc":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
self.disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
],
|
||||
"port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
self.disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.cache_config.pd_comm_port[0],
|
||||
"port": connector_port,
|
||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||
}
|
||||
self.disaggregate_info = disaggregate_info
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
if self.router_config:
|
||||
self.register_info = {
|
||||
"role": self.scheduler_config.splitwise_role,
|
||||
"host_ip": self.host_ip,
|
||||
"port": self.router_config.api_server_port,
|
||||
"connector_port": connector_port,
|
||||
"rdma_ports": self.cache_config.rdma_comm_ports,
|
||||
"engine_worker_queue_port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
|
||||
}
|
||||
logger.info(f"register_info: {self.register_info}")
|
||||
|
||||
def read_from_config(self):
|
||||
"""
|
||||
|
||||
@@ -34,6 +34,7 @@ from fastdeploy.config import (
|
||||
ParallelConfig,
|
||||
PlasAttentionConfig,
|
||||
PoolerConfig,
|
||||
RouterConfig,
|
||||
RunnerOption,
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
@@ -74,6 +75,10 @@ class EngineArgs:
|
||||
"""
|
||||
The name or path of the model to be used.
|
||||
"""
|
||||
port: Optional[str] = None
|
||||
"""
|
||||
Port for api server.
|
||||
"""
|
||||
served_model_name: Optional[str] = None
|
||||
"""
|
||||
The name of the model being served.
|
||||
@@ -445,6 +450,11 @@ class EngineArgs:
|
||||
- To enable custom logits processors, add your dotted paths to module and class names to the list.
|
||||
"""
|
||||
|
||||
router: Optional[str] = None
|
||||
"""
|
||||
Url for router server, such as `0.0.0.0:30000`.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -859,21 +869,6 @@ class EngineArgs:
|
||||
help="Flag to enable prefix caching.",
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
"--splitwise-role",
|
||||
type=str,
|
||||
default=EngineArgs.splitwise_role,
|
||||
help="Role of splitwise. Default is \
|
||||
'mixed'. (prefill, decode, mixed)",
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
"--innode-prefill-ports",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.innode_prefill_ports,
|
||||
help="port for innode prefill",
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
"--enable-chunked-prefill",
|
||||
action="store_true",
|
||||
@@ -903,27 +898,53 @@ class EngineArgs:
|
||||
help=("For chunked prefill, the threshold number of" " tokens for a prompt to be considered long."),
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
# Splitwise deployment parameters group
|
||||
splitwise_group = parser.add_argument_group("Splitwise Deployment")
|
||||
splitwise_group.add_argument(
|
||||
"--splitwise-role",
|
||||
type=str,
|
||||
default=EngineArgs.splitwise_role,
|
||||
help="Role of splitwise. Default is \
|
||||
'mixed'. (prefill, decode, mixed)",
|
||||
)
|
||||
|
||||
splitwise_group.add_argument(
|
||||
"--innode-prefill-ports",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.innode_prefill_ports,
|
||||
help="port for innode prefill, only used in single machine splitwise deployment",
|
||||
)
|
||||
|
||||
splitwise_group.add_argument(
|
||||
"--cache-transfer-protocol",
|
||||
type=str,
|
||||
default=EngineArgs.cache_transfer_protocol,
|
||||
help="support protocol list, comma separated, default is ipc",
|
||||
help="support protocol list (ipc or rdma), comma separated, default is ipc",
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
splitwise_group.add_argument(
|
||||
"--pd-comm-port",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.pd_comm_port,
|
||||
help="port for splitwise communication.",
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
splitwise_group.add_argument(
|
||||
"--rdma-comm-ports",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.rdma_comm_ports,
|
||||
help="ports for rdma communication.",
|
||||
)
|
||||
|
||||
# Router parameters group
|
||||
router_group = parser.add_argument_group("Router")
|
||||
router_group.add_argument(
|
||||
"--router",
|
||||
type=str,
|
||||
default=EngineArgs.router,
|
||||
help="url for router server.",
|
||||
)
|
||||
|
||||
# Scheduler parameters group
|
||||
scheduler_group = parser.add_argument_group("Scheduler")
|
||||
scheduler_group.add_argument(
|
||||
@@ -1044,7 +1065,11 @@ class EngineArgs:
|
||||
"""
|
||||
Create an instance of EngineArgs from command line arguments.
|
||||
"""
|
||||
return cls(**{field.name: getattr(args, field.name) for field in dataclass_fields(cls)})
|
||||
args_dict = {}
|
||||
for field in dataclass_fields(cls):
|
||||
if hasattr(args, field.name):
|
||||
args_dict[field.name] = getattr(args, field.name)
|
||||
return cls(**args_dict)
|
||||
|
||||
def create_speculative_config(self) -> SpeculativeConfig:
|
||||
""" """
|
||||
@@ -1063,6 +1088,7 @@ class EngineArgs:
|
||||
prefix_len = len(prefix)
|
||||
|
||||
all = asdict(self)
|
||||
all.pop("port") # port and scheduler_port are not the same
|
||||
params = dict()
|
||||
for k, v in all.items():
|
||||
if k[:prefix_len] == prefix:
|
||||
@@ -1151,6 +1177,7 @@ class EngineArgs:
|
||||
scheduler_cfg = self.create_scheduler_config()
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
plas_attention_config = self.create_plas_attention_config()
|
||||
router_config = RouterConfig(all_dict)
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
@@ -1170,6 +1197,7 @@ class EngineArgs:
|
||||
speculative_config=speculative_cfg,
|
||||
eplb_config=eplb_cfg,
|
||||
structured_outputs_config=structured_outputs_config,
|
||||
router_config=router_config,
|
||||
ips=self.ips,
|
||||
use_warmup=self.use_warmup,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
|
||||
@@ -23,10 +23,11 @@ import time
|
||||
import traceback
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import requests
|
||||
import zmq
|
||||
from opentelemetry import trace
|
||||
|
||||
@@ -45,6 +46,7 @@ from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||
from fastdeploy.router.utils import check_service_health
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import (
|
||||
@@ -95,6 +97,7 @@ class EngineService:
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.llm_logger.info("Use V1 KVCache Scheduler")
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
cfg.scheduler_config.max_num_seqs,
|
||||
cfg,
|
||||
@@ -103,6 +106,7 @@ class EngineService:
|
||||
cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
else:
|
||||
self.llm_logger.info("Use V0 KVCache Scheduler")
|
||||
self.resource_manager = ResourceManager(
|
||||
cfg.scheduler_config.max_num_seqs,
|
||||
cfg,
|
||||
@@ -118,7 +122,6 @@ class EngineService:
|
||||
]
|
||||
|
||||
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
|
||||
self.waiting_requests = []
|
||||
self.token_processor = TokenProcessor(
|
||||
cfg=cfg,
|
||||
cached_generated_tokens=self.scheduler,
|
||||
@@ -149,14 +152,18 @@ class EngineService:
|
||||
def start(self):
|
||||
self.running = True
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
|
||||
self.insert_task_to_worker_thread = threading.Thread(
|
||||
target=self._schedule_request_to_worker_v1, daemon=True
|
||||
)
|
||||
else:
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
|
||||
self.insert_task_to_worker_thread.start()
|
||||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||||
self.token_processor.run()
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.split_mode_get_tasks()
|
||||
self._process_splitwise_task()
|
||||
|
||||
self._register_to_router()
|
||||
|
||||
def create_data_processor(self):
|
||||
self.input_processor = InputPreprocessor(
|
||||
@@ -313,7 +320,7 @@ class EngineService:
|
||||
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
|
||||
def insert_tasks(self, tasks, current_id=-1, allocated=False):
|
||||
def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1, allocated=False):
|
||||
"""
|
||||
Insert tasks to engine.
|
||||
"""
|
||||
@@ -358,6 +365,7 @@ class EngineService:
|
||||
current_tasks.append(cur_task)
|
||||
if current_tasks:
|
||||
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
|
||||
self.llm_logger.debug(f"put task to engine worker queue, task:{current_tasks}")
|
||||
return True
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
@@ -574,7 +582,7 @@ class EngineService:
|
||||
patch_st += chunk_patch_num
|
||||
request.set("prefill_chunk_info", chunks_info)
|
||||
|
||||
def _insert_task_to_worker(self):
|
||||
def _schedule_request_to_worker(self):
|
||||
"""
|
||||
Insert task to engine thread, monitor scheduler request queue.
|
||||
if the engine has resource, insert task to engine
|
||||
@@ -619,9 +627,12 @@ class EngineService:
|
||||
if len(tasks) == 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if self.cfg.splitwise_version == "v2" and self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
# the task in decode instance will processed in _process_splitwise_task thread
|
||||
continue
|
||||
|
||||
llm_logger.debug(f"get tasks from scheduler: {tasks}")
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
insert_successful = self.insert_tasks(tasks, current_id)
|
||||
@@ -636,7 +647,7 @@ class EngineService:
|
||||
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
self.llm_logger.error(err_msg)
|
||||
|
||||
def _scheduler_task_to_worker_v1(self):
|
||||
def _schedule_request_to_worker_v1(self):
|
||||
"""
|
||||
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
|
||||
"""
|
||||
@@ -664,6 +675,7 @@ class EngineService:
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
need_delete_tasks = []
|
||||
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
|
||||
@@ -822,6 +834,7 @@ class EngineService:
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
return
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
@@ -975,17 +988,38 @@ class EngineService:
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def split_mode_get_tasks(self):
|
||||
def _process_splitwise_task(self):
|
||||
"""
|
||||
Split mode get tasks
|
||||
Processing tasks from engine worker queue in splitwise deployment.
|
||||
For v0 version, prefill instance gets tasks from engine worker queue.
|
||||
For v1 and v2 version, decode instance gets raw tasks from engine worker queue to preallocate resources,
|
||||
and decode instance gets prefilled tasks from engine worker queue to generate tokens.
|
||||
TODO: unifiy the communication between decode and prefill instances.
|
||||
"""
|
||||
|
||||
def receiver_loop():
|
||||
waiting_resource_requests = []
|
||||
waiting_ready_tasks = []
|
||||
|
||||
# Waiting for the api_server and scheduler in decode to
|
||||
# receive the request sent by the client
|
||||
def _decode_process_prefilled_task_v0_scheduler(input_tasks):
|
||||
ready_tasks = []
|
||||
waiting_tasks = []
|
||||
for task in input_tasks:
|
||||
if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id):
|
||||
ready_tasks.append(task)
|
||||
else:
|
||||
waiting_tasks.append(task)
|
||||
self.insert_tasks(ready_tasks, allocated=True)
|
||||
if self.cfg.splitwise_version in ("v0", "v2"):
|
||||
self.scheduler.put_results(ready_tasks)
|
||||
return waiting_tasks
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
|
||||
processed_indices = []
|
||||
for idx, task in enumerate(self.waiting_requests):
|
||||
for idx, task in enumerate(waiting_resource_requests):
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if self.resource_manager.preallocate_resource_in_d(task):
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
@@ -1004,21 +1038,27 @@ class EngineService:
|
||||
break
|
||||
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
waiting_resource_requests.pop(idx)
|
||||
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
waiting_ready_tasks = _decode_process_prefilled_task_v0_scheduler(waiting_ready_tasks)
|
||||
|
||||
if self.engine_worker_queue.disaggregate_queue_empty():
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks()
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
|
||||
# prefill instance gets tasks from engine worker queue
|
||||
if role == "prefill":
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
|
||||
# decode instance gets tasks from engine worker queue
|
||||
elif role == "decode":
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if isinstance(tasks[0], RequestOutput):
|
||||
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
@@ -1057,13 +1097,12 @@ class EngineService:
|
||||
self.resource_manager.insert_task_for_decoding(task)
|
||||
|
||||
else:
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
waiting_ready_tasks.extend(_decode_process_prefilled_task_v0_scheduler(tasks))
|
||||
elif isinstance(tasks[0], Request):
|
||||
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
|
||||
if len(waiting_resource_requests):
|
||||
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
waiting_resource_requests.extend(tasks)
|
||||
else:
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
@@ -1087,13 +1126,12 @@ class EngineService:
|
||||
if not self.enable_decode_cache_task:
|
||||
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||
else:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
waiting_resource_requests.extend(new_waiting)
|
||||
self.llm_logger.info(
|
||||
f"Added {len(new_waiting)} tasks to waiting queue"
|
||||
)
|
||||
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type: {type(tasks[0])}")
|
||||
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Error in main loop: {e}")
|
||||
@@ -1130,6 +1168,42 @@ class EngineService:
|
||||
llm_logger.error(f"Clear data error: {e}")
|
||||
return False
|
||||
|
||||
def _register_to_router(self):
|
||||
"""If use router, register this server to router"""
|
||||
timeout = 5
|
||||
sleep_seconds = 10
|
||||
|
||||
def _register():
|
||||
while True:
|
||||
try:
|
||||
time.sleep(sleep_seconds)
|
||||
|
||||
api_server_host = self.cfg.router_config.api_server_host
|
||||
api_server_port = self.cfg.router_config.api_server_port
|
||||
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
||||
if not check_service_health(api_server_url):
|
||||
continue
|
||||
|
||||
router_url = self.cfg.router_config.router
|
||||
resp = requests.post(
|
||||
f"{router_url}/register",
|
||||
json=self.cfg.register_info,
|
||||
timeout=timeout,
|
||||
)
|
||||
if not resp.ok:
|
||||
llm_logger.error(
|
||||
f"Router registration failed: {resp.status_code}, "
|
||||
f"{resp.text}, {self.cfg.register_info}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
llm_logger.error(f"Register to router request error: {e}")
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Unexpected error during router registration: {e}")
|
||||
|
||||
if self.cfg.router_config.router is not None:
|
||||
register_thread = threading.Thread(target=_register, daemon=True)
|
||||
register_thread.start()
|
||||
|
||||
def _exit_sub_services(self):
|
||||
"""
|
||||
exit sub services
|
||||
|
||||
@@ -694,8 +694,6 @@ class LLMEngine:
|
||||
self.splitwise_receive_thread.daemon = True
|
||||
self.splitwise_receive_thread.start()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
|
||||
@@ -527,6 +527,8 @@ class RequestOutput:
|
||||
f"num_input_image_tokens={self.num_input_image_tokens}, "
|
||||
f"num_input_video_tokens={self.num_input_video_tokens}, "
|
||||
f"metrics={self.metrics}, "
|
||||
f"error_code={self.error_code}, "
|
||||
f"error_msg={self.error_msg},"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -451,6 +451,8 @@ class CompletionRequest(BaseModel):
|
||||
temperature: Optional[float] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0, le=1)
|
||||
user: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
disaggregate_info: Optional[dict] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
top_k: Optional[int] = None
|
||||
@@ -486,8 +488,6 @@ class CompletionRequest(BaseModel):
|
||||
dict: request parameters in dict format
|
||||
"""
|
||||
req_dict = {}
|
||||
if request_id is not None:
|
||||
req_dict["request_id"] = request_id
|
||||
|
||||
# parse request model into dict
|
||||
if self.suffix is not None:
|
||||
@@ -497,6 +497,8 @@ class CompletionRequest(BaseModel):
|
||||
if value is not None:
|
||||
req_dict[key] = value
|
||||
|
||||
if request_id is not None:
|
||||
req_dict["request_id"] = request_id
|
||||
if prompt is not None:
|
||||
req_dict["prompt"] = prompt
|
||||
|
||||
@@ -604,6 +606,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
user: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
request_id: Optional[str] = None
|
||||
disaggregate_info: Optional[dict] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
top_k: Optional[int] = None
|
||||
@@ -644,8 +648,6 @@ class ChatCompletionRequest(BaseModel):
|
||||
dict: request parameters in dict format
|
||||
"""
|
||||
req_dict = {}
|
||||
if request_id is not None:
|
||||
req_dict["request_id"] = request_id
|
||||
|
||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||
@@ -666,6 +668,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
if value is not None:
|
||||
req_dict[key] = value
|
||||
|
||||
if request_id is not None:
|
||||
req_dict["request_id"] = request_id
|
||||
|
||||
if "prompt_token_ids" in req_dict:
|
||||
if "messages" in req_dict:
|
||||
del req_dict["messages"]
|
||||
|
||||
@@ -114,7 +114,11 @@ class OpenAIServingChat:
|
||||
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
|
||||
api_server_logger.info(f"current {self.engine_client.semaphore.status()}")
|
||||
|
||||
if request.user is not None:
|
||||
if request.request_id is not None:
|
||||
request_id = request.request_id
|
||||
if not request_id.startswith("chatcmpl-"):
|
||||
request_id = f"chatcmpl-{request_id}"
|
||||
elif request.user is not None:
|
||||
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
|
||||
else:
|
||||
request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
|
||||
@@ -85,7 +85,11 @@ class OpenAIServingCompletion:
|
||||
error=ErrorInfo(message=err_msg, type=ErrorType.INTERNAL_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
|
||||
)
|
||||
created_time = int(time.time())
|
||||
if request.user is not None:
|
||||
if request.request_id is not None:
|
||||
request_id = request.request_id
|
||||
if not request_id.startswith("cmpl-"):
|
||||
request_id = f"cmpl-{request_id}"
|
||||
elif request.user is not None:
|
||||
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
|
||||
else:
|
||||
request_id = f"cmpl-{uuid.uuid4()}"
|
||||
|
||||
@@ -662,7 +662,10 @@ class EngineWorkerQueue:
|
||||
self.client_read_info_flag[:] = [0] * self.num_client
|
||||
|
||||
self.cache_infos.extend(cache_info)
|
||||
llm_logger.debug(f"cache_infos: {self.cache_infos} local_data_parallel_id:{self.local_data_parallel_id}")
|
||||
llm_logger.debug(
|
||||
f"put cache_infos to engine worker queue: {self.cache_infos}, "
|
||||
f"local_data_parallel_id:{self.local_data_parallel_id}"
|
||||
)
|
||||
self.lock_info.release()
|
||||
|
||||
def get_cache_info(self) -> List[Any]:
|
||||
@@ -684,7 +687,10 @@ class EngineWorkerQueue:
|
||||
self.cache_infos[:] = list()
|
||||
self.lock_info.release()
|
||||
if len(cache_infos) != 0:
|
||||
llm_logger.debug(f"get cache infos: {cache_infos} local_data_parallel_id:{self.local_data_parallel_id}")
|
||||
llm_logger.debug(
|
||||
f"get cache infos from engine worker queue: {cache_infos}, "
|
||||
f"local_data_parallel_id:{self.local_data_parallel_id}"
|
||||
)
|
||||
return cache_infos
|
||||
|
||||
def num_cache_infos(self) -> int:
|
||||
|
||||
@@ -456,6 +456,7 @@ class TokenProcessor:
|
||||
recycle resources
|
||||
"""
|
||||
if is_prefill:
|
||||
start_time = time.time()
|
||||
while True:
|
||||
finished_task_ids = self.engine_worker_queue.get_finished_req()
|
||||
if len(finished_task_ids) > 0:
|
||||
@@ -474,6 +475,9 @@ class TokenProcessor:
|
||||
if self.prefill_result_status[task_id] != "finished":
|
||||
result.error_code = 400
|
||||
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
|
||||
llm_logger.info(
|
||||
f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}"
|
||||
)
|
||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||
break
|
||||
else:
|
||||
@@ -731,11 +735,10 @@ class TokenProcessor:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if (
|
||||
not is_prefill
|
||||
or self.cfg.scheduler_config.name == "splitwise"
|
||||
or self.cfg.scheduler_config.name == "dp"
|
||||
):
|
||||
|
||||
if not (is_prefill and self.cfg.splitwise_version == "v0"):
|
||||
# NOTE: prefill instance in v0 version does not return result to scheduler
|
||||
llm_logger.debug(f"get response from infer: {result}")
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result, mtype)
|
||||
|
||||
15
fastdeploy/router/__init__.py
Normal file
15
fastdeploy/router/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
58
fastdeploy/router/launch.py
Normal file
58
fastdeploy/router/launch.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from fastdeploy.router.router import start_router
|
||||
from fastdeploy.utils import router_logger as logger
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Router for splitwise deployment testing")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host address to bind the router server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default="9000",
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--splitwise",
|
||||
action="store_true",
|
||||
help="Router uses splitwise deployment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-timeout-secs",
|
||||
type=int,
|
||||
default=1800,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
start_router(args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting router: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
317
fastdeploy/router/router.py
Normal file
317
fastdeploy/router/router.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Async Router server for FastDeploy.
|
||||
Handles client requests and manages prefill/decode/mixed instances.
|
||||
This module references the router implementation of slglang and vllm.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from itertools import chain
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from fastdeploy.router.utils import (
|
||||
InstanceInfo,
|
||||
InstanceRole,
|
||||
check_service_health_async,
|
||||
)
|
||||
from fastdeploy.utils import router_logger as logger
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
Router class that handles requests from client and
|
||||
collects prefill/decode instance information
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.host = args.host
|
||||
self.port = args.port
|
||||
self.splitwise = args.splitwise
|
||||
self.timeout = args.request_timeout_secs
|
||||
|
||||
self.mixed_servers = []
|
||||
self.prefill_servers = []
|
||||
self.decode_servers = []
|
||||
self.lock = asyncio.Lock() # async-safe lock
|
||||
|
||||
async def register_instance(self, instance_info_dict: dict):
|
||||
"""Register an instance asynchronously"""
|
||||
try:
|
||||
inst_info = InstanceInfo(**instance_info_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"register instance failed: {e}")
|
||||
raise
|
||||
|
||||
if (self.splitwise and inst_info.role == InstanceRole.MIXED) or (
|
||||
not self.splitwise and inst_info.role != InstanceRole.MIXED
|
||||
):
|
||||
raise ValueError(f"Invalid instance role: {inst_info.role}, splitwise: {self.splitwise}")
|
||||
|
||||
if not await check_service_health_async(inst_info.url()):
|
||||
raise RuntimeError(f"Instance {inst_info} is not healthy")
|
||||
|
||||
async with self.lock:
|
||||
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers:
|
||||
self.mixed_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}"
|
||||
)
|
||||
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers:
|
||||
self.prefill_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register prefill instance success: {inst_info}, "
|
||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
||||
)
|
||||
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers:
|
||||
self.decode_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register decode instance success: {inst_info}, "
|
||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
||||
)
|
||||
|
||||
async def registered_number(self):
|
||||
"""Get number of registered instances"""
|
||||
return {
|
||||
"mixed": len(self.mixed_servers),
|
||||
"prefill": len(self.prefill_servers),
|
||||
"decode": len(self.decode_servers),
|
||||
}
|
||||
|
||||
async def select_pd(self):
|
||||
"""Select one prefill and one decode server"""
|
||||
async with self.lock:
|
||||
if not self.prefill_servers:
|
||||
raise RuntimeError("No prefill servers available")
|
||||
if not self.decode_servers:
|
||||
raise RuntimeError("No decode servers available")
|
||||
pidx = random.randint(0, len(self.prefill_servers) - 1)
|
||||
didx = random.randint(0, len(self.decode_servers) - 1)
|
||||
return self.prefill_servers[pidx], self.decode_servers[didx]
|
||||
|
||||
async def select_mixed(self):
|
||||
"""Select one mixed server"""
|
||||
async with self.lock:
|
||||
if not self.mixed_servers:
|
||||
raise RuntimeError("No mixed servers available")
|
||||
idx = random.randint(0, len(self.mixed_servers) - 1)
|
||||
return self.mixed_servers[idx]
|
||||
|
||||
async def handle_request(self, request_data: dict, endpoint_name: str):
|
||||
if self.splitwise:
|
||||
return await self.handle_splitwise_request(request_data, endpoint_name)
|
||||
else:
|
||||
return await self.handle_mixed_request(request_data, endpoint_name)
|
||||
|
||||
async def handle_mixed_request(self, request_data: dict, endpoint_name: str):
|
||||
logger.debug(f"Received request: {request_data}")
|
||||
mixed_server = await self.select_mixed()
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
||||
else:
|
||||
return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
||||
|
||||
async def handle_splitwise_request(self, request_data: dict, endpoint_name: str):
|
||||
logger.debug(f"Received request: {request_data}")
|
||||
prefill_server, decode_server = await self.select_pd()
|
||||
|
||||
# TODO: unify the disaggregate_info in server and remove redundancy params
|
||||
is_same_node = prefill_server.host_ip == decode_server.host_ip
|
||||
use_ipc = (
|
||||
is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
|
||||
)
|
||||
|
||||
cache_info = {}
|
||||
if use_ipc:
|
||||
cache_info["ipc"] = {
|
||||
"ip": decode_server.host_ip,
|
||||
"port": decode_server.engine_worker_queue_port,
|
||||
"device_ids": decode_server.device_ids,
|
||||
}
|
||||
else:
|
||||
cache_info["rdma"] = {
|
||||
"ip": decode_server.host_ip,
|
||||
"port": decode_server.connector_port,
|
||||
"rdma_port": decode_server.rdma_ports,
|
||||
}
|
||||
|
||||
disaggregate_info = {
|
||||
"prefill": prefill_server.to_dict(),
|
||||
"decode": decode_server.to_dict(),
|
||||
"role": "decode",
|
||||
"cache_info": cache_info,
|
||||
"transfer_protocol": "ipc" if use_ipc else "rdma",
|
||||
}
|
||||
|
||||
modified_request = request_data.copy()
|
||||
modified_request["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in modified_request:
|
||||
modified_request["request_id"] = str(uuid4())
|
||||
|
||||
logger.debug(f"Modified request: {modified_request}")
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await self._generate_stream(
|
||||
modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name
|
||||
)
|
||||
else:
|
||||
return await self._generate(
|
||||
modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name
|
||||
)
|
||||
|
||||
async def _generate(
|
||||
self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions"
|
||||
) -> ORJSONResponse:
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
||||
tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
ret_json = await results[return_result_url_index].json()
|
||||
return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status)
|
||||
|
||||
async def _generate_stream(
|
||||
self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions"
|
||||
):
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
||||
tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = 1024 * 64 # prevent aiohttp's "Chunk too big" error
|
||||
async for chunk in results[return_result_url_index].content.iter_chunked(
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||
):
|
||||
logger.debug(f"receive response chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
async def monitor_instance_health(self, interval_secs: float = 5.0):
|
||||
"""
|
||||
Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
prefill_to_remove = []
|
||||
decode_to_remove = []
|
||||
mixed_to_remove = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# check servers
|
||||
prefill_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.prefill_servers]
|
||||
decode_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.decode_servers]
|
||||
mixed_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.mixed_servers]
|
||||
|
||||
# gather all tasks concurrently
|
||||
all_tasks = prefill_tasks + decode_tasks + mixed_tasks
|
||||
for inst, coro in all_tasks:
|
||||
try:
|
||||
resp = await coro
|
||||
if resp.status != 200:
|
||||
logger.warning(f"Instance {inst.url()} unhealthy: {resp.status}")
|
||||
if inst in self.prefill_servers:
|
||||
prefill_to_remove.append(inst)
|
||||
elif inst in self.decode_servers:
|
||||
decode_to_remove.append(inst)
|
||||
elif inst in self.mixed_servers:
|
||||
mixed_to_remove.append(inst)
|
||||
except Exception as e:
|
||||
logger.warning(f"Instance {inst.url()} check failed: {e}")
|
||||
if inst in self.prefill_servers:
|
||||
prefill_to_remove.append(inst)
|
||||
elif inst in self.decode_servers:
|
||||
decode_to_remove.append(inst)
|
||||
elif inst in self.mixed_servers:
|
||||
mixed_to_remove.append(inst)
|
||||
|
||||
# remove unhealthy instances under lock
|
||||
async with self.lock:
|
||||
if prefill_to_remove:
|
||||
for inst in prefill_to_remove:
|
||||
self.prefill_servers.remove(inst)
|
||||
logger.info(f"Removed unhealthy prefill instance: {inst.url()}")
|
||||
if decode_to_remove:
|
||||
for inst in decode_to_remove:
|
||||
self.decode_servers.remove(inst)
|
||||
logger.info(f"Removed unhealthy decode instance: {inst.url()}")
|
||||
if mixed_to_remove:
|
||||
for inst in mixed_to_remove:
|
||||
self.mixed_servers.remove(inst)
|
||||
logger.info(f"Removed unhealthy mixed instance: {inst.url()}")
|
||||
|
||||
await asyncio.sleep(interval_secs)
|
||||
|
||||
prefill_instances = [inst.url() for inst in self.prefill_servers]
|
||||
decode_instances = [inst.url() for inst in self.decode_servers]
|
||||
mixed_instance = [inst.url() for inst in self.mixed_servers]
|
||||
logger.debug(
|
||||
f"Healthy prefill instances: {prefill_instances}, "
|
||||
f"Healthy decode instances: {decode_instances}, "
|
||||
f"Healthy mixed instance: {mixed_instance}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to monitor instance health: {e}")
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def register(instance_info_dict: dict):
|
||||
"""Register prefill/decode/mixed servers"""
|
||||
try:
|
||||
await app.state.router.register_instance(instance_info_dict)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@app.get("/registered_number")
|
||||
async def registered_number():
|
||||
"""Get the number of registered prefill/decode/mixed servers"""
|
||||
return await app.state.router.registered_number()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request_data: dict):
|
||||
return await app.state.router.handle_request(request_data, "v1/chat/completions")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request_data: dict):
|
||||
return await app.state.router.handle_request(request_data, "v1/completions")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Basic health check"""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
"""Check all prefill and decode servers are healthy"""
|
||||
router = app.state.router
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [session.get(f"{s.url()}/health") for s in chain(router.prefill_servers, router.decode_servers)]
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
resp = await coro
|
||||
if resp.status != 200:
|
||||
logger.warning(f"Server {resp.url} not healthy: {resp.status}")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def start_router(router_args):
|
||||
app.state.router_args = router_args
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
app.state.router = Router(app.state.router_args)
|
||||
asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5))
|
||||
|
||||
uvicorn.run(app, host=router_args.host, port=router_args.port)
|
||||
131
fastdeploy/router/utils.py
Normal file
131
fastdeploy/router/utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
class InstanceRole(Enum):
|
||||
MIXED = 0
|
||||
PREFILL = 1
|
||||
DECODE = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceInfo:
|
||||
role: Union[InstanceRole, str]
|
||||
host_ip: str
|
||||
port: Union[int, str]
|
||||
connector_port: Union[int, str] = 0
|
||||
engine_worker_queue_port: Union[int, str] = 0
|
||||
transfer_protocol: List[str] = field(default_factory=list)
|
||||
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
|
||||
device_ids: Union[List[str], List[int]] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
"""check and unify fields"""
|
||||
if isinstance(self.role, str):
|
||||
try:
|
||||
self.role = InstanceRole[self.role.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid role string: {self.role}")
|
||||
elif not isinstance(self.role, InstanceRole):
|
||||
raise TypeError(f"role must be InstanceRole or str, got {type(self.role)}")
|
||||
|
||||
for t in self.transfer_protocol:
|
||||
assert t in ["ipc", "rdma"], f"Invalid transfer_protocol: {self.transfer_protocol}"
|
||||
|
||||
self.port = str(self.port)
|
||||
self.connector_port = str(self.connector_port)
|
||||
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
|
||||
if self.rdma_ports:
|
||||
self.rdma_ports = [str(p) for p in self.rdma_ports]
|
||||
if self.device_ids:
|
||||
self.device_ids = [str(i) for i in self.device_ids]
|
||||
|
||||
def to_dict(self):
|
||||
return {k: (v.name if isinstance(v, Enum) else v) for k, v in asdict(self).items()}
|
||||
|
||||
def url(self) -> str:
|
||||
url = f"{self.host_ip}:{self.port}"
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = f"http://{url}"
|
||||
return url
|
||||
|
||||
|
||||
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
Check the health status of a service.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
|
||||
timeout (int): Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
bool: True if the service is healthy, False otherwise.
|
||||
"""
|
||||
if not base_url.startswith(("http://", "https://")):
|
||||
base_url = f"http://{base_url}"
|
||||
|
||||
url = f"{base_url.rstrip('/')}/health"
|
||||
try:
|
||||
resp = requests.get(url, timeout=timeout)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def check_service_health_async(base_url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
Asynchronously check the health status of a service.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
|
||||
timeout (int): Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
bool: True if the service is healthy, False otherwise.
|
||||
"""
|
||||
if not base_url.startswith(("http://", "https://")):
|
||||
base_url = f"http://{base_url}"
|
||||
|
||||
url = f"{base_url.rstrip('/')}/health"
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
async with session.get(url) as resp:
|
||||
status = resp.status
|
||||
text = await resp.text()
|
||||
|
||||
if status == 200:
|
||||
print(f"[OK] Service is healthy ({status})")
|
||||
return True
|
||||
else:
|
||||
print(f"[WARN] Service not healthy ({status}): {text}")
|
||||
return False
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"[ERROR] Failed to connect to {url}: {e}")
|
||||
return False
|
||||
except asyncio.TimeoutError:
|
||||
print(f"[ERROR] Request to {url} timed out after {timeout}s")
|
||||
return False
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
import redis
|
||||
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.utils import get_logger, llm_logger
|
||||
|
||||
config_logger = get_logger("config", "config.log")
|
||||
|
||||
from .dp_scheduler import DPScheduler
|
||||
from .global_scheduler import GlobalScheduler
|
||||
@@ -84,10 +86,10 @@ class LocalSchedulerConfig:
|
||||
"""
|
||||
Print the current configuration to logs.
|
||||
"""
|
||||
llm_logger.info("LocalScheduler Configuration Information :")
|
||||
config_logger.info("LocalScheduler Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("=============================================================")
|
||||
config_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
config_logger.info("=============================================================")
|
||||
|
||||
|
||||
class DPLocalSchedulerConfig(LocalSchedulerConfig):
|
||||
@@ -312,6 +314,7 @@ class SchedulerConfig:
|
||||
Returns:
|
||||
Initialized scheduler instance (LocalScheduler or GlobalScheduler)
|
||||
"""
|
||||
llm_logger.info("Scheduler Type: %s" % self.name)
|
||||
|
||||
if self.name == "global":
|
||||
return GlobalScheduler(
|
||||
|
||||
@@ -195,6 +195,20 @@ class LocalScheduler:
|
||||
results += [(request_id, "duplicated request_id") for request_id in duplicated_ids]
|
||||
return results
|
||||
|
||||
def has_request(self, request_id: str) -> bool:
|
||||
"""
|
||||
Check if there are any pending requests in the scheduler.
|
||||
|
||||
Args:
|
||||
request_id: Optional specific request ID to check.
|
||||
If None, checks whether there are any pending requests.
|
||||
|
||||
Returns:
|
||||
True if there are pending requests, False otherwise.
|
||||
"""
|
||||
with self.mutex:
|
||||
return request_id in self.requests
|
||||
|
||||
def calc_required_blocks(self, token_num, block_size):
|
||||
"""
|
||||
Calculate the number of blocks needed for a given number of tokens.
|
||||
@@ -292,6 +306,7 @@ class LocalScheduler:
|
||||
Args:
|
||||
results: List of RequestOutput objects containing results
|
||||
"""
|
||||
scheduler_logger.debug(f"put results: {results}")
|
||||
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
|
||||
|
||||
finished_responses = [response.request_id for response in responses if response.finished]
|
||||
@@ -354,4 +369,8 @@ class LocalScheduler:
|
||||
if finished:
|
||||
self._recycle(request_id)
|
||||
scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
|
||||
if results:
|
||||
scheduler_logger.debug(f"get responses, {results}")
|
||||
|
||||
return results
|
||||
|
||||
@@ -18,12 +18,12 @@ import json
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -241,7 +241,7 @@ class SplitwiseConnector:
|
||||
},
|
||||
}
|
||||
|
||||
def send_splitwise_tasks(self, tasks, current_id):
|
||||
def send_splitwise_tasks(self, tasks: List[Request], current_id):
|
||||
"""
|
||||
Send splitwise tasks to all connected addresses.
|
||||
|
||||
@@ -276,6 +276,7 @@ class SplitwiseConnector:
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
self.logger.debug(f"send task to coupled instance, {addr}, {task}")
|
||||
self._send_message(addr, "prefill", [task])
|
||||
task.disaggregate_info["cache_info"] = decode_diagg
|
||||
task.disaggregate_info["role"] = "prefill"
|
||||
@@ -311,7 +312,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
self.logger.info("send first token to port decode")
|
||||
self.logger.info(f"send first token to decode, {[x.request_id for x in tasks_list]}")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
@@ -355,7 +356,7 @@ class SplitwiseConnector:
|
||||
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_infos(self, tasks, current_id):
|
||||
def send_cache_infos(self, tasks: List[Request], current_id):
|
||||
"""
|
||||
Send cache information to specific port.
|
||||
|
||||
@@ -432,8 +433,10 @@ class SplitwiseConnector:
|
||||
|
||||
if not is_decode and len(temp_cache_info):
|
||||
for k, v in temp_cache_info.items():
|
||||
self.logger.debug(f"send cache info to cachemessager, {v}")
|
||||
self.engine_worker_queue.put_cache_info(v)
|
||||
else:
|
||||
self.logger.debug(f"send cache info to coupled instance, {temp_cache_info}")
|
||||
if len(temp_cache_info):
|
||||
for k, v in temp_cache_info.items():
|
||||
self.logger.info(f"{k} {v}")
|
||||
@@ -490,7 +493,7 @@ class SplitwiseConnector:
|
||||
"""
|
||||
Handle prefill tasks from other nodes.
|
||||
"""
|
||||
|
||||
self.logger.debug(f"_handle_prefill function receive {tasks}")
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
|
||||
|
||||
@@ -498,21 +501,9 @@ class SplitwiseConnector:
|
||||
"""
|
||||
Handle decode tasks from other nodes.
|
||||
"""
|
||||
self.logger.debug(f"_handle_decode function receive {payload}")
|
||||
tasks = []
|
||||
for task in payload:
|
||||
tasks.append(
|
||||
RequestOutput(
|
||||
request_id=task["request_id"],
|
||||
outputs=CompletionOutput(
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||
),
|
||||
finished=True,
|
||||
num_cached_tokens=task["num_cached_tokens"],
|
||||
error_code=task["error_code"],
|
||||
error_msg=task["error_msg"],
|
||||
)
|
||||
)
|
||||
output = RequestOutput.from_dict(task)
|
||||
tasks.append(output)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
|
||||
@@ -982,6 +982,7 @@ api_server_logger = get_logger("api_server", "api_server.log")
|
||||
console_logger = get_logger("console", "console.log", print_to_console=True)
|
||||
spec_logger = get_logger("speculate", "speculate.log")
|
||||
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
|
||||
router_logger = get_logger("router", "router.log")
|
||||
|
||||
|
||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||
|
||||
@@ -475,8 +475,10 @@ class PaddleDisWorkerProc:
|
||||
|
||||
# Execute model to generate token. The generated token will be written to the buffer.
|
||||
# These generated tokens can be obtained through get_output op.
|
||||
start_execute_time = time.time()
|
||||
self.worker.execute_model(req_dicts, num_running_requests)
|
||||
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
|
||||
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
|
||||
|
||||
def initialize_kv_cache(self) -> None:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
|
||||
@@ -42,3 +42,5 @@ opentelemetry-instrumentation-fastapi
|
||||
partial_json_parser
|
||||
msgspec
|
||||
einops
|
||||
setproctitle
|
||||
aistudio_sdk
|
||||
|
||||
@@ -45,18 +45,24 @@ def setup_and_run_server():
|
||||
clean_ports()
|
||||
|
||||
print("log dir clean ")
|
||||
if os.path.exists("log") and os.path.isdir("log"):
|
||||
shutil.rmtree("log")
|
||||
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
|
||||
shutil.rmtree("log_prefill")
|
||||
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
|
||||
shutil.rmtree("log_decode")
|
||||
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "./ERNIE-4.5-0.3B-Paddle"
|
||||
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||
print(f"model_path: {model_path}")
|
||||
|
||||
# prefill实例
|
||||
print("start prefill...")
|
||||
env_prefill = os.environ.copy()
|
||||
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
|
||||
env_prefill["FD_LOG_DIR"] = "log_prefill"
|
||||
env_prefill["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT)
|
||||
prefill_log_path = "server.log"
|
||||
prefill_cmd = [
|
||||
@@ -94,12 +100,15 @@ def setup_and_run_server():
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_prefill,
|
||||
)
|
||||
time.sleep(3)
|
||||
|
||||
# decode实例
|
||||
print("start decode...")
|
||||
env_decode = os.environ.copy()
|
||||
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
|
||||
env_decode["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1)
|
||||
env_decode["FD_LOG_DIR"] = "decode_log"
|
||||
env_decode["FD_LOG_DIR"] = "log_decode"
|
||||
decode_log_path = "decode_server.log"
|
||||
decode_cmd = [
|
||||
sys.executable,
|
||||
@@ -125,6 +134,8 @@ def setup_and_run_server():
|
||||
"wint8",
|
||||
"--splitwise-role",
|
||||
"decode",
|
||||
"--innode-prefill-ports",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
@@ -260,18 +271,7 @@ def test_chat_usage_stream(api_url):
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
p_url, d_url = api_url
|
||||
|
||||
response = send_request(url=p_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||
print("Prefill Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
_, d_url = api_url # Only the decode server receives the request
|
||||
|
||||
response = send_request(url=d_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
@@ -302,16 +302,7 @@ def test_chat_usage_non_stream(api_url):
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
p_url, d_url = api_url
|
||||
|
||||
response = send_request(url=p_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
_, d_url = api_url
|
||||
|
||||
response = send_request(url=d_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
@@ -336,25 +327,13 @@ def test_non_chat_usage_stream(api_url):
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
p_url, d_url = api_url
|
||||
p_url = p_url.replace("chat/completions", "completions")
|
||||
_, d_url = api_url
|
||||
d_url = d_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=p_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
# print("Prefill Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
response = send_request(url=d_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
# print("Decode Response:", result)
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
@@ -375,23 +354,13 @@ def test_non_chat_usage_non_stream(api_url):
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
p_url, d_url = api_url
|
||||
p_url = p_url.replace("chat/completions", "completions")
|
||||
_, d_url = api_url
|
||||
d_url = d_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=p_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
# print("Prefill Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
response = send_request(url=d_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
|
||||
500
tests/e2e/test_ernie_03b_pd_multi_node.py
Normal file
500
tests/e2e/test_ernie_03b_pd_multi_node.py
Normal file
@@ -0,0 +1,500 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Read ports from environment variables; use default values if not set
|
||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
|
||||
FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533))
|
||||
|
||||
# List of ports to clean before and after tests
|
||||
PORTS_TO_CLEAN = [
|
||||
FD_API_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_CONNECTOR_PORT,
|
||||
FD_API_PORT + 1,
|
||||
FD_ENGINE_QUEUE_PORT + 1,
|
||||
FD_METRICS_PORT + 1,
|
||||
FD_CACHE_QUEUE_PORT + 1,
|
||||
FD_CONNECTOR_PORT + 1,
|
||||
FD_ROUTER_PORT,
|
||||
]
|
||||
|
||||
|
||||
def is_port_open(host: str, port: int, timeout=1.0):
|
||||
"""
|
||||
Check if a TCP port is open on the given host.
|
||||
Returns True if connection succeeds, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
Check the health status of a service.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
|
||||
timeout (int): Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
bool: True if the service is healthy, False otherwise.
|
||||
"""
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
url = f"{base_url.rstrip('/')}/health"
|
||||
try:
|
||||
resp = requests.get(url, timeout=timeout)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_registered_number(router_url) -> list:
|
||||
"""
|
||||
Get the number of registered models in the router.
|
||||
|
||||
Args:
|
||||
router_url (str): The base URL of the router, e.g. "http://localhost:8080".
|
||||
|
||||
Returns:
|
||||
int: The number of registered models.
|
||||
"""
|
||||
if not router_url.startswith("http"):
|
||||
router_url = f"http://{router_url}"
|
||||
|
||||
try:
|
||||
response = requests.get(f"{router_url}/registered_number", timeout=60)
|
||||
registered_numbers = response.json()
|
||||
return registered_numbers
|
||||
except Exception:
|
||||
return {"mixed": 0, "prefill": 0, "decode": 0}
|
||||
|
||||
|
||||
def kill_process_on_port(port: int):
|
||||
"""
|
||||
Kill processes that are listening on the given port.
|
||||
Uses `lsof` to find process ids and sends SIGKILL.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||
current_pid = os.getpid()
|
||||
parent_pid = os.getppid()
|
||||
for pid in output.splitlines():
|
||||
pid = int(pid)
|
||||
if pid in (current_pid, parent_pid):
|
||||
print(f"Skip killing current process (pid={pid}) on port {port}")
|
||||
continue
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
print(f"Killed process on port {port}, pid={pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
def clean_ports():
|
||||
"""
|
||||
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||
"""
|
||||
for port in PORTS_TO_CLEAN:
|
||||
kill_process_on_port(port)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts the API server as a subprocess
|
||||
- Waits for server port to open (up to 30 seconds)
|
||||
- Tears down server after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean_ports()
|
||||
|
||||
print("log dir clean ")
|
||||
if os.path.exists("log_router") and os.path.isdir("log_router"):
|
||||
shutil.rmtree("log_router")
|
||||
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
|
||||
shutil.rmtree("log_prefill")
|
||||
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
|
||||
shutil.rmtree("log_decode")
|
||||
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||
print(f"model_path: {model_path}")
|
||||
|
||||
# router
|
||||
print("start router...")
|
||||
env_router = os.environ.copy()
|
||||
env_router["FD_LOG_DIR"] = "log_router"
|
||||
router_log_path = "router.log"
|
||||
|
||||
router_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.router.launch",
|
||||
"--port",
|
||||
str(FD_ROUTER_PORT),
|
||||
"--splitwise",
|
||||
]
|
||||
|
||||
with open(router_log_path, "w") as logfile:
|
||||
process_router = subprocess.Popen(
|
||||
router_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_router,
|
||||
)
|
||||
|
||||
# prefill实例
|
||||
print("start prefill...")
|
||||
env_prefill = os.environ.copy()
|
||||
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
|
||||
env_prefill["FD_LOG_DIR"] = "log_prefill"
|
||||
env_prefill["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT)
|
||||
prefill_log_path = "server.log"
|
||||
prefill_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"20",
|
||||
"--quantization",
|
||||
"wint8",
|
||||
"--splitwise-role",
|
||||
"prefill",
|
||||
"--cache-transfer-protocol",
|
||||
"ipc",
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT),
|
||||
"--router",
|
||||
f"0.0.0.0:{FD_ROUTER_PORT}",
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
with open(prefill_log_path, "w") as logfile:
|
||||
process_prefill = subprocess.Popen(
|
||||
prefill_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_prefill,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# decode实例
|
||||
print("start decode...")
|
||||
env_decode = os.environ.copy()
|
||||
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
env_decode["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
|
||||
env_decode["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1)
|
||||
env_decode["FD_LOG_DIR"] = "log_decode"
|
||||
decode_log_path = "decode_server.log"
|
||||
decode_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT + 1),
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT + 1),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT + 1),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT + 1),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"20",
|
||||
"--quantization",
|
||||
"wint8",
|
||||
"--splitwise-role",
|
||||
"decode",
|
||||
"--cache-transfer-protocol",
|
||||
"ipc",
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT + 1),
|
||||
"--router",
|
||||
f"0.0.0.0:{FD_ROUTER_PORT}",
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
with open(decode_log_path, "w") as logfile:
|
||||
process_decode = subprocess.Popen(
|
||||
decode_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_decode,
|
||||
)
|
||||
|
||||
# Wait up to 300 seconds for API server to be ready
|
||||
for _ in range(60):
|
||||
registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}")
|
||||
if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1:
|
||||
print("Prefill and decode servers are both online")
|
||||
break
|
||||
time.sleep(5)
|
||||
else:
|
||||
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean_ports()
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process_router.pid, signal.SIGTERM)
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean_ports()
|
||||
print(f"Prefill server (pid={process_prefill.pid}) terminated")
|
||||
print(f"Decode server (pid={process_decode.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate API server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the API endpoint URL for chat completions.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def metrics_url(request):
|
||||
"""
|
||||
Returns the metrics endpoint URL.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
"""
|
||||
Returns common HTTP request headers.
|
||||
"""
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def test_metrics_config(metrics_url):
|
||||
timeout = 600
|
||||
url = metrics_url.replace("metrics", "config-info")
|
||||
res = requests.get(url, timeout=timeout)
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def send_request(url, payload, timeout=600):
|
||||
"""
|
||||
发送请求到指定的URL,并返回响应结果。
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||
print("🟢 接收响应中...\n")
|
||||
return res
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"❌ 请求超时(超过 {timeout} 秒)")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ 请求失败:{e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_chunks(response):
|
||||
"""解析流式返回,生成chunk List[dict]"""
|
||||
chunks = []
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
print(f"解析失败: {e}, 行内容: {line}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
print("返回内容:", response.text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def test_chat_usage_stream(api_url):
|
||||
"""测试流式chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = send_request(url=api_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_chat_usage_non_stream(api_url):
|
||||
"""测试非流式chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = send_request(url=api_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_stream(api_url):
|
||||
"""测试流式非chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
api_url = api_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=api_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_non_stream(api_url):
|
||||
"""测试非流式非chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
api_url = api_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=api_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
486
tests/e2e/test_ernie_03b_router.py
Normal file
486
tests/e2e/test_ernie_03b_router.py
Normal file
@@ -0,0 +1,486 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Test for router and mixed server
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Read ports from environment variables; use default values if not set
|
||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||
FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533))
|
||||
|
||||
# List of ports to clean before and after tests
|
||||
PORTS_TO_CLEAN = [
|
||||
FD_API_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_API_PORT + 1,
|
||||
FD_ENGINE_QUEUE_PORT + 1,
|
||||
FD_METRICS_PORT + 1,
|
||||
FD_CACHE_QUEUE_PORT + 1,
|
||||
FD_ROUTER_PORT,
|
||||
]
|
||||
|
||||
|
||||
def is_port_open(host: str, port: int, timeout=1.0):
|
||||
"""
|
||||
Check if a TCP port is open on the given host.
|
||||
Returns True if connection succeeds, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
Check the health status of a service.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
|
||||
timeout (int): Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
bool: True if the service is healthy, False otherwise.
|
||||
"""
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
url = f"{base_url.rstrip('/')}/health"
|
||||
try:
|
||||
resp = requests.get(url, timeout=timeout)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_registered_number(router_url) -> list:
|
||||
"""
|
||||
Get the number of registered models in the router.
|
||||
|
||||
Args:
|
||||
router_url (str): The base URL of the router, e.g. "http://localhost:8080".
|
||||
|
||||
Returns:
|
||||
int: The number of registered models.
|
||||
"""
|
||||
if not router_url.startswith("http"):
|
||||
router_url = f"http://{router_url}"
|
||||
|
||||
try:
|
||||
response = requests.get(f"{router_url}/registered_number", timeout=60)
|
||||
registered_numbers = response.json()
|
||||
return registered_numbers
|
||||
except Exception:
|
||||
return {"mixed": 0, "prefill": 0, "decode": 0}
|
||||
|
||||
|
||||
def kill_process_on_port(port: int):
|
||||
"""
|
||||
Kill processes that are listening on the given port.
|
||||
Uses `lsof` to find process ids and sends SIGKILL.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||
current_pid = os.getpid()
|
||||
parent_pid = os.getppid()
|
||||
for pid in output.splitlines():
|
||||
pid = int(pid)
|
||||
if pid in (current_pid, parent_pid):
|
||||
print(f"Skip killing current process (pid={pid}) on port {port}")
|
||||
continue
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
print(f"Killed process on port {port}, pid={pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
def clean_ports():
|
||||
"""
|
||||
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||
"""
|
||||
for port in PORTS_TO_CLEAN:
|
||||
kill_process_on_port(port)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts the API server as a subprocess
|
||||
- Waits for server port to open (up to 30 seconds)
|
||||
- Tears down server after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean_ports()
|
||||
|
||||
print("log dir clean ")
|
||||
if os.path.exists("log_router") and os.path.isdir("log_router"):
|
||||
shutil.rmtree("log_router")
|
||||
if os.path.exists("log_server_0") and os.path.isdir("log_server_0"):
|
||||
shutil.rmtree("log_server_0")
|
||||
if os.path.exists("log_server_1") and os.path.isdir("log_server_1"):
|
||||
shutil.rmtree("log_server_1")
|
||||
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||
print(f"model_path: {model_path}")
|
||||
|
||||
# router
|
||||
print("start router...")
|
||||
env_router = os.environ.copy()
|
||||
env_router["FD_LOG_DIR"] = "log_router"
|
||||
router_log_path = "router.log"
|
||||
|
||||
router_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.router.launch",
|
||||
"--port",
|
||||
str(FD_ROUTER_PORT),
|
||||
]
|
||||
|
||||
with open(router_log_path, "w") as logfile:
|
||||
process_router = subprocess.Popen(
|
||||
router_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_router,
|
||||
)
|
||||
|
||||
# server0
|
||||
print("start server0...")
|
||||
env_server_0 = os.environ.copy()
|
||||
env_server_0["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
env_server_0["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
|
||||
env_server_0["FD_LOG_DIR"] = "log_server_0"
|
||||
env_server_0["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT)
|
||||
log_path = "server_0.log"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"20",
|
||||
"--quantization",
|
||||
"wint8",
|
||||
"--router",
|
||||
f"0.0.0.0:{FD_ROUTER_PORT}",
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
with open(log_path, "w") as logfile:
|
||||
process_server_0 = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_server_0,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# server 1
|
||||
print("start server 1...")
|
||||
env_server_1 = os.environ.copy()
|
||||
env_server_1["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
env_server_1["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
|
||||
env_server_1["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1)
|
||||
env_server_1["FD_LOG_DIR"] = "log_server_1"
|
||||
log_path = "server_1.log"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT + 1),
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT + 1),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT + 1),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT + 1),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"20",
|
||||
"--quantization",
|
||||
"wint8",
|
||||
"--router",
|
||||
f"0.0.0.0:{FD_ROUTER_PORT}",
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
with open(log_path, "w") as logfile:
|
||||
process_server_1 = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_server_1,
|
||||
)
|
||||
|
||||
# Wait up to 300 seconds for API server to be ready
|
||||
for _ in range(60):
|
||||
registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}")
|
||||
if registered_numbers["mixed"] >= 2:
|
||||
print("Mixed servers are both online")
|
||||
break
|
||||
time.sleep(5)
|
||||
else:
|
||||
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process_server_0.pid, signal.SIGTERM)
|
||||
os.killpg(process_server_1.pid, signal.SIGTERM)
|
||||
clean_ports()
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process_router.pid, signal.SIGTERM)
|
||||
os.killpg(process_server_0.pid, signal.SIGTERM)
|
||||
os.killpg(process_server_1.pid, signal.SIGTERM)
|
||||
clean_ports()
|
||||
print(f"server (pid={process_server_0.pid}) terminated")
|
||||
print(f"server (pid={process_server_1.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate API server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the API endpoint URL for chat completions.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def metrics_url(request):
|
||||
"""
|
||||
Returns the metrics endpoint URL.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
"""
|
||||
Returns common HTTP request headers.
|
||||
"""
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def test_metrics_config(metrics_url):
|
||||
timeout = 600
|
||||
url = metrics_url.replace("metrics", "config-info")
|
||||
res = requests.get(url, timeout=timeout)
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def send_request(url, payload, timeout=600):
|
||||
"""
|
||||
发送请求到指定的URL,并返回响应结果。
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||
print("🟢 接收响应中...\n")
|
||||
return res
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"❌ 请求超时(超过 {timeout} 秒)")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ 请求失败:{e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_chunks(response):
|
||||
"""解析流式返回,生成chunk List[dict]"""
|
||||
chunks = []
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
print(f"解析失败: {e}, 行内容: {line}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
print("返回内容:", response.text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def test_chat_usage_stream(api_url):
|
||||
"""测试流式chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = send_request(url=api_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||
print("Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_chat_usage_non_stream(api_url):
|
||||
"""测试非流式chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = send_request(url=api_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_stream(api_url):
|
||||
"""测试流式非chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
api_url = api_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=api_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
print("Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_non_stream(api_url):
|
||||
"""测试非流式非chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
api_url = api_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=api_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
print("Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
Reference in New Issue
Block a user