[Feature] [PD] add simple router and refine splitwise deployment (#4709)

* add simple router and refine splitwise deployment

* fix
This commit is contained in:
Juncai
2025-11-06 14:56:02 +08:00
committed by GitHub
parent 831266da7a
commit 08ca0f6aea
39 changed files with 2397 additions and 171 deletions

View File

@@ -94,10 +94,11 @@ async def async_request_eb_openai_chat_completions(
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
}
},
"max_tokens": request_func_input.output_len,
}
if request_func_input.response_format:
payload["response_format"] =request_func_input.response_format
payload["response_format"] = request_func_input.response_format
# 超参由yaml传入
payload.update(request_func_input.hyper_parameters)
@@ -132,13 +133,13 @@ async def async_request_eb_openai_chat_completions(
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
#print("####chunk:", chunk, type(chunk))
# print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter()
data = json.loads(chunk)
if request_id == "None" and "id" in data:
request_id = data["id"]
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
reason_content = choices[0]["delta"].get("reasoning_content")
@@ -164,7 +165,6 @@ async def async_request_eb_openai_chat_completions(
elif usage := data.get("usage", {}):
output.output_tokens = usage.get("completion_tokens", 0)
output.prompt_tokens = usage.get("prompt_tokens", 0)
most_recent_timestamp = timestamp

View File

@@ -46,7 +46,7 @@ class SampleRequest:
prompt_len: int
expected_output_len: int
response_format: Optional[dict] = None
class BenchmarkDataset(ABC):
"""BenchmarkDataset"""
@@ -299,7 +299,7 @@ class EBChatDataset(BenchmarkDataset):
prompt = entry["messages"][-1].get("content", "")
history_QA = entry.get("messages", [])
response_format = entry.get("response_format")
new_output_len = int(entry.get("max_tokens", 12288))
new_output_len = int(entry.get("max_tokens", output_len if output_len else 12288))
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(prompt, None)
@@ -311,7 +311,7 @@ class EBChatDataset(BenchmarkDataset):
prompt_len=0,
history_QA=history_QA,
expected_output_len=new_output_len,
response_format=response_format
response_format=response_format,
)
)
cnt += 1

View File

@@ -352,7 +352,7 @@ async def benchmark(
ignore_eos=ignore_eos,
debug=debug,
extra_body=extra_body,
response_format=response_format
response_format=response_format,
)
print("test_input:", test_input)
@@ -384,7 +384,7 @@ async def benchmark(
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body,
response_format=response_format
response_format=response_format,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@@ -444,7 +444,7 @@ async def benchmark(
debug=debug,
ignore_eos=ignore_eos,
extra_body=extra_body,
response_format=response_format
response_format=response_format,
)
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
@@ -460,7 +460,7 @@ async def benchmark(
api_url=base_url + "/stop_profile",
output_len=test_output_len,
logprobs=logprobs,
response_format=response_format
response_format=response_format,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:

View File

@@ -3,4 +3,4 @@ max_num_seqs: 128
gpu_memory_utilization: 0.85
tensor_parallel_size: 1
limit_mm_per_prompt: '{"image": 100, "video": 100}'
enable_mm: True
enable_mm: True

View File

@@ -5,4 +5,4 @@ metadata:
max_tokens: 32768
repetition_penalty: 1.05
frequency_penalty: 0
presence_penalty: 0
presence_penalty: 0

View File

@@ -26,7 +26,7 @@ We recommend using mpirun for one-command startup without manually starting each
4. Ensure all nodes can resolve each other's hostnames
* Online inference startup example:
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -40,7 +40,7 @@ We recommend using mpirun for one-command startup without manually starting each
```
* Offline startup example:
```python
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM

View File

@@ -26,7 +26,7 @@
4. 确保所有节点能够解析彼此的主机名
* 在线推理启动示例:
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -40,7 +40,7 @@
```
* 离线启动示例:
```python
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM

View File

@@ -0,0 +1,71 @@
#!/bin/bash
set -e
wait_for_health() {
local server_port=$1
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
break
else
echo "Service not ready. Retrying in 2s..."
sleep 2
fi
done
}
# prepare environment
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
unset http_proxy && unset https_proxy
rm -rf log_*
# start router
export FD_LOG_DIR="log_router"
mkdir -p ${FD_LOG_DIR}
router_port=9000
nohup python -m fastdeploy.router.launch \
--port ${router_port} \
2>&1 >${FD_LOG_DIR}/nohup &
sleep 1
# start modelserver 0
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log_server_0"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8100 \
--metrics-port 8101 \
--engine-worker-queue-port 8102 \
--cache-queue-port 8103 \
--max-model-len 32768 \
--router "0.0.0.0:${router_port}" \
2>&1 >${FD_LOG_DIR}/nohup &
sleep 1
wait_for_health 8100
# start modelserver 1
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log_server_1"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8200 \
--metrics-port 8201 \
--engine-worker-queue-port 8202 \
--cache-queue-port 8203 \
--max-model-len 32768 \
--router "0.0.0.0:${router_port}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8200

View File

@@ -0,0 +1,66 @@
#!/bin/bash
set -e
# Test splitwise deployment
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
wait_for_health() {
local server_port=$1
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
break
else
echo "Service not ready. Retrying in 2s..."
sleep 2
fi
done
}
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
aistudio download --model ${MODEL_NAME}
unset http_proxy && unset https_proxy
rm -rf log_*
# start prefill
export FD_LOG_DIR="log_prefill"
mkdir -p ${FD_LOG_DIR}
export CUDA_VISIBLE_DEVICES=0
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8100 \
--metrics-port 8101 \
--engine-worker-queue-port 8102 \
--cache-queue-port 8103 \
--max-model-len 32768 \
--splitwise-role "prefill" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8100
# start decode
export FD_LOG_DIR="log_decode"
mkdir -p ${FD_LOG_DIR}
export CUDA_VISIBLE_DEVICES=1
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 9000 \
--metrics-port 9001 \
--engine-worker-queue-port 9002 \
--cache-queue-port 9003 \
--max-model-len 32768 \
--splitwise-role "decode" \
--innode-prefill-ports 8102 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 9000

View File

@@ -0,0 +1,96 @@
#!/bin/bash
set -e
# Test splitwise deployment
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
wait_for_health() {
local server_port=$1
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
break
else
echo "Service not ready. Retrying in 2s..."
sleep 2
fi
done
}
# prepare environment
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
rm -rf log_*
# start redis
if ! redis-cli ping &>/dev/null; then
echo "Redis is not running. Starting redis-server..."
redis-server --daemonize yes
sleep 1
else
echo "Redis is already running."
fi
sleep 1
# start prefill
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log_prefill"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8100 \
--metrics-port 8101 \
--engine-worker-queue-port 8102 \
--cache-queue-port 8103 \
--max-model-len 32768 \
--splitwise-role "prefill" \
--cache-transfer-protocol "rdma,ipc" \
--rdma-comm-ports 8104 \
--pd-comm-port 8105 \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port 6379 \
--scheduler-ttl 9000 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8100
# start decode
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log_decode"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 9000 \
--metrics-port 9001 \
--engine-worker-queue-port 9002 \
--cache-queue-port 9003 \
--max-model-len 32768 \
--splitwise-role "decode" \
--cache-transfer-protocol "rdma,ipc" \
--rdma-comm-ports 9004 \
--pd-comm-port 9005 \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port 6379 \
--scheduler-ttl 9000 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 9000

View File

@@ -0,0 +1,98 @@
#!/bin/bash
set -e
# Test splitwise deployment
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
wait_for_health() {
local server_port=$1
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
break
else
echo "Service not ready. Retrying in 2s..."
sleep 2
fi
done
}
# prepare environment
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
rm -rf log_*
# start redis
if ! redis-cli ping &>/dev/null; then
echo "Redis is not running. Starting redis-server..."
redis-server --daemonize yes
sleep 1
else
echo "Redis is already running."
fi
sleep 1
# start prefill
export CUDA_VISIBLE_DEVICES=0,1
export FD_LOG_DIR="log_prefill"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8100 \
--metrics-port 8101 \
--engine-worker-queue-port 8102 \
--cache-queue-port 8103 \
--max-model-len 32768 \
--tensor-parallel-size 2 \
--splitwise-role "prefill" \
--cache-transfer-protocol "rdma,ipc" \
--pd-comm-port 8104 \
--rdma-comm-ports 8105,8106 \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port 6379 \
--scheduler-ttl 9000 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8100
# start decode
export CUDA_VISIBLE_DEVICES=2,3
export FD_LOG_DIR="log_decode"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 9000 \
--metrics-port 9001 \
--engine-worker-queue-port 9002 \
--cache-queue-port 9003 \
--max-model-len 32768 \
--tensor-parallel-size 2 \
--splitwise-role "decode" \
--cache-transfer-protocol "rdma,ipc" \
--pd-comm-port 9004 \
--rdma-comm-ports 9005,9006 \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port 6379 \
--scheduler-ttl 9000 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 9000

View File

@@ -0,0 +1,93 @@
#!/bin/bash
set -e
# Test splitwise deployment
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
wait_for_health() {
local server_port=$1
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
break
else
echo "Service not ready. Retrying in 2s..."
sleep 2
fi
done
}
# prepare environment
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
rm -rf log_*
# start router
export FD_LOG_DIR="log_router"
mkdir -p ${FD_LOG_DIR}
router_port=9000
nohup python -m fastdeploy.router.launch \
--port ${router_port} \
--splitwise \
2>&1 >${FD_LOG_DIR}/nohup &
sleep 1
# start prefill
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log_prefill"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8100 \
--metrics-port 8101 \
--engine-worker-queue-port 8102 \
--cache-queue-port 8103 \
--max-model-len 32768 \
--splitwise-role "prefill" \
--cache-transfer-protocol "ipc,rdma" \
--rdma-comm-ports 8104 \
--pd-comm-port 8105 \
--router "0.0.0.0:${router_port}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8100
# start decode
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log_decode"
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8200 \
--metrics-port 8201 \
--engine-worker-queue-port 8202 \
--cache-queue-port 8203 \
--max-model-len 32768 \
--splitwise-role "decode" \
--cache-transfer-protocol "ipc,rdma" \
--rdma-comm-ports 8204 \
--pd-comm-port 8205 \
--router "0.0.0.0:${router_port}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8200

View File

@@ -0,0 +1,96 @@
#!/bin/bash
set -e
# Test splitwise deployment
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
wait_for_health() {
local server_port=$1
while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
break
else
echo "Service not ready. Retrying in 2s..."
sleep 2
fi
done
}
# prepare environment
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
# MODEL_NAME="baidu/ERNIE-4.5-21B-A3B-Paddle"
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=0
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
rm -rf log_*
# start router
export FD_LOG_DIR="log_router"
mkdir -p ${FD_LOG_DIR}
echo "start router"
router_port=9000
nohup python -m fastdeploy.router.launch \
--port ${router_port} \
--splitwise \
2>&1 >${FD_LOG_DIR}/nohup &
sleep 1
# start prefill
export CUDA_VISIBLE_DEVICES=0,1
export FD_LOG_DIR="log_prefill"
mkdir -p ${FD_LOG_DIR}
echo "start prefill"
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8100 \
--metrics-port 8101 \
--engine-worker-queue-port 8102 \
--cache-queue-port 8103 \
--tensor-parallel-size 2 \
--max-model-len 32768 \
--splitwise-role "prefill" \
--pd-comm-port 8104 \
--rdma-comm-ports 8105,8106 \
--router "0.0.0.0:${router_port}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8100
# start decode
export CUDA_VISIBLE_DEVICES=2,3
export FD_LOG_DIR="log_decode"
mkdir -p ${FD_LOG_DIR}
echo "start decode"
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port 8200 \
--metrics-port 8201 \
--engine-worker-queue-port 8202 \
--cache-queue-port 8203 \
--max-model-len 32768 \
--tensor-parallel-size 2 \
--splitwise-role "decode" \
--pd-comm-port 8204 \
--rdma-comm-ports 8205,8206 \
--router "0.0.0.0:${router_port}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health 8200

View File

@@ -0,0 +1,7 @@
pkill -9 -f python
pkill -9 -f fastdeploy
pkill -f -9 gunicorn
if redis-cli ping >/dev/null 2>&1; then
redis-cli shutdown
fi

View File

@@ -0,0 +1,20 @@
#!/bin/bash
# using v0 version, the request must be sent to the decode instance
# using v1 version, the request can be sent to the prefill or decode instance
# using v2 version, the request must be sent to the router
port=${1:-9000}
echo "port: ${port}"
unset http_proxy && unset https_proxy
curl -X POST "http://0.0.0.0:${port}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Introduce shenzhen"}
],
"max_tokens": 20,
"stream": true
}'

