mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Add speculative decoding simulation benchmark. (#2751)
* Add speculative decoding simulation benchmark * Fix the name of the parameter
This commit is contained in:
@@ -494,6 +494,11 @@ def parse_args():
|
||||
default="WINT8",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_benchmark_mode",
|
||||
default="false",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument("--max_num_batched_tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
@@ -625,6 +630,9 @@ def initialize_fd_config(config_or_args) -> FDConfig:
|
||||
speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0)
|
||||
speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None)
|
||||
speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None)
|
||||
speculative_config.benchmark_mode = (
|
||||
getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Update parallel config
|
||||
parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None)
|
||||
|
Reference in New Issue
Block a user