mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
update (#5625)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -46,6 +46,7 @@ python -m pip install -r requirements.txt
|
||||
--shuffle:是否打乱数据集,默认False不打乱
|
||||
--seed:打乱数据集时的随机种子,默认0
|
||||
--pd-metrics:开启PD分离metrics指标收集,会添加请求参数collect_metrics=True,默认False
|
||||
--ip-list:支持多个ip:port,将总请求数以及总并发数均分到每个IP,按整除取余分配。例:0.0.0.0:1211,0.0.0.0:1222,默认为空
|
||||
```
|
||||
|
||||
##### /v1/chat/completions接口压测单条数据调试
|
||||
|
||||
@@ -74,6 +74,8 @@ class RequestFuncOutput:
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
prompt_tokens: int = 0 # 推理侧返回输入token数
|
||||
reasoning_tokens: int = 0 # 思考长度
|
||||
res_ttft: int = 0 # 包含思考首token时延
|
||||
error: str = ""
|
||||
metrics: dict = field(default_factory=dict)
|
||||
|
||||
@@ -198,11 +200,14 @@ async def async_request_eb_openai_chat_completions(
|
||||
request_id = "None"
|
||||
|
||||
ttft = 0.0
|
||||
res_ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
token_timestamps = []
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers, read_bufsize=10 * 1024 * 1024
|
||||
) as response:
|
||||
data = {}
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
@@ -242,6 +247,14 @@ async def async_request_eb_openai_chat_completions(
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
# response首token
|
||||
if res_ttft == 0.0:
|
||||
if content:
|
||||
res_ttft = choices[0]["arrival_time"]
|
||||
output.res_ttft = res_ttft
|
||||
usage = data.get("usage", {})
|
||||
output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0)
|
||||
|
||||
output.generated_text += content or ""
|
||||
output.reasoning_content += reason_content or ""
|
||||
# print(f"####content:{data}")
|
||||
@@ -262,6 +275,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
if output.generated_text.strip() == "":
|
||||
output.success = False
|
||||
output.reasoning_tokens = output.output_tokens
|
||||
output.error = "No generated text found!"
|
||||
else:
|
||||
output.success = True
|
||||
@@ -284,7 +298,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
output.request_id = request_id
|
||||
|
||||
# 保存失败请求结果
|
||||
if not output.success:
|
||||
if not output.success or output.output_tokens == 0:
|
||||
with open("error_output.txt", "a") as f:
|
||||
f.write(str(output) + "\n")
|
||||
if pbar:
|
||||
|
||||
@@ -104,6 +104,14 @@ class BenchmarkMetrics:
|
||||
median_output_len: float
|
||||
std_output_len: float
|
||||
percentiles_output_len: list[tuple[float, float]]
|
||||
mean_reasoning_len: float
|
||||
median_reasoning_len: float
|
||||
std_reasoning_len: float
|
||||
percentiles_reasoning_len: list[tuple[float, float]]
|
||||
mean_res_ttft_ms: float
|
||||
median_res_ttft_ms: float
|
||||
std_res_ttft_ms: float
|
||||
percentiles_res_ttft_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
async def get_request(
|
||||
@@ -160,6 +168,7 @@ def calculate_metrics(
|
||||
input_lens: list[int] = []
|
||||
infer_input_lens: list[int] = [] # 推理侧输入token数
|
||||
actual_output_lens: list[int] = []
|
||||
reasoning_output_lens: list[int] = []
|
||||
total_input = 0
|
||||
completed = 0
|
||||
good_completed = 0
|
||||
@@ -169,6 +178,7 @@ def calculate_metrics(
|
||||
all_tpots: list[float] = []
|
||||
ttfts: list[float] = []
|
||||
s_ttfts: list[float] = []
|
||||
res_ttfts: list[float] = []
|
||||
e2els: list[float] = []
|
||||
s_e2els: list[float] = []
|
||||
s_decodes: list[float] = []
|
||||
@@ -186,6 +196,7 @@ def calculate_metrics(
|
||||
continue
|
||||
|
||||
actual_output_lens.append(output_len)
|
||||
reasoning_output_lens.append(outputs[i].reasoning_tokens)
|
||||
input_lens.append(outputs[i].prompt_len)
|
||||
infer_input_lens.append(outputs[i].prompt_tokens)
|
||||
total_input += outputs[i].prompt_tokens
|
||||
@@ -204,6 +215,7 @@ def calculate_metrics(
|
||||
ttfts.append(outputs[i].ttft)
|
||||
# 推理侧TTFT
|
||||
s_ttfts.append(outputs[i].arrival_time[1])
|
||||
res_ttfts.append(outputs[i].res_ttft)
|
||||
e2els.append(outputs[i].latency)
|
||||
# 推理侧整句时延
|
||||
s_e2els.append(outputs[i].arrival_time[-1])
|
||||
@@ -296,6 +308,14 @@ def calculate_metrics(
|
||||
std_output_len=np.std(actual_output_lens or 0) * 1,
|
||||
median_output_len=np.median(actual_output_lens or 0) * 1,
|
||||
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
|
||||
mean_reasoning_len=np.mean(reasoning_output_lens or 0) * 1,
|
||||
std_reasoning_len=np.std(reasoning_output_lens or 0) * 1,
|
||||
median_reasoning_len=np.median(reasoning_output_lens or 0) * 1,
|
||||
percentiles_reasoning_len=[(p, np.percentile(reasoning_output_lens or 0, p)) for p in selected_percentiles],
|
||||
mean_res_ttft_ms=np.mean(res_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
|
||||
std_res_ttft_ms=np.std(res_ttfts or 0) * 1000,
|
||||
median_res_ttft_ms=np.median(res_ttfts or 0) * 1000,
|
||||
percentiles_res_ttft_ms=[(p, np.percentile(res_ttfts or 0, p) * 1000) for p in selected_percentiles],
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@@ -323,6 +343,7 @@ async def benchmark(
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
ip_list: Optional[list[str]] = None,
|
||||
):
|
||||
"""Benchmarks an API endpoint using a given set of sample inputs and returns"""
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@@ -340,6 +361,9 @@ async def benchmark(
|
||||
response_format = input_requests[0].response_format
|
||||
random_flag = input_requests[0].random_flag
|
||||
|
||||
if len(ip_list) >= 1:
|
||||
api_url = f"http://{ip_list[0]}{args.endpoint}"
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
@@ -411,51 +435,133 @@ async def benchmark(
|
||||
# and it will simplify the code in limited_request_func.
|
||||
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
# if max_concurrency else contextlib.nullcontext())
|
||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||
ip_list = ip_list or []
|
||||
|
||||
async def limited_request_func(request_func_input, pbar):
|
||||
if semaphore is None:
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
if len(ip_list) <= 1:
|
||||
if len(ip_list) == 1:
|
||||
api_url = f"http://{ip_list[0]}{args.endpoint}"
|
||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, output_len, no = (
|
||||
request.prompt,
|
||||
request.expected_output_len,
|
||||
request.no,
|
||||
)
|
||||
history_QA = request.history_QA
|
||||
response_format = request.response_format
|
||||
random_flag = request.random_flag
|
||||
async def limited_request_func(request_func_input, pbar):
|
||||
if semaphore is None:
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||
tasks: list[asyncio.Task] = []
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
no=no,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
api_url=api_url,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
debug=debug,
|
||||
pd_metrics=pd_metrics,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, output_len, no = (
|
||||
request.prompt,
|
||||
request.expected_output_len,
|
||||
request.no,
|
||||
)
|
||||
history_QA = request.history_QA
|
||||
response_format = request.response_format
|
||||
random_flag = request.random_flag
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
no=no,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
api_url=api_url,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
debug=debug,
|
||||
pd_metrics=pd_metrics,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
else:
|
||||
# 多ip按DP均分并发
|
||||
assert max_concurrency, "multi-IP 模式必须指定 max_concurrency"
|
||||
n_ip = len(ip_list)
|
||||
concurrency_per_ip = max_concurrency // n_ip
|
||||
concurrency_remainder = max_concurrency % n_ip
|
||||
|
||||
# 分配请求
|
||||
req_per_ip = len(input_requests) // n_ip
|
||||
remainder = len(input_requests) % n_ip
|
||||
|
||||
ip_requests_map = {}
|
||||
start = 0
|
||||
for i, ip in enumerate(ip_list):
|
||||
count = req_per_ip + (1 if i < remainder else 0)
|
||||
print(f"IP: {ip}, requests: {count}")
|
||||
print(f"start: {start}, end: {start + count}")
|
||||
ip_requests_map[ip] = input_requests[start : start + count]
|
||||
start += count
|
||||
|
||||
# exit(8)
|
||||
|
||||
semaphores = {
|
||||
ip: asyncio.Semaphore(concurrency_per_ip + (1 if i < concurrency_remainder else 0))
|
||||
for i, ip in enumerate(ip_list)
|
||||
}
|
||||
|
||||
async def limited_request_func_per_ip(req_input, semaphore, pbar):
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=req_input, pbar=pbar)
|
||||
|
||||
tasks = []
|
||||
for i, ip in enumerate(ip_list):
|
||||
print(
|
||||
f"Starting benchmark for IP: {ip}, "
|
||||
f"concurrency per IP: {semaphores[ip]._value}, "
|
||||
f"requests per IP: {len(ip_requests_map[ip])}",
|
||||
flush=True,
|
||||
)
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
for i, ip in enumerate(ip_list):
|
||||
semaphore = semaphores[ip]
|
||||
|
||||
for request in ip_requests_map[ip]:
|
||||
prompt, output_len, no = request.prompt, request.expected_output_len, request.no
|
||||
history_QA = request.history_QA
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
req_model_id = req_model_name = req_lora_module
|
||||
|
||||
req_input = RequestFuncInput(
|
||||
model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
no=no,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
api_url=f"http://{ip}{args.endpoint}", # ★ 多 IP 模式仅替换 host:port
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
ignore_eos=ignore_eos,
|
||||
debug=debug,
|
||||
pd_metrics=pd_metrics,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
|
||||
tasks.append(asyncio.create_task(limited_request_func_per_ip(req_input, semaphore, pbar)))
|
||||
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
outputs.sort(key=lambda x: x.end_timestamp)
|
||||
|
||||
@@ -541,11 +647,13 @@ async def benchmark(
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"reasoning_lens": [output.reasoning_tokens for output in outputs],
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"infer_input_lens": [output.prompt_tokens for output in outputs],
|
||||
"output_lens": [output.output_tokens for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"res_ttfts": [output.res_ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"input_texts": [input.prompt for input in input_requests],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
@@ -666,6 +774,7 @@ async def benchmark(
|
||||
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
||||
process_one_metric("res_ttft", "Response TTFT", "包含思考首token耗时")
|
||||
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
||||
@@ -686,6 +795,7 @@ async def benchmark(
|
||||
process_pd_metrics(outputs, "second_token_transmission_cost_time")
|
||||
process_one_length("input_len", "Cached Tokens", "Cached Tokens")
|
||||
process_one_length("s_input_len", "Input Length", "Infer Input Length")
|
||||
process_one_length("reasoning_len", "Reasoning Lenth", "思考长度")
|
||||
process_one_length("output_len", "Output Length", "Output Length")
|
||||
|
||||
print("=" * 50)
|
||||
@@ -983,6 +1093,15 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
hyper_parameters = {}
|
||||
|
||||
processed_list = []
|
||||
for item in args.ip_list:
|
||||
if "," in item:
|
||||
processed_list.extend([x.strip() for x in item.split(",") if x.strip()])
|
||||
else:
|
||||
processed_list.append(item)
|
||||
|
||||
ip_list = processed_list
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
@@ -1006,6 +1125,7 @@ def main(args: argparse.Namespace):
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ip_list=ip_list,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1097,6 +1217,17 @@ if __name__ == "__main__":
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ip-list",
|
||||
nargs="*",
|
||||
default=[],
|
||||
help=(
|
||||
"List of ip:port. "
|
||||
"Supports: "
|
||||
"1) --ip-list 127.0.0.1:8000 --ip-list 127.0.0.1:8001 "
|
||||
"2) --ip-list 127.0.0.1:8000,127.0.0.1:8001"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
@@ -1265,7 +1396,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
default="ttft,tpot,itl,reasoning_len",
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
|
||||
|
||||
Reference in New Issue
Block a user