View File

@@ -344,6 +344,8 @@ class CacheMessager:
)
item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers:
if "error" not in item["status"]:
item["status"] = "finished"
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
@@ -359,7 +361,7 @@ class CacheMessager:
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
task, _ = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
@@ -376,7 +378,8 @@ class CacheMessager:
self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
time.sleep(0.001)
logger.error(f"handle_connect_task has exception: {e}, {str(traceback.format_exc())}")
class CacheMessagerV1:

View File

@@ -1310,6 +1310,24 @@ class CacheConfig:
logger.info("=============================================================")
class RouterConfig:
"""
Configuration for router
Attributes:
router: the url of router, such as http://127.0.0.1:8000
api_server_host: the host ip of model server
api_server_port: the http port of model server
"""
def __init__(self, args: dict):
self.router = args["router"]
if self.router is not None and not self.router.startswith(("http://", "https://")):
self.router = f"http://{self.router}"
self.api_server_host = get_host_ip()
self.api_server_port = args["port"]
class CommitConfig:
"""
Configuration for tracking version information from version.txt
@@ -1411,6 +1429,7 @@ class FDConfig:
speculative_config: SpeculativeConfig = None,
eplb_config: EPLBConfig = None,
structured_outputs_config: StructuredOutputsConfig = None,
router_config: RouterConfig = None,
tokenizer: str = None,
ips: str = None,
use_warmup: bool = False,
@@ -1438,6 +1457,7 @@ class FDConfig:
self.cache_config: CacheConfig = cache_config # type: ignore
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
self.router_config: RouterConfig = router_config
# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs
@@ -1517,6 +1537,7 @@ class FDConfig:
self.read_from_config()
self.postprocess()
self.init_cache_info()
if test_mode:
return
self.check()
@@ -1734,29 +1755,66 @@ class FDConfig:
"""
initialize cache info
"""
disaggregate_info = {}
# TODO: group the splitiwse params, remove code of v0
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise or dp scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
self.splitwise_version = None
if self.scheduler_config.name == "local" and (self.router_config is None or self.router_config.router is None):
self.splitwise_version = "v0"
elif self.scheduler_config.name in ("splitwise", "dp"):
self.splitwise_version = "v1"
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
self.splitwise_version = "v2"
else:
raise ValueError(
f"Unsupported scheduler mode, scheduler_name: {self.scheduler_config.name}, "
f"router_config: {self.router_config}"
)
logger.info(f"splitwise_version: {self.splitwise_version}")
if isinstance(self.parallel_config.engine_worker_queue_port, (int, str)):
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port
else:
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
]
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None
self.disaggregate_info = {}
if self.scheduler_config.splitwise_role != "mixed":
disaggregate_info["role"] = self.scheduler_config.splitwise_role
disaggregate_info["cache_info"] = dict()
self.disaggregate_info["role"] = self.scheduler_config.splitwise_role
self.disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
disaggregate_info["transfer_protocol"] = current_protocol
self.disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol:
if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = {
self.disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
],
"port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
}
elif protocol == "rdma":
disaggregate_info["cache_info"][protocol] = {
self.disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.cache_config.pd_comm_port[0],
"port": connector_port,
"rdma_port": self.cache_config.rdma_comm_ports,
}
self.disaggregate_info = disaggregate_info
logger.info(f"disaggregate_info: {self.disaggregate_info}")
logger.info(f"disaggregate_info: {self.disaggregate_info}")
if self.router_config:
self.register_info = {
"role": self.scheduler_config.splitwise_role,
"host_ip": self.host_ip,
"port": self.router_config.api_server_port,
"connector_port": connector_port,
"rdma_ports": self.cache_config.rdma_comm_ports,
"engine_worker_queue_port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
}
logger.info(f"register_info: {self.register_info}")
def read_from_config(self):
"""

