""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import argparse import asyncio import contextlib import os import signal import socket import subprocess import time from typing import Union import openai import yaml from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest from benchmark_serving import benchmark def prepare_input_requests( num_prompts: int, dataset_name: str, dataset_path: str ) -> Union[EBDataset, EBChatDataset]: dataset_mapping = { "EB": lambda: EBDataset(dataset_path=dataset_path).sample( num_requests=num_prompts ), "EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample( num_requests=num_prompts ), } try: input_requests = dataset_mapping[dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {dataset_name}") from err return input_requests class FakeTokenizer: def encode(self, text: str, add_special_tokens: bool = False): return [] def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm): selected_percentile_metrics = ["s_itl"] selected_percentiles = [] # Run benchmark results = asyncio.run( benchmark( backend="openai-chat", api_url=f"{base_url}/v1/chat/completions", base_url=base_url, model_id="default", model_name="default", input_requests=input_requests, hyper_parameters={}, logprobs=None, request_rate=float("inf"), burstiness=1.0, disable_tqdm=disable_tqdm, profile=False, selected_percentile_metrics=selected_percentile_metrics, selected_percentiles=selected_percentiles, ignore_eos=False, goodput_config_dict=None, max_concurrency=max_concurrency, lora_modules=None, extra_body=None, ) ) record = { "mean_s_itl_ms": results["mean_s_itl_ms"], } return record def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp): tmp = 0.0 for i in range(draft_token_step): tmp += pow(acceptance_rate, i + 1) r_ac = tmp / (1 + tmp) return t_ori / ((1 - r_ac) * t_mtp) def main(args): base_url = f"http://{args.host}:{args.port}" input_requests = prepare_input_requests( args.num_prompts, args.dataset_name, args.dataset_path ) if len(args.max_concurrency) != len(args.s_itl_base_model): raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model") for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model): # Wramup print("Starting warmup...") with open(os.devnull, "w") as f: with contextlib.redirect_stdout(f): send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True) # Benchmark record = send_one_batch(base_url, max_concurrency, input_requests, False) metric_header = f"Speed up" print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) for draft_token_step in args.draft_token_steps: speedup = calculate_speedup( args.acceptance_rate, draft_token_step, s_itl, record["mean_s_itl_ms"], ) print( "{:<40} {:<10.2f}".format( f"Speed up on {draft_token_step} steps draft", speedup ) ) print("=" * 50) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--host", type=str, default="127.0.0.1", ) parser.add_argument( "--port", type=str, default="8000", ) parser.add_argument( "--max-concurrency", type=int, nargs="+", default=(1, 2, 4, 8, 16, 32), ) parser.add_argument( "--num-prompts", type=int, default=128, ) parser.add_argument( "--acceptance-rate", type=float, default=0.8, ) parser.add_argument( "--draft-token-steps", type=int, nargs="+", default=(1, 2), ) parser.add_argument( "--s_itl-base-model", type=float, nargs="+", ) parser.add_argument( "--dataset-name", type=str, default="EBChat", ) parser.add_argument( "--dataset-path", type=str, ) args = parser.parse_args() main(args)