mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
DecodingConfig,
|
||||
DeviceConfig,
|
||||
EarlyStopConfig,
|
||||
ErnieArchitectures,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
@@ -565,6 +566,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Enable output of token-level log probabilities.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stop_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Configuration of early stop.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -608,6 +615,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
||||
|
||||
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||
@@ -679,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
early_stop_config=early_stop_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
|
Reference in New Issue
Block a user