View File

@@ -34,6 +34,7 @@ from fastdeploy.config import (
ParallelConfig,
PlasAttentionConfig,
PoolerConfig,
RouterConfig,
RunnerOption,
SpeculativeConfig,
StructuredOutputsConfig,
@@ -74,6 +75,10 @@ class EngineArgs:
"""
The name or path of the model to be used.
"""
port: Optional[str] = None
"""
Port for api server.
"""
served_model_name: Optional[str] = None
"""
The name of the model being served.
@@ -445,6 +450,11 @@ class EngineArgs:
- To enable custom logits processors, add your dotted paths to module and class names to the list.
"""
router: Optional[str] = None
"""
Url for router server, such as `0.0.0.0:30000`.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -859,21 +869,6 @@ class EngineArgs:
help="Flag to enable prefix caching.",
)
perf_group.add_argument(
"--splitwise-role",
type=str,
default=EngineArgs.splitwise_role,
help="Role of splitwise. Default is \
'mixed'. (prefill, decode, mixed)",
)
perf_group.add_argument(
"--innode-prefill-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.innode_prefill_ports,
help="port for innode prefill",
)
perf_group.add_argument(
"--enable-chunked-prefill",
action="store_true",
@@ -903,27 +898,53 @@ class EngineArgs:
help=("For chunked prefill, the threshold number of" " tokens for a prompt to be considered long."),
)
perf_group.add_argument(
# Splitwise deployment parameters group
splitwise_group = parser.add_argument_group("Splitwise Deployment")
splitwise_group.add_argument(
"--splitwise-role",
type=str,
default=EngineArgs.splitwise_role,
help="Role of splitwise. Default is \
'mixed'. (prefill, decode, mixed)",
)
splitwise_group.add_argument(
"--innode-prefill-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.innode_prefill_ports,
help="port for innode prefill, only used in single machine splitwise deployment",
)
splitwise_group.add_argument(
"--cache-transfer-protocol",
type=str,
default=EngineArgs.cache_transfer_protocol,
help="support protocol list, comma separated, default is ipc",
help="support protocol list (ipc or rdma), comma separated, default is ipc",
)
perf_group.add_argument(
splitwise_group.add_argument(
"--pd-comm-port",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.pd_comm_port,
help="port for splitwise communication.",
)
perf_group.add_argument(
splitwise_group.add_argument(
"--rdma-comm-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.rdma_comm_ports,
help="ports for rdma communication.",
)
# Router parameters group
router_group = parser.add_argument_group("Router")
router_group.add_argument(
"--router",
type=str,
default=EngineArgs.router,
help="url for router server.",
)
# Scheduler parameters group
scheduler_group = parser.add_argument_group("Scheduler")
scheduler_group.add_argument(
@@ -1044,7 +1065,11 @@ class EngineArgs:
"""
Create an instance of EngineArgs from command line arguments.
"""
return cls(**{field.name: getattr(args, field.name) for field in dataclass_fields(cls)})
args_dict = {}
for field in dataclass_fields(cls):
if hasattr(args, field.name):
args_dict[field.name] = getattr(args, field.name)
return cls(**args_dict)
def create_speculative_config(self) -> SpeculativeConfig:
""" """
@@ -1063,6 +1088,7 @@ class EngineArgs:
prefix_len = len(prefix)
all = asdict(self)
all.pop("port") # port and scheduler_port are not the same
params = dict()
for k, v in all.items():
if k[:prefix_len] == prefix:
@@ -1151,6 +1177,7 @@ class EngineArgs:
scheduler_cfg = self.create_scheduler_config()
graph_opt_cfg = self.create_graph_optimization_config()
plas_attention_config = self.create_plas_attention_config()
router_config = RouterConfig(all_dict)
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
@@ -1170,6 +1197,7 @@ class EngineArgs:
speculative_config=speculative_cfg,
eplb_config=eplb_cfg,
structured_outputs_config=structured_outputs_config,
router_config=router_config,
ips=self.ips,
use_warmup=self.use_warmup,
limit_mm_per_prompt=self.limit_mm_per_prompt,

View File

@@ -23,10 +23,11 @@ import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
import requests
import zmq
from opentelemetry import trace
@@ -45,6 +46,7 @@ from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.router.utils import check_service_health
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import (
@@ -95,6 +97,7 @@ class EngineService:
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.llm_logger.info("Use V1 KVCache Scheduler")
self.resource_manager = ResourceManagerV1(
cfg.scheduler_config.max_num_seqs,
cfg,
@@ -103,6 +106,7 @@ class EngineService:
cfg.parallel_config.local_data_parallel_id,
)
else:
self.llm_logger.info("Use V0 KVCache Scheduler")
self.resource_manager = ResourceManager(
cfg.scheduler_config.max_num_seqs,
cfg,
@@ -118,7 +122,6 @@ class EngineService:
]
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
self.waiting_requests = []
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
@@ -149,14 +152,18 @@ class EngineService:
def start(self):
self.running = True
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
self.insert_task_to_worker_thread = threading.Thread(
target=self._schedule_request_to_worker_v1, daemon=True
)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.split_mode_get_tasks()
self._process_splitwise_task()
self._register_to_router()
def create_data_processor(self):
self.input_processor = InputPreprocessor(
@@ -313,7 +320,7 @@ class EngineService:
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks, current_id=-1, allocated=False):
def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1, allocated=False):
"""
Insert tasks to engine.
"""
@@ -358,6 +365,7 @@ class EngineService:
current_tasks.append(cur_task)
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
self.llm_logger.debug(f"put task to engine worker queue, task:{current_tasks}")
return True
self.resource_manager.check_and_free_block_tables()
@@ -574,7 +582,7 @@ class EngineService:
patch_st += chunk_patch_num
request.set("prefill_chunk_info", chunks_info)
def _insert_task_to_worker(self):
def _schedule_request_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
@@ -619,9 +627,12 @@ class EngineService:
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.splitwise_version == "v2" and self.cfg.scheduler_config.splitwise_role == "decode":
# the task in decode instance will processed in _process_splitwise_task thread
continue
llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
@@ -636,7 +647,7 @@ class EngineService:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
self.llm_logger.error(err_msg)
def _scheduler_task_to_worker_v1(self):
def _schedule_request_to_worker_v1(self):
"""
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
@@ -664,6 +675,7 @@ class EngineService:
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
@@ -822,6 +834,7 @@ class EngineService:
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.scheduler_config.splitwise_role == "decode":
return
while self.running:
try:
block = True if len(added_requests) == 0 else False
@@ -975,17 +988,38 @@ class EngineService:
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def split_mode_get_tasks(self):
def _process_splitwise_task(self):
"""
Split mode get tasks
Processing tasks from engine worker queue in splitwise deployment.
For v0 version, prefill instance gets tasks from engine worker queue.
For v1 and v2 version, decode instance gets raw tasks from engine worker queue to preallocate resources,
and decode instance gets prefilled tasks from engine worker queue to generate tokens.
TODO: unifiy the communication between decode and prefill instances.
"""
def receiver_loop():
waiting_resource_requests = []
waiting_ready_tasks = []
# Waiting for the api_server and scheduler in decode to
# receive the request sent by the client
def _decode_process_prefilled_task_v0_scheduler(input_tasks):
ready_tasks = []
waiting_tasks = []
for task in input_tasks:
if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id):
ready_tasks.append(task)
else:
waiting_tasks.append(task)
self.insert_tasks(ready_tasks, allocated=True)
if self.cfg.splitwise_version in ("v0", "v2"):
self.scheduler.put_results(ready_tasks)
return waiting_tasks
while self.running:
try:
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
for idx, task in enumerate(waiting_resource_requests):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
@@ -1004,21 +1038,27 @@ class EngineService:
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
waiting_resource_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty():
waiting_ready_tasks = _decode_process_prefilled_task_v0_scheduler(waiting_ready_tasks)
if self.engine_worker_queue.disaggregate_queue_empty():
time.sleep(0.001)
else:
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
# prefill instance gets tasks from engine worker queue
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
# decode instance gets tasks from engine worker queue
elif role == "decode":
if hasattr(tasks[0], "finished"):
if isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
@@ -1057,13 +1097,12 @@ class EngineService:
self.resource_manager.insert_task_for_decoding(task)
else:
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
waiting_ready_tasks.extend(_decode_process_prefilled_task_v0_scheduler(tasks))
elif isinstance(tasks[0], Request):
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
if len(waiting_resource_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
waiting_resource_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
@@ -1087,13 +1126,12 @@ class EngineService:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
waiting_resource_requests.extend(new_waiting)
self.llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
else:
time.sleep(0.001)
else:
raise ValueError(f"Unsupported task type: {type(tasks[0])}")
except Exception as e:
self.llm_logger.error(f"Error in main loop: {e}")
@@ -1130,6 +1168,42 @@ class EngineService:
llm_logger.error(f"Clear data error: {e}")
return False
def _register_to_router(self):
"""If use router, register this server to router"""
timeout = 5
sleep_seconds = 10
def _register():
while True:
try:
time.sleep(sleep_seconds)
api_server_host = self.cfg.router_config.api_server_host
api_server_port = self.cfg.router_config.api_server_port
api_server_url = f"http://{api_server_host}:{api_server_port}"
if not check_service_health(api_server_url):
continue
router_url = self.cfg.router_config.router
resp = requests.post(
f"{router_url}/register",
json=self.cfg.register_info,
timeout=timeout,
)
if not resp.ok:
llm_logger.error(
f"Router registration failed: {resp.status_code}, "
f"{resp.text}, {self.cfg.register_info}"
)
except requests.exceptions.RequestException as e:
llm_logger.error(f"Register to router request error: {e}")
except Exception as e:
llm_logger.exception(f"Unexpected error during router registration: {e}")
if self.cfg.router_config.router is not None:
register_thread = threading.Thread(target=_register, daemon=True)
register_thread.start()
def _exit_sub_services(self):
"""
exit sub services

View File

@@ -694,8 +694,6 @@ class LLMEngine:
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.init_cache_info()
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info

View File

@@ -527,6 +527,8 @@ class RequestOutput:
f"num_input_image_tokens={self.num_input_image_tokens}, "
f"num_input_video_tokens={self.num_input_video_tokens}, "
f"metrics={self.metrics}, "
f"error_code={self.error_code}, "
f"error_msg={self.error_msg},"
)
@classmethod

View File

@@ -451,6 +451,8 @@ class CompletionRequest(BaseModel):
temperature: Optional[float] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0, le=1)
user: Optional[str] = None
request_id: Optional[str] = None
disaggregate_info: Optional[dict] = None
# doc: begin-completion-sampling-params
top_k: Optional[int] = None
@@ -486,8 +488,6 @@ class CompletionRequest(BaseModel):
dict: request parameters in dict format
"""
req_dict = {}
if request_id is not None:
req_dict["request_id"] = request_id
# parse request model into dict
if self.suffix is not None:
@@ -497,6 +497,8 @@ class CompletionRequest(BaseModel):
if value is not None:
req_dict[key] = value
if request_id is not None:
req_dict["request_id"] = request_id
if prompt is not None:
req_dict["prompt"] = prompt
@@ -604,6 +606,8 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
metadata: Optional[dict] = None
response_format: Optional[AnyResponseFormat] = None
request_id: Optional[str] = None
disaggregate_info: Optional[dict] = None
# doc: begin-chat-completion-sampling-params
top_k: Optional[int] = None
@@ -644,8 +648,6 @@ class ChatCompletionRequest(BaseModel):
dict: request parameters in dict format
"""
req_dict = {}
if request_id is not None:
req_dict["request_id"] = request_id
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
@@ -666,6 +668,9 @@ class ChatCompletionRequest(BaseModel):
if value is not None:
req_dict[key] = value
if request_id is not None:
req_dict["request_id"] = request_id
if "prompt_token_ids" in req_dict:
if "messages" in req_dict:
del req_dict["messages"]

View File

@@ -114,7 +114,11 @@ class OpenAIServingChat:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
api_server_logger.info(f"current {self.engine_client.semaphore.status()}")
if request.user is not None:
if request.request_id is not None:
request_id = request.request_id
if not request_id.startswith("chatcmpl-"):
request_id = f"chatcmpl-{request_id}"
elif request.user is not None:
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
else:
request_id = f"chatcmpl-{uuid.uuid4()}"

View File

@@ -85,7 +85,11 @@ class OpenAIServingCompletion:
error=ErrorInfo(message=err_msg, type=ErrorType.INTERNAL_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
)
created_time = int(time.time())
if request.user is not None:
if request.request_id is not None:
request_id = request.request_id
if not request_id.startswith("cmpl-"):
request_id = f"cmpl-{request_id}"
elif request.user is not None:
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
else:
request_id = f"cmpl-{uuid.uuid4()}"

View File

@@ -662,7 +662,10 @@ class EngineWorkerQueue:
self.client_read_info_flag[:] = [0] * self.num_client
self.cache_infos.extend(cache_info)
llm_logger.debug(f"cache_infos: {self.cache_infos} local_data_parallel_id:{self.local_data_parallel_id}")
llm_logger.debug(
f"put cache_infos to engine worker queue: {self.cache_infos}, "
f"local_data_parallel_id:{self.local_data_parallel_id}"
)
self.lock_info.release()
def get_cache_info(self) -> List[Any]:
@@ -684,7 +687,10 @@ class EngineWorkerQueue:
self.cache_infos[:] = list()
self.lock_info.release()
if len(cache_infos) != 0:
llm_logger.debug(f"get cache infos: {cache_infos} local_data_parallel_id:{self.local_data_parallel_id}")
llm_logger.debug(
f"get cache infos from engine worker queue: {cache_infos}, "
f"local_data_parallel_id:{self.local_data_parallel_id}"
)
return cache_infos
def num_cache_infos(self) -> int:

View File

@@ -456,6 +456,7 @@ class TokenProcessor:
recycle resources
"""
if is_prefill:
start_time = time.time()
while True:
finished_task_ids = self.engine_worker_queue.get_finished_req()
if len(finished_task_ids) > 0:
@@ -474,6 +475,9 @@ class TokenProcessor:
if self.prefill_result_status[task_id] != "finished":
result.error_code = 400
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
llm_logger.info(
f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}"
)
self.split_connector.send_first_token(task.disaggregate_info, [result])
break
else:
@@ -731,11 +735,10 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill)
break
if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
if not (is_prefill and self.cfg.splitwise_version == "v0"):
# NOTE: prefill instance in v0 version does not return result to scheduler
llm_logger.debug(f"get response from infer: {result}")
batch_result.append(result)
self.postprocess(batch_result, mtype)

