[Feature] Add speculative decoding simulation benchmark. (#2751)

* Add speculative decoding simulation benchmark

* Fix the name of the parameter
This commit is contained in:
GoldPancake
2025-07-09 12:08:43 +08:00
committed by GitHub
parent 6b10c19482
commit f7cad30a38
8 changed files with 246 additions and 7 deletions

View File

@@ -105,3 +105,30 @@ python benchmark_serving.py \
--save-result > infer_log.txt 2>&1 &
```
### 投机解码性能测试工具
#### 使用方式:
```bash
python benchmarks/benchmark_mtp.py \
--host 127.0.0.1 --port 8000 \
--max-concurrency 16 32 64 96 --num-prompts 256 \
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
--s_itl-base-model 15.88 22.84 16.47 16.93 \
--dataset-name EBChat \
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
```
#### 参数说明
```bash
--host服务ip地址用于组url
--port服务HTTP端口用于组url
--max-concurrency测试并发数
--num-prompts总计发送多少条请求
--acceptance-rate投机解码的模拟接受率
--draft-token-steps投机解码的步数
--s_itl-base-model主模型的解码延迟可由上述的性能压测工具获得与batch-size一一对应
--dataset-name指定数据集类指定为"EBChat"可读取转存的FD格式数据集
--dataset-path测试数据集路径
```

191
benchmarks/benchmark_mtp.py Normal file
View File

@@ -0,0 +1,191 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import asyncio
import contextlib
import os
import signal
import socket
import subprocess
import time
from typing import Union
import openai
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_serving import benchmark
def prepare_input_requests(
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
}
try:
input_requests = dataset_mapping[dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {dataset_name}") from err
return input_requests
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
selected_percentile_metrics = ["s_itl"]
selected_percentiles = []
# Run benchmark
results = asyncio.run(
benchmark(
backend="openai-chat",
api_url=f"{base_url}/v1/chat/completions",
base_url=base_url,
model_id="default",
model_name="default",
input_requests=input_requests,
hyper_parameters={},
logprobs=None,
request_rate=float("inf"),
burstiness=1.0,
disable_tqdm=disable_tqdm,
profile=False,
selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles,
ignore_eos=False,
goodput_config_dict=None,
max_concurrency=max_concurrency,
lora_modules=None,
extra_body=None,
)
)
record = {
"mean_s_itl_ms": results["mean_s_itl_ms"],
}
return record
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
tmp = 0.0
for i in range(draft_token_step):
tmp += pow(acceptance_rate, i + 1)
r_ac = tmp / (1 + tmp)
return t_ori / ((1 - r_ac) * t_mtp)
def main(args):
base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests(
args.num_prompts, args.dataset_name, args.dataset_path
)
if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup
print("Starting warmup...")
with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
# Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup(
args.acceptance_rate,
draft_token_step,
s_itl,
record["mean_s_itl_ms"],
)
print(
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("=" * 50)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
)
parser.add_argument(
"--port",
type=str,
default="8000",
)
parser.add_argument(
"--max-concurrency",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16, 32),
)
parser.add_argument(
"--num-prompts",
type=int,
default=128,
)
parser.add_argument(
"--acceptance-rate",
type=float,
default=0.8,
)
parser.add_argument(
"--draft-token-steps",
type=int,
nargs="+",
default=(1, 2),
)
parser.add_argument(
"--s_itl-base-model",
type=float,
nargs="+",
)
parser.add_argument(
"--dataset-name",
type=str,
default="EBChat",
)
parser.add_argument(
"--dataset-path",
type=str,
)
args = parser.parse_args()
main(args)