mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -567,6 +567,74 @@ class GraphOptimizationConfig:
|
||||
argument = self.use_cudagraph
|
||||
|
||||
|
||||
class EarlyStopConfig:
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
"""
|
||||
Early Stop Configuration class.
|
||||
|
||||
Attributes:
|
||||
window_size: size of the window
|
||||
threshold: trigger early stop when the ratio of probs exceeds the threshold
|
||||
"""
|
||||
"""enable to use early stop"""
|
||||
self.enable_early_stop: bool = False
|
||||
"""strategy for early stop, the strategy lists are ['repetition']"""
|
||||
self.strategy: str = "repetition"
|
||||
""" the maximum length of verify window for early stop """
|
||||
self.window_size: int = 3000
|
||||
""" the probs threshold for early stop """
|
||||
self.threshold: float = 0.99
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
self.check_legality_parameters()
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert early_stop_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
if self.enable_early_stop is not None:
|
||||
assert isinstance(
|
||||
self.enable_early_stop, bool
|
||||
), "In early stop config, type of enable_early_stop must is bool."
|
||||
if self.window_size is not None:
|
||||
assert isinstance(self.window_size, int), "In early stop config, type of window_size must be int."
|
||||
assert self.window_size > 0, "window_size must large than 0"
|
||||
if self.threshold is not None:
|
||||
assert isinstance(self.threshold, float), "In early stop config, type of threshold must be float."
|
||||
assert self.threshold >= 0 and self.threshold <= 1, "threshold must between 0 and 1"
|
||||
|
||||
def update_enable_early_stop(self, argument: bool):
|
||||
"""
|
||||
Unified user specifies the enable_early_stop parameter through two methods,
|
||||
'--enable-early-stop' and '--early-stop-config'
|
||||
"""
|
||||
if self.enable_early_stop is None:
|
||||
# User only set '--enable-early-stop'
|
||||
self.enable_early_stop = argument
|
||||
else:
|
||||
# User both set '--enable-early-stop' and '--early-stop-config'
|
||||
if self.enable_early_stop is False and argument is True:
|
||||
raise ValueError(
|
||||
"Invalid parameter: Cannot set ---enable-early-stop and --early-stop-config '{\"enable_early_stop\":false}' simultaneously."
|
||||
)
|
||||
argument = self.enable_early_stop
|
||||
|
||||
|
||||
class LoadConfig:
|
||||
"""
|
||||
Configuration for dynamic weight loading strategies
|
||||
@@ -776,6 +844,7 @@ class FDConfig:
|
||||
load_config: LoadConfig = field(default=None, init=True)
|
||||
quant_config: Optional[QuantConfigBase] = None
|
||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||
early_stop_config: Optional[EarlyStopConfig] = None
|
||||
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
|
||||
|
||||
Reference in New Issue
Block a user