View File

@@ -0,0 +1,15 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -0,0 +1,58 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
from fastdeploy.router.router import start_router
from fastdeploy.utils import router_logger as logger
def main() -> None:
parser = argparse.ArgumentParser(description="Router for splitwise deployment testing")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host address to bind the router server.",
)
parser.add_argument(
"--port",
type=int,
default="9000",
help="Port number to bind the router server",
)
parser.add_argument(
"--splitwise",
action="store_true",
help="Router uses splitwise deployment",
)
parser.add_argument(
"--request-timeout-secs",
type=int,
default=1800,
help="Request timeout in seconds",
)
args = parser.parse_args()
try:
start_router(args)
except Exception as e:
logger.error(f"Error starting router: {e}")
raise e
if __name__ == "__main__":
main()

317
fastdeploy/router/router.py Normal file
View File

@@ -0,0 +1,317 @@
"""
Async Router server for FastDeploy.
Handles client requests and manages prefill/decode/mixed instances.
This module references the router implementation of slglang and vllm.
"""
import asyncio
import random
from itertools import chain
from uuid import uuid4
import aiohttp
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from fastdeploy.router.utils import (
InstanceInfo,
InstanceRole,
check_service_health_async,
)
from fastdeploy.utils import router_logger as logger
app = FastAPI()
class Router:
"""
Router class that handles requests from client and
collects prefill/decode instance information
"""
def __init__(self, args):
self.args = args
self.host = args.host
self.port = args.port
self.splitwise = args.splitwise
self.timeout = args.request_timeout_secs
self.mixed_servers = []
self.prefill_servers = []
self.decode_servers = []
self.lock = asyncio.Lock() # async-safe lock
async def register_instance(self, instance_info_dict: dict):
"""Register an instance asynchronously"""
try:
inst_info = InstanceInfo(**instance_info_dict)
except Exception as e:
logger.error(f"register instance failed: {e}")
raise
if (self.splitwise and inst_info.role == InstanceRole.MIXED) or (
not self.splitwise and inst_info.role != InstanceRole.MIXED
):
raise ValueError(f"Invalid instance role: {inst_info.role}, splitwise: {self.splitwise}")
if not await check_service_health_async(inst_info.url()):
raise RuntimeError(f"Instance {inst_info} is not healthy")
async with self.lock:
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers:
self.mixed_servers.append(inst_info)
logger.info(
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}"
)
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers:
self.prefill_servers.append(inst_info)
logger.info(
f"Register prefill instance success: {inst_info}, "
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
)
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers:
self.decode_servers.append(inst_info)
logger.info(
f"Register decode instance success: {inst_info}, "
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
)
async def registered_number(self):
"""Get number of registered instances"""
return {
"mixed": len(self.mixed_servers),
"prefill": len(self.prefill_servers),
"decode": len(self.decode_servers),
}
async def select_pd(self):
"""Select one prefill and one decode server"""
async with self.lock:
if not self.prefill_servers:
raise RuntimeError("No prefill servers available")
if not self.decode_servers:
raise RuntimeError("No decode servers available")
pidx = random.randint(0, len(self.prefill_servers) - 1)
didx = random.randint(0, len(self.decode_servers) - 1)
return self.prefill_servers[pidx], self.decode_servers[didx]
async def select_mixed(self):
"""Select one mixed server"""
async with self.lock:
if not self.mixed_servers:
raise RuntimeError("No mixed servers available")
idx = random.randint(0, len(self.mixed_servers) - 1)
return self.mixed_servers[idx]
async def handle_request(self, request_data: dict, endpoint_name: str):
if self.splitwise:
return await self.handle_splitwise_request(request_data, endpoint_name)
else:
return await self.handle_mixed_request(request_data, endpoint_name)
async def handle_mixed_request(self, request_data: dict, endpoint_name: str):
logger.debug(f"Received request: {request_data}")
mixed_server = await self.select_mixed()
if request_data.get("stream", False):
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
else:
return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name)
async def handle_splitwise_request(self, request_data: dict, endpoint_name: str):
logger.debug(f"Received request: {request_data}")
prefill_server, decode_server = await self.select_pd()
# TODO: unify the disaggregate_info in server and remove redundancy params
is_same_node = prefill_server.host_ip == decode_server.host_ip
use_ipc = (
is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
)
cache_info = {}
if use_ipc:
cache_info["ipc"] = {
"ip": decode_server.host_ip,
"port": decode_server.engine_worker_queue_port,
"device_ids": decode_server.device_ids,
}
else:
cache_info["rdma"] = {
"ip": decode_server.host_ip,
"port": decode_server.connector_port,
"rdma_port": decode_server.rdma_ports,
}
disaggregate_info = {
"prefill": prefill_server.to_dict(),
"decode": decode_server.to_dict(),
"role": "decode",
"cache_info": cache_info,
"transfer_protocol": "ipc" if use_ipc else "rdma",
}
modified_request = request_data.copy()
modified_request["disaggregate_info"] = disaggregate_info
if "request_id" not in modified_request:
modified_request["request_id"] = str(uuid4())
logger.debug(f"Modified request: {modified_request}")
if request_data.get("stream", False):
return await self._generate_stream(
modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name
)
else:
return await self._generate(
modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name
)
async def _generate(
self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions"
) -> ORJSONResponse:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls]
results = await asyncio.gather(*tasks)
ret_json = await results[return_result_url_index].json()
return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status)
async def _generate_stream(
self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions"
):
async def stream_results():
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls]
results = await asyncio.gather(*tasks)
AIOHTTP_STREAM_READ_CHUNK_SIZE = 1024 * 64 # prevent aiohttp's "Chunk too big" error
async for chunk in results[return_result_url_index].content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
logger.debug(f"receive response chunk: {chunk}")
yield chunk
return StreamingResponse(stream_results(), media_type="text/event-stream")
async def monitor_instance_health(self, interval_secs: float = 5.0):
"""
Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones.
"""
while True:
try:
prefill_to_remove = []
decode_to_remove = []
mixed_to_remove = []
async with aiohttp.ClientSession() as session:
# check servers
prefill_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.prefill_servers]
decode_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.decode_servers]
mixed_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.mixed_servers]
# gather all tasks concurrently
all_tasks = prefill_tasks + decode_tasks + mixed_tasks
for inst, coro in all_tasks:
try:
resp = await coro
if resp.status != 200:
logger.warning(f"Instance {inst.url()} unhealthy: {resp.status}")
if inst in self.prefill_servers:
prefill_to_remove.append(inst)
elif inst in self.decode_servers:
decode_to_remove.append(inst)
elif inst in self.mixed_servers:
mixed_to_remove.append(inst)
except Exception as e:
logger.warning(f"Instance {inst.url()} check failed: {e}")
if inst in self.prefill_servers:
prefill_to_remove.append(inst)
elif inst in self.decode_servers:
decode_to_remove.append(inst)
elif inst in self.mixed_servers:
mixed_to_remove.append(inst)
# remove unhealthy instances under lock
async with self.lock:
if prefill_to_remove:
for inst in prefill_to_remove:
self.prefill_servers.remove(inst)
logger.info(f"Removed unhealthy prefill instance: {inst.url()}")
if decode_to_remove:
for inst in decode_to_remove:
self.decode_servers.remove(inst)
logger.info(f"Removed unhealthy decode instance: {inst.url()}")
if mixed_to_remove:
for inst in mixed_to_remove:
self.mixed_servers.remove(inst)
logger.info(f"Removed unhealthy mixed instance: {inst.url()}")
await asyncio.sleep(interval_secs)
prefill_instances = [inst.url() for inst in self.prefill_servers]
decode_instances = [inst.url() for inst in self.decode_servers]
mixed_instance = [inst.url() for inst in self.mixed_servers]
logger.debug(
f"Healthy prefill instances: {prefill_instances}, "
f"Healthy decode instances: {decode_instances}, "
f"Healthy mixed instance: {mixed_instance}"
)
except Exception as e:
logger.exception(f"Failed to monitor instance health: {e}")
@app.post("/register")
async def register(instance_info_dict: dict):
"""Register prefill/decode/mixed servers"""
try:
await app.state.router.register_instance(instance_info_dict)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return {"status": "success"}
@app.get("/registered_number")
async def registered_number():
"""Get the number of registered prefill/decode/mixed servers"""
return await app.state.router.registered_number()
@app.post("/v1/chat/completions")
async def create_chat_completion(request_data: dict):
return await app.state.router.handle_request(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def create_completion(request_data: dict):
return await app.state.router.handle_request(request_data, "v1/completions")
@app.get("/health")
async def health_check():
"""Basic health check"""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
"""Check all prefill and decode servers are healthy"""
router = app.state.router
async with aiohttp.ClientSession() as session:
tasks = [session.get(f"{s.url()}/health") for s in chain(router.prefill_servers, router.decode_servers)]
for coro in asyncio.as_completed(tasks):
resp = await coro
if resp.status != 200:
logger.warning(f"Server {resp.url} not healthy: {resp.status}")
return Response(status_code=200)
def start_router(router_args):
app.state.router_args = router_args
@app.on_event("startup")
async def startup_event():
app.state.router = Router(app.state.router_args)
asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5))
uvicorn.run(app, host=router_args.host, port=router_args.port)

