diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index db2b560ca..1640fd23f 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -39,6 +39,10 @@ from fastdeploy import envs T = TypeVar("T") +# [N,2] -> every line is [config_name, enable_xxx_name] +# Make sure enable_xxx equal to config.enable_xxx +ARGS_CORRECTION_LIST = [["early_stop_config", "enable_early_stop"], ["graph_optimization_config", "use_cudagraph"]] + class EngineError(Exception): """Base exception class for engine errors""" @@ -361,8 +365,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser): namespace = argparse.Namespace() for key, value in filtered_config.items(): setattr(namespace, key, value) + args = super().parse_args(args=remaining_args, namespace=namespace) - return super().parse_args(args=remaining_args, namespace=namespace) + # Args correction + for config_name, flag_name in ARGS_CORRECTION_LIST: + if hasattr(args, config_name) and hasattr(args, flag_name): + # config is a dict + config = getattr(args, config_name, None) + if config is not None and flag_name in config.keys(): + setattr(args, flag_name, config[flag_name]) + return args def resolve_obj_from_strname(strname: str):