mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-30 14:22:27 +08:00
[Feature] Add speculative decoding simulation benchmark. (#2751)
* Add speculative decoding simulation benchmark * Fix the name of the parameter
This commit is contained in:
@@ -105,3 +105,30 @@ python benchmark_serving.py \
|
||||
--save-result > infer_log.txt 2>&1 &
|
||||
```
|
||||
|
||||
### 投机解码性能测试工具
|
||||
|
||||
#### 使用方式:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mtp.py \
|
||||
--host 127.0.0.1 --port 8000 \
|
||||
--max-concurrency 16 32 64 96 --num-prompts 256 \
|
||||
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
|
||||
--s_itl-base-model 15.88 22.84 16.47 16.93 \
|
||||
--dataset-name EBChat \
|
||||
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
|
||||
```
|
||||
|
||||
#### 参数说明
|
||||
|
||||
```bash
|
||||
--host:服务ip地址,用于组url
|
||||
--port:服务HTTP端口,用于组url
|
||||
--max-concurrency:测试并发数
|
||||
--num-prompts:总计发送多少条请求
|
||||
--acceptance-rate:投机解码的模拟接受率
|
||||
--draft-token-steps:投机解码的步数
|
||||
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
||||
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
||||
--dataset-path:测试数据集路径
|
||||
```
|
191
benchmarks/benchmark_mtp.py
Normal file
191
benchmarks/benchmark_mtp.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
||||
from benchmark_serving import benchmark
|
||||
|
||||
|
||||
def prepare_input_requests(
|
||||
num_prompts: int, dataset_name: str, dataset_path: str
|
||||
) -> Union[EBDataset, EBChatDataset]:
|
||||
dataset_mapping = {
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}") from err
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
|
||||
selected_percentile_metrics = ["s_itl"]
|
||||
selected_percentiles = []
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend="openai-chat",
|
||||
api_url=f"{base_url}/v1/chat/completions",
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
model_name="default",
|
||||
input_requests=input_requests,
|
||||
hyper_parameters={},
|
||||
logprobs=None,
|
||||
request_rate=float("inf"),
|
||||
burstiness=1.0,
|
||||
disable_tqdm=disable_tqdm,
|
||||
profile=False,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
ignore_eos=False,
|
||||
goodput_config_dict=None,
|
||||
max_concurrency=max_concurrency,
|
||||
lora_modules=None,
|
||||
extra_body=None,
|
||||
)
|
||||
)
|
||||
|
||||
record = {
|
||||
"mean_s_itl_ms": results["mean_s_itl_ms"],
|
||||
}
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
||||
|
||||
tmp = 0.0
|
||||
for i in range(draft_token_step):
|
||||
tmp += pow(acceptance_rate, i + 1)
|
||||
|
||||
r_ac = tmp / (1 + tmp)
|
||||
|
||||
return t_ori / ((1 - r_ac) * t_mtp)
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
input_requests = prepare_input_requests(
|
||||
args.num_prompts, args.dataset_name, args.dataset_path
|
||||
)
|
||||
|
||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
|
||||
|
||||
# Benchmark
|
||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||
|
||||
metric_header = f"Speed up"
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
for draft_token_step in args.draft_token_steps:
|
||||
speedup = calculate_speedup(
|
||||
args.acceptance_rate,
|
||||
draft_token_step,
|
||||
s_itl,
|
||||
record["mean_s_itl_ms"],
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Speed up on {draft_token_step} steps draft", speedup
|
||||
)
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="8000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16, 32),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acceptance-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-token-steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--s_itl-base-model",
|
||||
type=float,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="EBChat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@@ -73,7 +73,7 @@ __global__ void speculate_verify(
|
||||
const int *output_cum_offsets, const int *actual_candidate_len,
|
||||
const int real_bsz, const int max_draft_tokens, const int end_length,
|
||||
const int max_seq_len, const int max_candidate_len, const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop, const bool benchmark_mode) {
|
||||
const int bid = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
int accept_num_now = 1;
|
||||
@@ -95,6 +95,9 @@ __global__ void speculate_verify(
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
// seq_lens_this_time[bid]-1);
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
if (benchmark_mode) {
|
||||
break;
|
||||
}
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
break;
|
||||
}
|
||||
@@ -246,7 +249,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp) {
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
|
||||
// printf("Enter speculate update\n");
|
||||
auto bsz = accept_tokens.shape()[0];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
@@ -301,7 +304,7 @@ void SpeculateVerify(
|
||||
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
|
||||
end_length, max_seq_len, max_candidate_len, verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop, benchmark_mode);
|
||||
} else {
|
||||
speculate_verify<false, true>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
@@ -317,7 +320,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
}
|
||||
} else {
|
||||
if (enable_topp) {
|
||||
@@ -335,7 +338,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
} else {
|
||||
speculate_verify<false, false>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
@@ -351,7 +354,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,7 +369,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
||||
"actual_candidate_len", "actual_draft_token_nums", "topp"})
|
||||
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
|
||||
"stop_flags_out"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
|
@@ -238,6 +238,10 @@ class SpeculativeConfig:
|
||||
# A trick method is currently used to enable this sharing.
|
||||
# This will be replaced with a more standardized solution in the future.
|
||||
sharing_model = None
|
||||
# During benchmarking, we need to enforce that the number of accepted tokens is 1.
|
||||
# This means no tokens from MTP are accepted.
|
||||
# This ensures that the specified simulation acceptance rate is not affected.
|
||||
benchmark_mode: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -337,6 +337,7 @@ class SpeculativeConfig:
|
||||
model_name_or_path (Optional[str]): Path of the model.
|
||||
quantization (str): Quantization method for draft model, default is WINT8.
|
||||
max_model_len: Optional[int]: Maximum model length for draft model.
|
||||
benchmark_mode (bool): Whether to use benchmark mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -345,12 +346,14 @@ class SpeculativeConfig:
|
||||
model: Optional[str] = None,
|
||||
quantization: Optional[str] = "WINT8",
|
||||
max_model_len: Optional[int] = None,
|
||||
benchmark_mode: bool = False,
|
||||
**kwargs):
|
||||
self.model_name_or_path = model
|
||||
self.method = method
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.quantization = quantization
|
||||
self.max_model_len = max_model_len
|
||||
self.benchmark_mode = benchmark_mode
|
||||
# Fixed now
|
||||
self.num_gpu_block_expand_ratio = 1
|
||||
self.num_extra_cache_layer = 0
|
||||
|
@@ -1030,6 +1030,7 @@ class LLMEngine(object):
|
||||
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
|
||||
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
||||
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
||||
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
|
||||
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.model_config.load_strategy}")
|
||||
|
@@ -235,6 +235,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
raise NotImplementedError()
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
""" pre process before running """
|
||||
@@ -309,6 +310,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
self.speculative_benchmark_mode,
|
||||
)
|
||||
|
||||
return None
|
||||
|
@@ -494,6 +494,11 @@ def parse_args():
|
||||
default="WINT8",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_benchmark_mode",
|
||||
default="false",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument("--max_num_batched_tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
@@ -625,6 +630,9 @@ def initialize_fd_config(config_or_args) -> FDConfig:
|
||||
speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0)
|
||||
speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None)
|
||||
speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None)
|
||||
speculative_config.benchmark_mode = (
|
||||
getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Update parallel config
|
||||
parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None)
|
||||
|
Reference in New Issue
Block a user