131
fastdeploy/router/utils.py Normal file
View File

@@ -0,0 +1,131 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Union
import aiohttp
import requests
class InstanceRole(Enum):
MIXED = 0
PREFILL = 1
DECODE = 2
@dataclass
class InstanceInfo:
role: Union[InstanceRole, str]
host_ip: str
port: Union[int, str]
connector_port: Union[int, str] = 0
engine_worker_queue_port: Union[int, str] = 0
transfer_protocol: List[str] = field(default_factory=list)
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
device_ids: Union[List[str], List[int]] = field(default_factory=list)
def __post_init__(self):
"""check and unify fields"""
if isinstance(self.role, str):
try:
self.role = InstanceRole[self.role.upper()]
except KeyError:
raise ValueError(f"Invalid role string: {self.role}")
elif not isinstance(self.role, InstanceRole):
raise TypeError(f"role must be InstanceRole or str, got {type(self.role)}")
for t in self.transfer_protocol:
assert t in ["ipc", "rdma"], f"Invalid transfer_protocol: {self.transfer_protocol}"
self.port = str(self.port)
self.connector_port = str(self.connector_port)
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
if self.rdma_ports:
self.rdma_ports = [str(p) for p in self.rdma_ports]
if self.device_ids:
self.device_ids = [str(i) for i in self.device_ids]
def to_dict(self):
return {k: (v.name if isinstance(v, Enum) else v) for k, v in asdict(self).items()}
def url(self) -> str:
url = f"{self.host_ip}:{self.port}"
if not url.startswith(("http://", "https://")):
url = f"http://{url}"
return url
def check_service_health(base_url: str, timeout: int = 3) -> bool:
"""
Check the health status of a service.
Args:
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
timeout (int): Request timeout in seconds.
Returns:
bool: True if the service is healthy, False otherwise.
"""
if not base_url.startswith(("http://", "https://")):
base_url = f"http://{base_url}"
url = f"{base_url.rstrip('/')}/health"
try:
resp = requests.get(url, timeout=timeout)
if resp.status_code == 200:
return True
else:
return False
except Exception:
return False
async def check_service_health_async(base_url: str, timeout: int = 3) -> bool:
"""
Asynchronously check the health status of a service.
Args:
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
timeout (int): Request timeout in seconds.
Returns:
bool: True if the service is healthy, False otherwise.
"""
if not base_url.startswith(("http://", "https://")):
base_url = f"http://{base_url}"
url = f"{base_url.rstrip('/')}/health"
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
async with session.get(url) as resp:
status = resp.status
text = await resp.text()
if status == 200:
print(f"[OK] Service is healthy ({status})")
return True
else:
print(f"[WARN] Service not healthy ({status}): {text}")
return False
except aiohttp.ClientError as e:
print(f"[ERROR] Failed to connect to {url}: {e}")
return False
except asyncio.TimeoutError:
print(f"[ERROR] Request to {url} timed out after {timeout}s")
return False

