mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
179 lines
5.0 KiB
Python
179 lines
5.0 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import contextlib
|
|
import os
|
|
from typing import Union
|
|
|
|
from benchmark_dataset import EBChatDataset, EBDataset
|
|
from benchmark_serving import benchmark
|
|
|
|
|
|
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
|
|
dataset_mapping = {
|
|
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
|
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
|
}
|
|
|
|
try:
|
|
input_requests = dataset_mapping[dataset_name]()
|
|
except KeyError as err:
|
|
raise ValueError(f"Unknown dataset: {dataset_name}") from err
|
|
|
|
return input_requests
|
|
|
|
|
|
class FakeTokenizer:
|
|
def encode(self, text: str, add_special_tokens: bool = False):
|
|
return []
|
|
|
|
|
|
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
|
|
selected_percentile_metrics = ["s_itl"]
|
|
selected_percentiles = []
|
|
# Run benchmark
|
|
results = asyncio.run(
|
|
benchmark(
|
|
backend="openai-chat",
|
|
api_url=f"{base_url}/v1/chat/completions",
|
|
base_url=base_url,
|
|
model_id="default",
|
|
model_name="default",
|
|
input_requests=input_requests,
|
|
hyper_parameters={},
|
|
logprobs=None,
|
|
request_rate=float("inf"),
|
|
burstiness=1.0,
|
|
disable_tqdm=disable_tqdm,
|
|
profile=False,
|
|
selected_percentile_metrics=selected_percentile_metrics,
|
|
selected_percentiles=selected_percentiles,
|
|
ignore_eos=False,
|
|
goodput_config_dict=None,
|
|
max_concurrency=max_concurrency,
|
|
lora_modules=None,
|
|
extra_body=None,
|
|
)
|
|
)
|
|
|
|
record = {
|
|
"mean_s_itl_ms": results["mean_s_itl_ms"],
|
|
}
|
|
|
|
return record
|
|
|
|
|
|
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
|
|
|
tmp = 0.0
|
|
for i in range(draft_token_step):
|
|
tmp += pow(acceptance_rate, i + 1)
|
|
|
|
r_ac = tmp / (1 + tmp)
|
|
|
|
return t_ori / ((1 - r_ac) * t_mtp)
|
|
|
|
|
|
def main(args):
|
|
base_url = f"http://{args.host}:{args.port}"
|
|
|
|
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
|
|
|
|
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
|
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
|
|
|
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
|
# Wramup
|
|
print("Starting warmup...")
|
|
with open(os.devnull, "w") as f:
|
|
with contextlib.redirect_stdout(f):
|
|
send_one_batch(
|
|
base_url,
|
|
max_concurrency,
|
|
input_requests[0:max_concurrency],
|
|
True,
|
|
)
|
|
|
|
# Benchmark
|
|
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
|
|
|
metric_header = "Speed up"
|
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
|
for draft_token_step in args.draft_token_steps:
|
|
speedup = calculate_speedup(
|
|
args.acceptance_rate,
|
|
draft_token_step,
|
|
s_itl,
|
|
record["mean_s_itl_ms"],
|
|
)
|
|
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
|
|
print("=" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default="127.0.0.1",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=str,
|
|
default="8000",
|
|
)
|
|
parser.add_argument(
|
|
"--max-concurrency",
|
|
type=int,
|
|
nargs="+",
|
|
default=(1, 2, 4, 8, 16, 32),
|
|
)
|
|
parser.add_argument(
|
|
"--num-prompts",
|
|
type=int,
|
|
default=128,
|
|
)
|
|
parser.add_argument(
|
|
"--acceptance-rate",
|
|
type=float,
|
|
default=0.8,
|
|
)
|
|
parser.add_argument(
|
|
"--draft-token-steps",
|
|
type=int,
|
|
nargs="+",
|
|
default=(1, 2),
|
|
)
|
|
parser.add_argument(
|
|
"--s_itl-base-model",
|
|
type=float,
|
|
nargs="+",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-name",
|
|
type=str,
|
|
default="EBChat",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-path",
|
|
type=str,
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|