[Feature] Support repetition early stop (#3024)

* support repetition early stop and support user to set the parameter

* remove log

* fix codestyle

* add the early_stop_config to rollout_config

* update config and EarlyStopper class

* fix the bug for triton

* modify the stop method

* update description

* modify the usage for stop_flags

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Zero Rains
2025-07-29 22:42:54 +08:00
committed by GitHub
parent 3214fb5393
commit b2f9a42d87
13 changed files with 575 additions and 4 deletions

View File

@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
Sampler for normal generation.
"""
def __init__(self):
def __init__(self, fd_config: FDConfig = None):
""" """
super().__init__()
if (
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
raise NotImplementedError
self.processor = SamplerProcessor()
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
fd_config is not None
and fd_config.early_stop_config is not None
and fd_config.early_stop_config.enable_early_stop
):
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
def apply_logits_processor(
self,
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
logprobs_tensors = (
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
)
if sampling_metadata.enable_early_stop:
# will set the stop batch in stop_flags
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
self.processor.update_output_tokens(next_tokens, skip_idx_list)