benchmark工具支持受限解码场景指定response_format (#4718)

This commit is contained in:
ophilia-lee
2025-10-31 12:26:24 +08:00
committed by GitHub
parent 10de7a3b82
commit 412097c1b8
3 changed files with 15 additions and 2 deletions

View File

@@ -51,6 +51,7 @@ class RequestFuncInput:
ignore_eos: bool = False
language: Optional[str] = None
debug: bool = False
response_format: Optional[dict] = None
@dataclass
@@ -93,8 +94,11 @@ async def async_request_eb_openai_chat_completions(
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
}
}
if request_func_input.response_format:
payload["response_format"] =request_func_input.response_format
# 超参由yaml传入
payload.update(request_func_input.hyper_parameters)

View File

@@ -45,7 +45,8 @@ class SampleRequest:
json_data: Optional[dict]
prompt_len: int
expected_output_len: int
response_format: Optional[dict] = None
class BenchmarkDataset(ABC):
"""BenchmarkDataset"""
@@ -297,6 +298,7 @@ class EBChatDataset(BenchmarkDataset):
json_data = entry
prompt = entry["messages"][-1].get("content", "")
history_QA = entry.get("messages", [])
response_format = entry.get("response_format")
new_output_len = int(entry.get("max_tokens", 12288))
if enable_multimodal_chat:
@@ -309,6 +311,7 @@ class EBChatDataset(BenchmarkDataset):
prompt_len=0,
history_QA=history_QA,
expected_output_len=new_output_len,
response_format=response_format
)
)
cnt += 1

View File

@@ -336,6 +336,7 @@ async def benchmark(
input_requests[0].no,
)
test_history_QA = input_requests[0].history_QA
response_format = input_requests[0].response_format
test_input = RequestFuncInput(
model=model_id,
@@ -351,6 +352,7 @@ async def benchmark(
ignore_eos=ignore_eos,
debug=debug,
extra_body=extra_body,
response_format=response_format
)
print("test_input:", test_input)
@@ -382,6 +384,7 @@ async def benchmark(
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body,
response_format=response_format
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@@ -420,6 +423,7 @@ async def benchmark(
request.no,
)
history_QA = request.history_QA
response_format = request.response_format
req_model_id, req_model_name = model_id, model_name
if lora_modules:
@@ -440,6 +444,7 @@ async def benchmark(
debug=debug,
ignore_eos=ignore_eos,
extra_body=extra_body,
response_format=response_format
)
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
@@ -455,6 +460,7 @@ async def benchmark(
api_url=base_url + "/stop_profile",
output_len=test_output_len,
logprobs=logprobs,
response_format=response_format
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: