mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
update (#5298)
This commit is contained in:
@@ -39,7 +39,7 @@ from backend_request_func import (
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
)
|
||||
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
||||
from benchmark_dataset import EBChatDataset, EBDataset, RandomTextDataset, SampleRequest
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
@@ -337,6 +337,7 @@ async def benchmark(
|
||||
)
|
||||
test_history_QA = input_requests[0].history_QA
|
||||
response_format = input_requests[0].response_format
|
||||
random_flag = input_requests[0].random_flag
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@@ -353,6 +354,7 @@ async def benchmark(
|
||||
debug=debug,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
|
||||
print("test_input:", test_input)
|
||||
@@ -385,6 +387,7 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@@ -424,6 +427,7 @@ async def benchmark(
|
||||
)
|
||||
history_QA = request.history_QA
|
||||
response_format = request.response_format
|
||||
random_flag = request.random_flag
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
@@ -445,6 +449,7 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
@@ -461,6 +466,7 @@ async def benchmark(
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format,
|
||||
random_flag=random_flag,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@@ -498,6 +504,12 @@ async def benchmark(
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
print(f"benchmark_duration: {benchmark_duration} 秒")
|
||||
|
||||
if random_flag:
|
||||
print("指定随机输入输出长度测试")
|
||||
print(f"random_input_len: {args.random_input_len}")
|
||||
print(f"random_output_len: {args.random_output_len}")
|
||||
print(f"random_range_ratio: {args.random_range_ratio}")
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
# input_requests=input_requests,
|
||||
outputs=benchmark_outputs,
|
||||
@@ -866,6 +878,12 @@ def main(args: argparse.Namespace):
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"random": lambda: RandomTextDataset().sample(
|
||||
num_requests=args.num_prompts,
|
||||
random_input_len=args.random_input_len,
|
||||
random_output_len=args.random_output_len,
|
||||
random_range_ratio=args.random_range_ratio,
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -1021,15 +1039,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
default="EBChat",
|
||||
choices=[
|
||||
"sharegpt",
|
||||
"burstgpt",
|
||||
"sonnet",
|
||||
"random",
|
||||
"hf",
|
||||
"EB",
|
||||
"EBChat",
|
||||
"random",
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
@@ -1247,37 +1260,24 @@ if __name__ == "__main__":
|
||||
random_group.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Number of input tokens per request, used only for random sampling.",
|
||||
default=200,
|
||||
help="Number of input English words per request, used only for random-text dataset.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Number of output tokens per request, used only for random sampling.",
|
||||
default=1024,
|
||||
help="Number of output tokens per request, used both for random and random-text datasets.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
default=0.1,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."
|
||||
),
|
||||
)
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
|
||||
|
||||
Reference in New Issue
Block a user