mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
benchmark工具支持受限解码场景指定response_format (#4718)
This commit is contained in:
@@ -51,6 +51,7 @@ class RequestFuncInput:
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
debug: bool = False
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,8 +94,11 @@ async def async_request_eb_openai_chat_completions(
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
if request_func_input.response_format:
|
||||
payload["response_format"] =request_func_input.response_format
|
||||
|
||||
# 超参由yaml传入
|
||||
payload.update(request_func_input.hyper_parameters)
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ class SampleRequest:
|
||||
json_data: Optional[dict]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
"""BenchmarkDataset"""
|
||||
@@ -297,6 +298,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
json_data = entry
|
||||
prompt = entry["messages"][-1].get("content", "")
|
||||
history_QA = entry.get("messages", [])
|
||||
response_format = entry.get("response_format")
|
||||
new_output_len = int(entry.get("max_tokens", 12288))
|
||||
|
||||
if enable_multimodal_chat:
|
||||
@@ -309,6 +311,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
expected_output_len=new_output_len,
|
||||
response_format=response_format
|
||||
)
|
||||
)
|
||||
cnt += 1
|
||||
|
||||
@@ -336,6 +336,7 @@ async def benchmark(
|
||||
input_requests[0].no,
|
||||
)
|
||||
test_history_QA = input_requests[0].history_QA
|
||||
response_format = input_requests[0].response_format
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@@ -351,6 +352,7 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
debug=debug,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
print("test_input:", test_input)
|
||||
@@ -382,6 +384,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@@ -420,6 +423,7 @@ async def benchmark(
|
||||
request.no,
|
||||
)
|
||||
history_QA = request.history_QA
|
||||
response_format = request.response_format
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
@@ -440,6 +444,7 @@ async def benchmark(
|
||||
debug=debug,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
response_format=response_format
|
||||
)
|
||||
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
@@ -455,6 +460,7 @@ async def benchmark(
|
||||
api_url=base_url + "/stop_profile",
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
|
||||
Reference in New Issue
Block a user