mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -18,28 +18,16 @@ import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
||||
from benchmark_dataset import EBChatDataset, EBDataset
|
||||
from benchmark_serving import benchmark
|
||||
|
||||
|
||||
def prepare_input_requests(
|
||||
num_prompts: int, dataset_name: str, dataset_path: str
|
||||
) -> Union[EBDataset, EBChatDataset]:
|
||||
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
|
||||
dataset_mapping = {
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -104,24 +92,27 @@ def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
input_requests = prepare_input_requests(
|
||||
args.num_prompts, args.dataset_name, args.dataset_path
|
||||
)
|
||||
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
|
||||
|
||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
|
||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
|
||||
send_one_batch(
|
||||
base_url,
|
||||
max_concurrency,
|
||||
input_requests[0:max_concurrency],
|
||||
True,
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||
|
||||
metric_header = f"Speed up"
|
||||
metric_header = "Speed up"
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
for draft_token_step in args.draft_token_steps:
|
||||
speedup = calculate_speedup(
|
||||
@@ -130,11 +121,7 @@ def main(args):
|
||||
s_itl,
|
||||
record["mean_s_itl_ms"],
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Speed up on {draft_token_step} steps draft", speedup
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user