This commit is contained in:
Zhang Yulong
2025-11-28 18:29:16 +08:00
committed by GitHub
parent a535050b11
commit 5b49142988
4 changed files with 561 additions and 29 deletions

View File

@@ -39,7 +39,7 @@ from backend_request_func import (
RequestFuncInput,
RequestFuncOutput,
)
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_dataset import EBChatDataset, EBDataset, RandomTextDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm
@@ -337,6 +337,7 @@ async def benchmark(
)
test_history_QA = input_requests[0].history_QA
response_format = input_requests[0].response_format
random_flag = input_requests[0].random_flag
test_input = RequestFuncInput(
model=model_id,
@@ -353,6 +354,7 @@ async def benchmark(
debug=debug,
extra_body=extra_body,
response_format=response_format,
random_flag=random_flag,
)
print("test_input:", test_input)
@@ -385,6 +387,7 @@ async def benchmark(
ignore_eos=ignore_eos,
extra_body=extra_body,
response_format=response_format,
random_flag=random_flag,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@@ -424,6 +427,7 @@ async def benchmark(
)
history_QA = request.history_QA
response_format = request.response_format
random_flag = request.random_flag
req_model_id, req_model_name = model_id, model_name
if lora_modules:
@@ -445,6 +449,7 @@ async def benchmark(
ignore_eos=ignore_eos,
extra_body=extra_body,
response_format=response_format,
random_flag=random_flag,
)
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
@@ -461,6 +466,7 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
response_format=response_format,
random_flag=random_flag,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@@ -498,6 +504,12 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
print(f"benchmark_duration: {benchmark_duration}")
if random_flag:
print("指定随机输入输出长度测试")
print(f"random_input_len: {args.random_input_len}")
print(f"random_output_len: {args.random_output_len}")
print(f"random_range_ratio: {args.random_range_ratio}")
metrics, actual_output_lens = calculate_metrics(
# input_requests=input_requests,
outputs=benchmark_outputs,
@@ -866,6 +878,12 @@ def main(args: argparse.Namespace):
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"random": lambda: RandomTextDataset().sample(
num_requests=args.num_prompts,
random_input_len=args.random_input_len,
random_output_len=args.random_output_len,
random_range_ratio=args.random_range_ratio,
),
}
try:
@@ -1021,15 +1039,10 @@ if __name__ == "__main__":
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
default="EBChat",
choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
"random",
],
help="Name of the dataset to benchmark on.",
)
@@ -1247,37 +1260,24 @@ if __name__ == "__main__":
random_group.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random sampling.",
default=200,
help="Number of input English words per request, used only for random-text dataset.",
)
random_group.add_argument(
"--random-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for random sampling.",
default=1024,
help="Number of output tokens per request, used both for random and random-text datasets.",
)
random_group.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
default=0.1,
help="Range ratio for sampling input/output length, "
"used only for random sampling. Must be in the range [0, 1) to define "
"a symmetric sampling range"
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
random_group.add_argument(
"--random-prefix-len",
type=int,
default=0,
help=(
"Number of fixed prefix tokens before the random context "
"in a request. "
"The total input length is the sum of `random-prefix-len` and "
"a random "
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")