mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
EarlyStopConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
SpeculativeConfig,
|
||||
@@ -313,6 +314,16 @@ class EngineArgs:
|
||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||
"""
|
||||
|
||||
enable_early_stop: bool = False
|
||||
"""
|
||||
Flag to enable early stop. Default is False (disabled).
|
||||
"""
|
||||
|
||||
early_stop_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Configuration for early stop.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -464,6 +475,18 @@ class EngineArgs:
|
||||
default=EngineArgs.enable_logprob,
|
||||
help="Enable output of token-level log probabilities.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--enable-early-stop",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_early_stop,
|
||||
help="Enable early stopping during generation.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--early-stop-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.early_stop_config,
|
||||
help="the config for early stop.",
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
@@ -811,6 +834,16 @@ class EngineArgs:
|
||||
graph_optimization_args[k] = v
|
||||
return GraphOptimizationConfig(graph_optimization_args)
|
||||
|
||||
def create_early_stop_config(self) -> EarlyStopConfig:
|
||||
"""
|
||||
Create and retuan an EarlyStopConfig object based on the current settings.
|
||||
"""
|
||||
early_stop_args = asdict(self)
|
||||
if self.early_stop_config is not None:
|
||||
for k, v in self.early_stop_config.items():
|
||||
early_stop_args[k] = v
|
||||
return EarlyStopConfig(early_stop_args)
|
||||
|
||||
def create_engine_config(self) -> Config:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
@@ -833,6 +866,9 @@ class EngineArgs:
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
|
||||
assert not (
|
||||
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
||||
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
||||
@@ -866,4 +902,5 @@ class EngineArgs:
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
enable_logprob=self.enable_logprob,
|
||||
early_stop_config=early_stop_cfg,
|
||||
)
|
||||
|
Reference in New Issue
Block a user