[Feature] Add speculative decoding simulation benchmark. (#2751)

* Add speculative decoding simulation benchmark

* Fix the name of the parameter
This commit is contained in:
GoldPancake
2025-07-09 12:08:43 +08:00
committed by GitHub
parent 6b10c19482
commit f7cad30a38
8 changed files with 246 additions and 7 deletions

View File

@@ -494,6 +494,11 @@ def parse_args():
default="WINT8",
type=str,
)
parser.add_argument(
"--speculative_benchmark_mode",
default="false",
type=str,
)
parser.add_argument("--max_num_batched_tokens",
type=int,
default=2048,
@@ -625,6 +630,9 @@ def initialize_fd_config(config_or_args) -> FDConfig:
speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0)
speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None)
speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None)
speculative_config.benchmark_mode = (
getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true"
)
# Update parallel config
parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None)