mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	[Feature] Add speculative decoding simulation benchmark. (#2751)
* Add speculative decoding simulation benchmark * Fix the name of the parameter
This commit is contained in:
		| @@ -105,3 +105,30 @@ python benchmark_serving.py \ | ||||
|   --save-result > infer_log.txt 2>&1 & | ||||
| ``` | ||||
|  | ||||
| ### 投机解码性能测试工具 | ||||
|  | ||||
| #### 使用方式: | ||||
|  | ||||
| ```bash | ||||
| python benchmarks/benchmark_mtp.py \ | ||||
|   --host 127.0.0.1 --port 8000 \ | ||||
|   --max-concurrency 16 32 64 96 --num-prompts 256 \ | ||||
|   --acceptance-rate 0.8 --draft-token-steps 1 2 3 \ | ||||
|   --s_itl-base-model 15.88 22.84 16.47 16.93 \ | ||||
|   --dataset-name EBChat \ | ||||
|   --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json | ||||
| ``` | ||||
|  | ||||
| #### 参数说明 | ||||
|  | ||||
| ```bash | ||||
| --host:服务ip地址,用于组url | ||||
| --port:服务HTTP端口,用于组url | ||||
| --max-concurrency:测试并发数 | ||||
| --num-prompts:总计发送多少条请求 | ||||
| --acceptance-rate:投机解码的模拟接受率 | ||||
| --draft-token-steps:投机解码的步数 | ||||
| --s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应 | ||||
| --dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集 | ||||
| --dataset-path:测试数据集路径 | ||||
| ``` | ||||
							
								
								
									
										191
									
								
								benchmarks/benchmark_mtp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								benchmarks/benchmark_mtp.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| """ | ||||
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License" | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| import argparse | ||||
| import asyncio | ||||
| import contextlib | ||||
| import os | ||||
| import signal | ||||
| import socket | ||||
| import subprocess | ||||
| import time | ||||
| from typing import Union | ||||
|  | ||||
| import openai | ||||
| import yaml | ||||
| from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest | ||||
| from benchmark_serving import benchmark | ||||
|  | ||||
|  | ||||
| def prepare_input_requests( | ||||
|     num_prompts: int, dataset_name: str, dataset_path: str | ||||
| ) -> Union[EBDataset, EBChatDataset]: | ||||
|     dataset_mapping = { | ||||
|         "EB": lambda: EBDataset(dataset_path=dataset_path).sample( | ||||
|             num_requests=num_prompts | ||||
|         ), | ||||
|         "EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample( | ||||
|             num_requests=num_prompts | ||||
|         ), | ||||
|     } | ||||
|  | ||||
|     try: | ||||
|         input_requests = dataset_mapping[dataset_name]() | ||||
|     except KeyError as err: | ||||
|         raise ValueError(f"Unknown dataset: {dataset_name}") from err | ||||
|  | ||||
|     return input_requests | ||||
|  | ||||
|  | ||||
| class FakeTokenizer: | ||||
|     def encode(self, text: str, add_special_tokens: bool = False): | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm): | ||||
|     selected_percentile_metrics = ["s_itl"] | ||||
|     selected_percentiles = [] | ||||
|     # Run benchmark | ||||
|     results = asyncio.run( | ||||
|         benchmark( | ||||
|             backend="openai-chat", | ||||
|             api_url=f"{base_url}/v1/chat/completions", | ||||
|             base_url=base_url, | ||||
|             model_id="default", | ||||
|             model_name="default", | ||||
|             input_requests=input_requests, | ||||
|             hyper_parameters={}, | ||||
|             logprobs=None, | ||||
|             request_rate=float("inf"), | ||||
|             burstiness=1.0, | ||||
|             disable_tqdm=disable_tqdm, | ||||
|             profile=False, | ||||
|             selected_percentile_metrics=selected_percentile_metrics, | ||||
|             selected_percentiles=selected_percentiles, | ||||
|             ignore_eos=False, | ||||
|             goodput_config_dict=None, | ||||
|             max_concurrency=max_concurrency, | ||||
|             lora_modules=None, | ||||
|             extra_body=None, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     record = { | ||||
|         "mean_s_itl_ms": results["mean_s_itl_ms"], | ||||
|     } | ||||
|  | ||||
|     return record | ||||
|  | ||||
|  | ||||
| def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp): | ||||
|  | ||||
|     tmp = 0.0 | ||||
|     for i in range(draft_token_step): | ||||
|         tmp += pow(acceptance_rate, i + 1) | ||||
|  | ||||
|     r_ac = tmp / (1 + tmp) | ||||
|  | ||||
|     return t_ori / ((1 - r_ac) * t_mtp) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     base_url = f"http://{args.host}:{args.port}" | ||||
|  | ||||
|     input_requests = prepare_input_requests( | ||||
|         args.num_prompts, args.dataset_name, args.dataset_path | ||||
|     ) | ||||
|  | ||||
|     if len(args.max_concurrency) != len(args.s_itl_base_model): | ||||
|         raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model") | ||||
|  | ||||
|     for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model): | ||||
|         # Wramup | ||||
|         print("Starting warmup...") | ||||
|         with open(os.devnull, "w") as f: | ||||
|             with contextlib.redirect_stdout(f): | ||||
|                 send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True) | ||||
|  | ||||
|         # Benchmark | ||||
|         record = send_one_batch(base_url, max_concurrency, input_requests, False) | ||||
|  | ||||
|         metric_header = f"Speed up" | ||||
|         print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) | ||||
|         for draft_token_step in args.draft_token_steps: | ||||
|             speedup = calculate_speedup( | ||||
|                 args.acceptance_rate, | ||||
|                 draft_token_step, | ||||
|                 s_itl, | ||||
|                 record["mean_s_itl_ms"], | ||||
|             ) | ||||
|             print( | ||||
|                 "{:<40} {:<10.2f}".format( | ||||
|                     f"Speed up on {draft_token_step} steps draft", speedup | ||||
|                 ) | ||||
|             ) | ||||
|         print("=" * 50) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument( | ||||
|         "--host", | ||||
|         type=str, | ||||
|         default="127.0.0.1", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--port", | ||||
|         type=str, | ||||
|         default="8000", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max-concurrency", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=(1, 2, 4, 8, 16, 32), | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num-prompts", | ||||
|         type=int, | ||||
|         default=128, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--acceptance-rate", | ||||
|         type=float, | ||||
|         default=0.8, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--draft-token-steps", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=(1, 2), | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--s_itl-base-model", | ||||
|         type=float, | ||||
|         nargs="+", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dataset-name", | ||||
|         type=str, | ||||
|         default="EBChat", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dataset-path", | ||||
|         type=str, | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args) | ||||
| @@ -73,7 +73,7 @@ __global__ void speculate_verify( | ||||
|     const int *output_cum_offsets, const int *actual_candidate_len, | ||||
|     const int real_bsz, const int max_draft_tokens, const int end_length, | ||||
|     const int max_seq_len, const int max_candidate_len, const int verify_window, | ||||
|     const bool prefill_one_step_stop) { | ||||
|     const bool prefill_one_step_stop, const bool benchmark_mode) { | ||||
|   const int bid = threadIdx.x; | ||||
|   // verify and set stop flags | ||||
|   int accept_num_now = 1; | ||||
| @@ -95,6 +95,9 @@ __global__ void speculate_verify( | ||||
|       // printf("seq_lens_this_time[%d]-1: %d \n",bid, | ||||
|       // seq_lens_this_time[bid]-1); | ||||
|       for (; i < seq_lens_this_time[bid] - 1; i++) { | ||||
|         if (benchmark_mode) { | ||||
|           break; | ||||
|         } | ||||
|         if (seq_lens_encoder[bid] != 0) { | ||||
|           break; | ||||
|         } | ||||
| @@ -246,7 +249,7 @@ void SpeculateVerify( | ||||
|     const paddle::Tensor &output_cum_offsets, | ||||
|     const paddle::Tensor &actual_candidate_len, | ||||
|     const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, | ||||
|     int max_seq_len, int verify_window, bool enable_topp) { | ||||
|     int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) { | ||||
|   //   printf("Enter speculate update\n"); | ||||
|   auto bsz = accept_tokens.shape()[0]; | ||||
|   int real_bsz = seq_lens_this_time.shape()[0]; | ||||
| @@ -301,7 +304,7 @@ void SpeculateVerify( | ||||
|           is_block_step.data<bool>(), output_cum_offsets.data<int>(), | ||||
|           actual_candidate_len.data<int>(), real_bsz, max_draft_tokens, | ||||
|           end_length, max_seq_len, max_candidate_len, verify_window, | ||||
|           prefill_one_step_stop); | ||||
|           prefill_one_step_stop, benchmark_mode); | ||||
|     } else { | ||||
|       speculate_verify<false, true> | ||||
|           <<<1, BlockSize, 0, accept_tokens.stream()>>>( | ||||
| @@ -317,7 +320,7 @@ void SpeculateVerify( | ||||
|               end_tokens.data<int64_t>(), is_block_step.data<bool>(), | ||||
|               output_cum_offsets.data<int>(), actual_candidate_len.data<int>(), | ||||
|               real_bsz, max_draft_tokens, end_length, max_seq_len, | ||||
|               max_candidate_len, verify_window, prefill_one_step_stop); | ||||
|               max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); | ||||
|     } | ||||
|   } else { | ||||
|     if (enable_topp) { | ||||
| @@ -335,7 +338,7 @@ void SpeculateVerify( | ||||
|               end_tokens.data<int64_t>(), is_block_step.data<bool>(), | ||||
|               output_cum_offsets.data<int>(), actual_candidate_len.data<int>(), | ||||
|               real_bsz, max_draft_tokens, end_length, max_seq_len, | ||||
|               max_candidate_len, verify_window, prefill_one_step_stop); | ||||
|               max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); | ||||
|     } else { | ||||
|       speculate_verify<false, false> | ||||
|           <<<1, BlockSize, 0, accept_tokens.stream()>>>( | ||||
| @@ -351,7 +354,7 @@ void SpeculateVerify( | ||||
|               end_tokens.data<int64_t>(), is_block_step.data<bool>(), | ||||
|               output_cum_offsets.data<int>(), actual_candidate_len.data<int>(), | ||||
|               real_bsz, max_draft_tokens, end_length, max_seq_len, | ||||
|               max_candidate_len, verify_window, prefill_one_step_stop); | ||||
|               max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @@ -366,7 +369,7 @@ PD_BUILD_STATIC_OP(speculate_verify) | ||||
|              "actual_candidate_len", "actual_draft_token_nums", "topp"}) | ||||
|     .Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out", | ||||
|               "stop_flags_out"}) | ||||
|     .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"}) | ||||
|     .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"}) | ||||
|     .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, | ||||
|                     {"accept_num", "accept_num_out"}, | ||||
|                     {"step_idx", "step_idx_out"}, | ||||
|   | ||||
| @@ -238,6 +238,10 @@ class SpeculativeConfig: | ||||
|     # A trick method is currently used to enable this sharing. | ||||
|     # This will be replaced with a more standardized solution in the future. | ||||
|     sharing_model = None | ||||
|     # During benchmarking, we need to enforce that the number of accepted tokens is 1. | ||||
|     # This means no tokens from MTP are accepted. | ||||
|     # This ensures that the specified simulation acceptance rate is not affected. | ||||
|     benchmark_mode: bool = False | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
|   | ||||
| @@ -337,6 +337,7 @@ class SpeculativeConfig: | ||||
|         model_name_or_path (Optional[str]): Path of the model. | ||||
|         quantization (str): Quantization method for draft model, default is WINT8. | ||||
|         max_model_len: Optional[int]: Maximum model length for draft model. | ||||
|         benchmark_mode (bool): Whether to use benchmark mode. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
| @@ -345,12 +346,14 @@ class SpeculativeConfig: | ||||
|                  model: Optional[str] = None, | ||||
|                  quantization: Optional[str] = "WINT8", | ||||
|                  max_model_len: Optional[int] = None, | ||||
|                  benchmark_mode: bool = False, | ||||
|                  **kwargs): | ||||
|         self.model_name_or_path = model | ||||
|         self.method = method | ||||
|         self.num_speculative_tokens = num_speculative_tokens | ||||
|         self.quantization = quantization | ||||
|         self.max_model_len = max_model_len | ||||
|         self.benchmark_mode = benchmark_mode | ||||
|         # Fixed now | ||||
|         self.num_gpu_block_expand_ratio = 1 | ||||
|         self.num_extra_cache_layer = 0 | ||||
|   | ||||
| @@ -1030,6 +1030,7 @@ class LLMEngine(object): | ||||
|             f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}" | ||||
|             f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}" | ||||
|             f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" | ||||
|             f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}" | ||||
|             f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" | ||||
|             f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" | ||||
|             f" --load_strategy {self.cfg.model_config.load_strategy}") | ||||
|   | ||||
| @@ -235,6 +235,7 @@ class SpeculativeSampler(nn.Layer): | ||||
|             raise NotImplementedError() | ||||
|         self.speculative_verify_window = fd_config.speculative_config.verify_window | ||||
|         self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len | ||||
|         self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode | ||||
|  | ||||
|     def pre_process(self, skip_idx_list: List[int] = []): | ||||
|         """ pre process before running """ | ||||
| @@ -309,6 +310,7 @@ class SpeculativeSampler(nn.Layer): | ||||
|             max_model_len, | ||||
|             self.speculative_verify_window, | ||||
|             True,  # enable_topp | ||||
|             self.speculative_benchmark_mode, | ||||
|         ) | ||||
|  | ||||
|         return None | ||||
|   | ||||
| @@ -494,6 +494,11 @@ def parse_args(): | ||||
|         default="WINT8", | ||||
|         type=str, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--speculative_benchmark_mode", | ||||
|         default="false", | ||||
|         type=str, | ||||
|     ) | ||||
|     parser.add_argument("--max_num_batched_tokens", | ||||
|                         type=int, | ||||
|                         default=2048, | ||||
| @@ -625,6 +630,9 @@ def initialize_fd_config(config_or_args) -> FDConfig: | ||||
|     speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0) | ||||
|     speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None) | ||||
|     speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None) | ||||
|     speculative_config.benchmark_mode = ( | ||||
|         getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true" | ||||
|     ) | ||||
|  | ||||
|     # Update parallel config | ||||
|     parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 GoldPancake
					GoldPancake