[Feature] Support repetition early stop (#3024)

* support repetition early stop and support user to set the parameter

* remove log

* fix codestyle

* add the early_stop_config to rollout_config

* update config and EarlyStopper class

* fix the bug for triton

* modify the stop method

* update description

* modify the usage for stop_flags

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Zero Rains
2025-07-29 22:42:54 +08:00
committed by GitHub
parent 3214fb5393
commit b2f9a42d87
13 changed files with 575 additions and 4 deletions

View File

@@ -0,0 +1,129 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import abstractmethod
import paddle
from fastdeploy.config import EarlyStopConfig
class EarlyStopper:
@abstractmethod
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
"""
Initialize the stopper and set hyper-parameters.
args:
- batch_size: int, the batch size of input
- cfg: EarlyStopConfig
"""
raise NotImplementedError
@abstractmethod
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
"""
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
args:
- probs: [batch_size, vocab_size], the probs of every sample
- next_tokens: [batch_size, 1], the token index of every chosen sample
- stop_flags: [batch_size, 1], determine which batch will be stopped
"""
raise NotImplementedError
class RepetitionEarlyStopper(EarlyStopper):
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
self.early_stop_cfg = cfg
self.window_size = cfg.window_size
self.threshold = cfg.threshold
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
"""
args:
- probs: [batch_size, vocab_size], the probs of every sample
- next_tokens: [batch_size, 1], the token index of every chosen sample
- stop_flags: [batch_size, 1], determine which batch will be stopped
"""
# It will use normal execute if there is no triton support, otherwise use triton
try:
self.process_triton(probs, next_tokens, stop_flags)
except Exception:
self.process_normal(probs, next_tokens, stop_flags)
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
# Get the probability score corresponding to next_tokens in this step
next_scores = paddle.index_sample(probs, next_tokens)
# Sliding window: Move left one grid and insert new score
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
self.trunc_scores[:, -1:] = next_scores
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
# Add the stop flags
stop_flags[need_trunc_all] = True
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
reset_mask = need_trunc_all.tile([1, self.window_size])
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
import triton
from fastdeploy.model_executor.ops.triton_ops import (
repetition_early_stopper_kernel,
)
B, W = self.trunc_scores.shape
V = probs.shape[1]
BLOCK_W = triton.next_power_of_2(W)
grid = (B,)
repetition_early_stopper_kernel[grid](
self.trunc_scores,
probs,
next_tokens,
stop_flags,
self.threshold,
B,
W,
V,
self.trunc_scores.shape[1],
probs.shape[1],
BLOCK_W=BLOCK_W,
)
return next_tokens
# mapping strategy name to class
EARLY_STOPPER_MAPPING = {
"repetition": RepetitionEarlyStopper,
}
def get_early_stopper_cls_from_stragegy(strategy: str):
"""
get early stopper class from strategy name
args:
- strategy: string, the strategy name
"""
strategy = strategy.lower()
assert (
strategy in EARLY_STOPPER_MAPPING
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
return EARLY_STOPPER_MAPPING[strategy]

View File

@@ -44,5 +44,7 @@ class SamplingMetadata:
top_k: Optional[paddle.Tensor] = None
min_p: Optional[paddle.Tensor] = None
max_num_logprobs: Optional[int] = None
enable_early_stop: Optional[int] = False
stop_flags: Optional[paddle.Tensor] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None

View File

@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
Sampler for normal generation.
"""
def __init__(self):
def __init__(self, fd_config: FDConfig = None):
""" """
super().__init__()
if (
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
raise NotImplementedError
self.processor = SamplerProcessor()
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
fd_config is not None
and fd_config.early_stop_config is not None
and fd_config.early_stop_config.enable_early_stop
):
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
def apply_logits_processor(
self,
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
logprobs_tensors = (
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
)
if sampling_metadata.enable_early_stop:
# will set the stop batch in stop_flags
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
self.processor.update_output_tokens(next_tokens, skip_idx_list)