From f7cad30a3830c3226c679d19c1551cf2a69451b5 Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Wed, 9 Jul 2025 12:08:43 +0800 Subject: [PATCH] [Feature] Add speculative decoding simulation benchmark. (#2751) * Add speculative decoding simulation benchmark * Fix the name of the parameter --- benchmarks/README.md | 27 +++ benchmarks/benchmark_mtp.py | 191 ++++++++++++++++++ .../speculate_decoding/speculate_verify.cu | 17 +- fastdeploy/config.py | 4 + fastdeploy/engine/config.py | 3 + fastdeploy/engine/engine.py | 1 + .../model_executor/layers/sample/sampler.py | 2 + fastdeploy/worker/worker_process.py | 8 + 8 files changed, 246 insertions(+), 7 deletions(-) create mode 100644 benchmarks/benchmark_mtp.py diff --git a/benchmarks/README.md b/benchmarks/README.md index 7c65a777f..aa9858ced 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -105,3 +105,30 @@ python benchmark_serving.py \ --save-result > infer_log.txt 2>&1 & ``` +### 投机解码性能测试工具 + +#### 使用方式: + +```bash +python benchmarks/benchmark_mtp.py \ + --host 127.0.0.1 --port 8000 \ + --max-concurrency 16 32 64 96 --num-prompts 256 \ + --acceptance-rate 0.8 --draft-token-steps 1 2 3 \ + --s_itl-base-model 15.88 22.84 16.47 16.93 \ + --dataset-name EBChat \ + --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json +``` + +#### 参数说明 + +```bash +--host:服务ip地址,用于组url +--port:服务HTTP端口,用于组url +--max-concurrency:测试并发数 +--num-prompts:总计发送多少条请求 +--acceptance-rate:投机解码的模拟接受率 +--draft-token-steps:投机解码的步数 +--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应 +--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集 +--dataset-path:测试数据集路径 +``` \ No newline at end of file diff --git a/benchmarks/benchmark_mtp.py b/benchmarks/benchmark_mtp.py new file mode 100644 index 000000000..65c2392a1 --- /dev/null +++ b/benchmarks/benchmark_mtp.py @@ -0,0 +1,191 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import argparse +import asyncio +import contextlib +import os +import signal +import socket +import subprocess +import time +from typing import Union + +import openai +import yaml +from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest +from benchmark_serving import benchmark + + +def prepare_input_requests( + num_prompts: int, dataset_name: str, dataset_path: str +) -> Union[EBDataset, EBChatDataset]: + dataset_mapping = { + "EB": lambda: EBDataset(dataset_path=dataset_path).sample( + num_requests=num_prompts + ), + "EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample( + num_requests=num_prompts + ), + } + + try: + input_requests = dataset_mapping[dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {dataset_name}") from err + + return input_requests + + +class FakeTokenizer: + def encode(self, text: str, add_special_tokens: bool = False): + return [] + + +def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm): + selected_percentile_metrics = ["s_itl"] + selected_percentiles = [] + # Run benchmark + results = asyncio.run( + benchmark( + backend="openai-chat", + api_url=f"{base_url}/v1/chat/completions", + base_url=base_url, + model_id="default", + model_name="default", + input_requests=input_requests, + hyper_parameters={}, + logprobs=None, + request_rate=float("inf"), + burstiness=1.0, + disable_tqdm=disable_tqdm, + profile=False, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, + ignore_eos=False, + goodput_config_dict=None, + max_concurrency=max_concurrency, + lora_modules=None, + extra_body=None, + ) + ) + + record = { + "mean_s_itl_ms": results["mean_s_itl_ms"], + } + + return record + + +def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp): + + tmp = 0.0 + for i in range(draft_token_step): + tmp += pow(acceptance_rate, i + 1) + + r_ac = tmp / (1 + tmp) + + return t_ori / ((1 - r_ac) * t_mtp) + + +def main(args): + base_url = f"http://{args.host}:{args.port}" + + input_requests = prepare_input_requests( + args.num_prompts, args.dataset_name, args.dataset_path + ) + + if len(args.max_concurrency) != len(args.s_itl_base_model): + raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model") + + for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model): + # Wramup + print("Starting warmup...") + with open(os.devnull, "w") as f: + with contextlib.redirect_stdout(f): + send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True) + + # Benchmark + record = send_one_batch(base_url, max_concurrency, input_requests, False) + + metric_header = f"Speed up" + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + for draft_token_step in args.draft_token_steps: + speedup = calculate_speedup( + args.acceptance_rate, + draft_token_step, + s_itl, + record["mean_s_itl_ms"], + ) + print( + "{:<40} {:<10.2f}".format( + f"Speed up on {draft_token_step} steps draft", speedup + ) + ) + print("=" * 50) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + ) + parser.add_argument( + "--port", + type=str, + default="8000", + ) + parser.add_argument( + "--max-concurrency", + type=int, + nargs="+", + default=(1, 2, 4, 8, 16, 32), + ) + parser.add_argument( + "--num-prompts", + type=int, + default=128, + ) + parser.add_argument( + "--acceptance-rate", + type=float, + default=0.8, + ) + parser.add_argument( + "--draft-token-steps", + type=int, + nargs="+", + default=(1, 2), + ) + parser.add_argument( + "--s_itl-base-model", + type=float, + nargs="+", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="EBChat", + ) + parser.add_argument( + "--dataset-path", + type=str, + ) + args = parser.parse_args() + + main(args) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index 509ce99c5..730684374 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -73,7 +73,7 @@ __global__ void speculate_verify( const int *output_cum_offsets, const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, const bool benchmark_mode) { const int bid = threadIdx.x; // verify and set stop flags int accept_num_now = 1; @@ -95,6 +95,9 @@ __global__ void speculate_verify( // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { + if (benchmark_mode) { + break; + } if (seq_lens_encoder[bid] != 0) { break; } @@ -246,7 +249,7 @@ void SpeculateVerify( const paddle::Tensor &output_cum_offsets, const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp) { + int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) { // printf("Enter speculate update\n"); auto bsz = accept_tokens.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0]; @@ -301,7 +304,7 @@ void SpeculateVerify( is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, benchmark_mode); } else { speculate_verify <<<1, BlockSize, 0, accept_tokens.stream()>>>( @@ -317,7 +320,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); } } else { if (enable_topp) { @@ -335,7 +338,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); } else { speculate_verify <<<1, BlockSize, 0, accept_tokens.stream()>>>( @@ -351,7 +354,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); } } @@ -366,7 +369,7 @@ PD_BUILD_STATIC_OP(speculate_verify) "actual_candidate_len", "actual_draft_token_nums", "topp"}) .Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out", "stop_flags_out"}) - .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"}) + .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"}) .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, {"accept_num", "accept_num_out"}, {"step_idx", "step_idx_out"}, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 446e59298..4d513a21b 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -238,6 +238,10 @@ class SpeculativeConfig: # A trick method is currently used to enable this sharing. # This will be replaced with a more standardized solution in the future. sharing_model = None + # During benchmarking, we need to enforce that the number of accepted tokens is 1. + # This means no tokens from MTP are accepted. + # This ensures that the specified simulation acceptance rate is not affected. + benchmark_mode: bool = False @dataclass diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index eb95f2bf1..bac38bfb8 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -337,6 +337,7 @@ class SpeculativeConfig: model_name_or_path (Optional[str]): Path of the model. quantization (str): Quantization method for draft model, default is WINT8. max_model_len: Optional[int]: Maximum model length for draft model. + benchmark_mode (bool): Whether to use benchmark mode. """ def __init__(self, @@ -345,12 +346,14 @@ class SpeculativeConfig: model: Optional[str] = None, quantization: Optional[str] = "WINT8", max_model_len: Optional[int] = None, + benchmark_mode: bool = False, **kwargs): self.model_name_or_path = model self.method = method self.num_speculative_tokens = num_speculative_tokens self.quantization = quantization self.max_model_len = max_model_len + self.benchmark_mode = benchmark_mode # Fixed now self.num_gpu_block_expand_ratio = 1 self.num_extra_cache_layer = 0 diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 8549502cc..e95d0d0b1 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1030,6 +1030,7 @@ class LLMEngine(object): f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}" f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}" f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" + f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}" f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" f" --load_strategy {self.cfg.model_config.load_strategy}") diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 3d2553446..2ee2a8fd1 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -235,6 +235,7 @@ class SpeculativeSampler(nn.Layer): raise NotImplementedError() self.speculative_verify_window = fd_config.speculative_config.verify_window self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len + self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode def pre_process(self, skip_idx_list: List[int] = []): """ pre process before running """ @@ -309,6 +310,7 @@ class SpeculativeSampler(nn.Layer): max_model_len, self.speculative_verify_window, True, # enable_topp + self.speculative_benchmark_mode, ) return None diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index fef56089b..400e7097c 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -494,6 +494,11 @@ def parse_args(): default="WINT8", type=str, ) + parser.add_argument( + "--speculative_benchmark_mode", + default="false", + type=str, + ) parser.add_argument("--max_num_batched_tokens", type=int, default=2048, @@ -625,6 +630,9 @@ def initialize_fd_config(config_or_args) -> FDConfig: speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0) speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None) speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None) + speculative_config.benchmark_mode = ( + getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true" + ) # Update parallel config parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None)