View File

@@ -16,7 +16,9 @@
import redis
from fastdeploy.utils import llm_logger
from fastdeploy.utils import get_logger, llm_logger
config_logger = get_logger("config", "config.log")
from .dp_scheduler import DPScheduler
from .global_scheduler import GlobalScheduler
@@ -84,10 +86,10 @@ class LocalSchedulerConfig:
"""
Print the current configuration to logs.
"""
llm_logger.info("LocalScheduler Configuration Information :")
config_logger.info("LocalScheduler Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("=============================================================")
config_logger.info("{:<20}:{:<6}{}".format(k, "", v))
config_logger.info("=============================================================")
class DPLocalSchedulerConfig(LocalSchedulerConfig):
@@ -312,6 +314,7 @@ class SchedulerConfig:
Returns:
Initialized scheduler instance (LocalScheduler or GlobalScheduler)
"""
llm_logger.info("Scheduler Type: %s" % self.name)
if self.name == "global":
return GlobalScheduler(

View File

@@ -195,6 +195,20 @@ class LocalScheduler:
results += [(request_id, "duplicated request_id") for request_id in duplicated_ids]
return results
def has_request(self, request_id: str) -> bool:
"""
Check if there are any pending requests in the scheduler.
Args:
request_id: Optional specific request ID to check.
If None, checks whether there are any pending requests.
Returns:
True if there are pending requests, False otherwise.
"""
with self.mutex:
return request_id in self.requests
def calc_required_blocks(self, token_num, block_size):
"""
Calculate the number of blocks needed for a given number of tokens.
@@ -292,6 +306,7 @@ class LocalScheduler:
Args:
results: List of RequestOutput objects containing results
"""
scheduler_logger.debug(f"put results: {results}")
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
finished_responses = [response.request_id for response in responses if response.finished]
@@ -354,4 +369,8 @@ class LocalScheduler:
if finished:
self._recycle(request_id)
scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}")
if results:
scheduler_logger.debug(f"get responses, {results}")
return results

View File

@@ -18,12 +18,12 @@ import json
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Dict
from typing import Dict, List
import zmq
from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
@@ -241,7 +241,7 @@ class SplitwiseConnector:
},
}
def send_splitwise_tasks(self, tasks, current_id):
def send_splitwise_tasks(self, tasks: List[Request], current_id):
"""
Send splitwise tasks to all connected addresses.
@@ -276,6 +276,7 @@ class SplitwiseConnector:
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
task.disaggregate_info["role"] = "decode"
self.logger.debug(f"send task to coupled instance, {addr}, {task}")
self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill"
@@ -311,7 +312,7 @@ class SplitwiseConnector:
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
self.logger.info("send first token to port decode")
self.logger.info(f"send first token to decode, {[x.request_id for x in tasks_list]}")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
@@ -355,7 +356,7 @@ class SplitwiseConnector:
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id):
def send_cache_infos(self, tasks: List[Request], current_id):
"""
Send cache information to specific port.
@@ -432,8 +433,10 @@ class SplitwiseConnector:
if not is_decode and len(temp_cache_info):
for k, v in temp_cache_info.items():
self.logger.debug(f"send cache info to cachemessager, {v}")
self.engine_worker_queue.put_cache_info(v)
else:
self.logger.debug(f"send cache info to coupled instance, {temp_cache_info}")
if len(temp_cache_info):
for k, v in temp_cache_info.items():
self.logger.info(f"{k} {v}")
@@ -490,7 +493,7 @@ class SplitwiseConnector:
"""
Handle prefill tasks from other nodes.
"""
self.logger.debug(f"_handle_prefill function receive {tasks}")
tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
@@ -498,21 +501,9 @@ class SplitwiseConnector:
"""
Handle decode tasks from other nodes.
"""
self.logger.debug(f"_handle_decode function receive {payload}")
tasks = []
for task in payload:
tasks.append(
RequestOutput(
request_id=task["request_id"],
outputs=CompletionOutput(
index=task["outputs"]["index"],
send_idx=0,
token_ids=task["outputs"]["token_ids"],
draft_token_ids=task["outputs"]["draft_token_ids"],
),
finished=True,
num_cached_tokens=task["num_cached_tokens"],
error_code=task["error_code"],
error_msg=task["error_msg"],
)
)
output = RequestOutput.from_dict(task)
tasks.append(output)
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))

View File

@@ -982,6 +982,7 @@ api_server_logger = get_logger("api_server", "api_server.log")
console_logger = get_logger("console", "console.log", print_to_console=True)
spec_logger = get_logger("speculate", "speculate.log")
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
router_logger = get_logger("router", "router.log")
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:

View File

@@ -475,8 +475,10 @@ class PaddleDisWorkerProc:
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
start_execute_time = time.time()
self.worker.execute_model(req_dicts, num_running_requests)
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many

View File

@@ -42,3 +42,5 @@ opentelemetry-instrumentation-fastapi
partial_json_parser
msgspec
einops
setproctitle
aistudio_sdk

View File

@@ -45,18 +45,24 @@ def setup_and_run_server():
clean_ports()
print("log dir clean ")
if os.path.exists("log") and os.path.isdir("log"):
shutil.rmtree("log")
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
shutil.rmtree("log_prefill")
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
shutil.rmtree("log_decode")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "./ERNIE-4.5-0.3B-Paddle"
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
print(f"model_path: {model_path}")
# prefill实例
print("start prefill...")
env_prefill = os.environ.copy()
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_prefill["FD_LOG_DIR"] = "log_prefill"
env_prefill["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT)
prefill_log_path = "server.log"
prefill_cmd = [
@@ -94,12 +100,15 @@ def setup_and_run_server():
start_new_session=True, # Enables killing full group via os.killpg
env=env_prefill,
)
time.sleep(3)
# decode实例
print("start decode...")
env_decode = os.environ.copy()
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_decode["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1)
env_decode["FD_LOG_DIR"] = "decode_log"
env_decode["FD_LOG_DIR"] = "log_decode"
decode_log_path = "decode_server.log"
decode_cmd = [
sys.executable,
@@ -125,6 +134,8 @@ def setup_and_run_server():
"wint8",
"--splitwise-role",
"decode",
"--innode-prefill-ports",
str(FD_ENGINE_QUEUE_PORT),
]
# Start subprocess in new process group
@@ -260,18 +271,7 @@ def test_chat_usage_stream(api_url):
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
response = send_request(url=p_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Prefill Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
_, d_url = api_url # Only the decode server receives the request
response = send_request(url=d_url, payload=payload)
chunks = get_stream_chunks(response)
@@ -302,16 +302,7 @@ def test_chat_usage_non_stream(api_url):
"stream": False,
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
response = send_request(url=p_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
_, d_url = api_url
response = send_request(url=d_url, payload=payload).json()
usage = response["usage"]
@@ -336,25 +327,13 @@ def test_non_chat_usage_stream(api_url):
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
p_url = p_url.replace("chat/completions", "completions")
_, d_url = api_url
d_url = d_url.replace("chat/completions", "completions")
response = send_request(url=p_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
# print("Prefill Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
response = send_request(url=d_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
# print("Decode Response:", result)
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
@@ -375,23 +354,13 @@ def test_non_chat_usage_non_stream(api_url):
"stream": False,
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
p_url = p_url.replace("chat/completions", "completions")
_, d_url = api_url
d_url = d_url.replace("chat/completions", "completions")
response = send_request(url=p_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
# print("Prefill Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
response = send_request(url=d_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
print("Decode Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"

View File

@@ -0,0 +1,500 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import signal
import socket
import subprocess
import sys
import time
import pytest
import requests
# Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [
FD_API_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_CONNECTOR_PORT,
FD_API_PORT + 1,
FD_ENGINE_QUEUE_PORT + 1,
FD_METRICS_PORT + 1,
FD_CACHE_QUEUE_PORT + 1,
FD_CONNECTOR_PORT + 1,
FD_ROUTER_PORT,
]
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
def check_service_health(base_url: str, timeout: int = 3) -> bool:
"""
Check the health status of a service.
Args:
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
timeout (int): Request timeout in seconds.
Returns:
bool: True if the service is healthy, False otherwise.
"""
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
url = f"{base_url.rstrip('/')}/health"
try:
resp = requests.get(url, timeout=timeout)
if resp.status_code == 200:
return True
else:
return False
except Exception:
return False
def get_registered_number(router_url) -> list:
"""
Get the number of registered models in the router.
Args:
router_url (str): The base URL of the router, e.g. "http://localhost:8080".
Returns:
int: The number of registered models.
"""
if not router_url.startswith("http"):
router_url = f"http://{router_url}"
try:
response = requests.get(f"{router_url}/registered_number", timeout=60)
registered_numbers = response.json()
return registered_numbers
except Exception:
return {"mixed": 0, "prefill": 0, "decode": 0}
def kill_process_on_port(port: int):
"""
Kill processes that are listening on the given port.
Uses `lsof` to find process ids and sends SIGKILL.
"""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
current_pid = os.getpid()
parent_pid = os.getppid()
for pid in output.splitlines():
pid = int(pid)
if pid in (current_pid, parent_pid):
print(f"Skip killing current process (pid={pid}) on port {port}")
continue
os.kill(pid, signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass
def clean_ports():
"""
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
"""
for port in PORTS_TO_CLEAN:
kill_process_on_port(port)
time.sleep(2)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean_ports()
print("log dir clean ")
if os.path.exists("log_router") and os.path.isdir("log_router"):
shutil.rmtree("log_router")
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
shutil.rmtree("log_prefill")
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
shutil.rmtree("log_decode")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
print(f"model_path: {model_path}")
# router
print("start router...")
env_router = os.environ.copy()
env_router["FD_LOG_DIR"] = "log_router"
router_log_path = "router.log"
router_cmd = [
sys.executable,
"-m",
"fastdeploy.router.launch",
"--port",
str(FD_ROUTER_PORT),
"--splitwise",
]
with open(router_log_path, "w") as logfile:
process_router = subprocess.Popen(
router_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_router,
)
# prefill实例
print("start prefill...")
env_prefill = os.environ.copy()
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_prefill["FD_LOG_DIR"] = "log_prefill"
env_prefill["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT)
prefill_log_path = "server.log"
prefill_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"8192",
"--max-num-seqs",
"20",
"--quantization",
"wint8",
"--splitwise-role",
"prefill",
"--cache-transfer-protocol",
"ipc",
"--pd-comm-port",
str(FD_CONNECTOR_PORT),
"--router",
f"0.0.0.0:{FD_ROUTER_PORT}",
]
# Start subprocess in new process group
with open(prefill_log_path, "w") as logfile:
process_prefill = subprocess.Popen(
prefill_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_prefill,
)
time.sleep(1)
# decode实例
print("start decode...")
env_decode = os.environ.copy()
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
env_decode["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_decode["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1)
env_decode["FD_LOG_DIR"] = "log_decode"
decode_log_path = "decode_server.log"
decode_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT + 1),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT + 1),
"--metrics-port",
str(FD_METRICS_PORT + 1),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT + 1),
"--max-model-len",
"8192",
"--max-num-seqs",
"20",
"--quantization",
"wint8",
"--splitwise-role",
"decode",
"--cache-transfer-protocol",
"ipc",
"--pd-comm-port",
str(FD_CONNECTOR_PORT + 1),
"--router",
f"0.0.0.0:{FD_ROUTER_PORT}",
]
# Start subprocess in new process group
with open(decode_log_path, "w") as logfile:
process_decode = subprocess.Popen(
decode_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_decode,
)
# Wait up to 300 seconds for API server to be ready
for _ in range(60):
registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}")
if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1:
print("Prefill and decode servers are both online")
break
time.sleep(5)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean_ports()
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process_router.pid, signal.SIGTERM)
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean_ports()
print(f"Prefill server (pid={process_prefill.pid}) terminated")
print(f"Decode server (pid={process_decode.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
def test_metrics_config(metrics_url):
timeout = 600
url = metrics_url.replace("metrics", "config-info")
res = requests.get(url, timeout=timeout)
assert res.status_code == 200
def send_request(url, payload, timeout=600):
"""
发送请求到指定的URL并返回响应结果。
"""
headers = {
"Content-Type": "application/json",
}
try:
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
print("🟢 接收响应中...\n")
return res
except requests.exceptions.Timeout:
print(f"❌ 请求超时(超过 {timeout} 秒)")
return None
except requests.exceptions.RequestException as e:
print(f"❌ 请求失败:{e}")
return None
def get_stream_chunks(response):
"""解析流式返回生成chunk List[dict]"""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"解析失败: {e}, 行内容: {line}")
else:
print(f"请求失败,状态码: {response.status_code}")
print("返回内容:", response.text)
return chunks
def test_chat_usage_stream(api_url):
"""测试流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_chat_usage_non_stream(api_url):
"""测试非流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_stream(api_url):
"""测试流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_non_stream(api_url):
"""测试非流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
print("Decode Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"

View File

@@ -0,0 +1,486 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Test for router and mixed server
import json
import os
import shutil
import signal
import socket
import subprocess
import sys
import time
import pytest
import requests
# Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [
FD_API_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_API_PORT + 1,
FD_ENGINE_QUEUE_PORT + 1,
FD_METRICS_PORT + 1,
FD_CACHE_QUEUE_PORT + 1,
FD_ROUTER_PORT,
]
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
def check_service_health(base_url: str, timeout: int = 3) -> bool:
"""
Check the health status of a service.
Args:
base_url (str): The base URL of the service, e.g. "http://127.0.0.1:8080"
timeout (int): Request timeout in seconds.
Returns:
bool: True if the service is healthy, False otherwise.
"""
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
url = f"{base_url.rstrip('/')}/health"
try:
resp = requests.get(url, timeout=timeout)
if resp.status_code == 200:
return True
else:
return False
except Exception:
return False
def get_registered_number(router_url) -> list:
"""
Get the number of registered models in the router.
Args:
router_url (str): The base URL of the router, e.g. "http://localhost:8080".
Returns:
int: The number of registered models.
"""
if not router_url.startswith("http"):
router_url = f"http://{router_url}"
try:
response = requests.get(f"{router_url}/registered_number", timeout=60)
registered_numbers = response.json()
return registered_numbers
except Exception:
return {"mixed": 0, "prefill": 0, "decode": 0}
def kill_process_on_port(port: int):
"""
Kill processes that are listening on the given port.
Uses `lsof` to find process ids and sends SIGKILL.
"""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
current_pid = os.getpid()
parent_pid = os.getppid()
for pid in output.splitlines():
pid = int(pid)
if pid in (current_pid, parent_pid):
print(f"Skip killing current process (pid={pid}) on port {port}")
continue
os.kill(pid, signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass
def clean_ports():
"""
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
"""
for port in PORTS_TO_CLEAN:
kill_process_on_port(port)
time.sleep(2)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean_ports()
print("log dir clean ")
if os.path.exists("log_router") and os.path.isdir("log_router"):
shutil.rmtree("log_router")
if os.path.exists("log_server_0") and os.path.isdir("log_server_0"):
shutil.rmtree("log_server_0")
if os.path.exists("log_server_1") and os.path.isdir("log_server_1"):
shutil.rmtree("log_server_1")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
print(f"model_path: {model_path}")
# router
print("start router...")
env_router = os.environ.copy()
env_router["FD_LOG_DIR"] = "log_router"
router_log_path = "router.log"
router_cmd = [
sys.executable,
"-m",
"fastdeploy.router.launch",
"--port",
str(FD_ROUTER_PORT),
]
with open(router_log_path, "w") as logfile:
process_router = subprocess.Popen(
router_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_router,
)
# server0
print("start server0...")
env_server_0 = os.environ.copy()
env_server_0["CUDA_VISIBLE_DEVICES"] = "0"
env_server_0["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_server_0["FD_LOG_DIR"] = "log_server_0"
env_server_0["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT)
log_path = "server_0.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"8192",
"--max-num-seqs",
"20",
"--quantization",
"wint8",
"--router",
f"0.0.0.0:{FD_ROUTER_PORT}",
]
# Start subprocess in new process group
with open(log_path, "w") as logfile:
process_server_0 = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_server_0,
)
time.sleep(1)
# server 1
print("start server 1...")
env_server_1 = os.environ.copy()
env_server_1["CUDA_VISIBLE_DEVICES"] = "1"
env_server_1["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_server_1["INFERENCE_MSG_QUEUE_ID"] = str(FD_API_PORT + 1)
env_server_1["FD_LOG_DIR"] = "log_server_1"
log_path = "server_1.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT + 1),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT + 1),
"--metrics-port",
str(FD_METRICS_PORT + 1),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT + 1),
"--max-model-len",
"8192",
"--max-num-seqs",
"20",
"--quantization",
"wint8",
"--router",
f"0.0.0.0:{FD_ROUTER_PORT}",
]
# Start subprocess in new process group
with open(log_path, "w") as logfile:
process_server_1 = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_server_1,
)
# Wait up to 300 seconds for API server to be ready
for _ in range(60):
registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}")
if registered_numbers["mixed"] >= 2:
print("Mixed servers are both online")
break
time.sleep(5)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process_server_0.pid, signal.SIGTERM)
os.killpg(process_server_1.pid, signal.SIGTERM)
clean_ports()
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process_router.pid, signal.SIGTERM)
os.killpg(process_server_0.pid, signal.SIGTERM)
os.killpg(process_server_1.pid, signal.SIGTERM)
clean_ports()
print(f"server (pid={process_server_0.pid}) terminated")
print(f"server (pid={process_server_1.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
def test_metrics_config(metrics_url):
timeout = 600
url = metrics_url.replace("metrics", "config-info")
res = requests.get(url, timeout=timeout)
assert res.status_code == 200
def send_request(url, payload, timeout=600):
"""
发送请求到指定的URL并返回响应结果。
"""
headers = {
"Content-Type": "application/json",
}
try:
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
print("🟢 接收响应中...\n")
return res
except requests.exceptions.Timeout:
print(f"❌ 请求超时(超过 {timeout} 秒)")
return None
except requests.exceptions.RequestException as e:
print(f"❌ 请求失败:{e}")
return None
def get_stream_chunks(response):
"""解析流式返回生成chunk List[dict]"""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"解析失败: {e}, 行内容: {line}")
else:
print(f"请求失败,状态码: {response.status_code}")
print("返回内容:", response.text)
return chunks
def test_chat_usage_stream(api_url):
"""测试流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_chat_usage_non_stream(api_url):
"""测试非流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_stream(api_url):
"""测试流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
print("Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_non_stream(api_url):
"""测试非流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
print("Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"