mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 12:44:20 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
129
fastdeploy/model_executor/layers/sample/early_stopper.py
Normal file
129
fastdeploy/model_executor/layers/sample/early_stopper.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import EarlyStopConfig
|
||||
|
||||
|
||||
class EarlyStopper:
|
||||
@abstractmethod
|
||||
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
||||
"""
|
||||
Initialize the stopper and set hyper-parameters.
|
||||
args:
|
||||
- batch_size: int, the batch size of input
|
||||
- cfg: EarlyStopConfig
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
"""
|
||||
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
|
||||
args:
|
||||
- probs: [batch_size, vocab_size], the probs of every sample
|
||||
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
||||
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RepetitionEarlyStopper(EarlyStopper):
|
||||
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
||||
self.early_stop_cfg = cfg
|
||||
self.window_size = cfg.window_size
|
||||
self.threshold = cfg.threshold
|
||||
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
|
||||
|
||||
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
"""
|
||||
args:
|
||||
- probs: [batch_size, vocab_size], the probs of every sample
|
||||
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
||||
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
||||
"""
|
||||
# It will use normal execute if there is no triton support, otherwise use triton
|
||||
try:
|
||||
self.process_triton(probs, next_tokens, stop_flags)
|
||||
except Exception:
|
||||
self.process_normal(probs, next_tokens, stop_flags)
|
||||
|
||||
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
# Get the probability score corresponding to next_tokens in this step
|
||||
next_scores = paddle.index_sample(probs, next_tokens)
|
||||
|
||||
# Sliding window: Move left one grid and insert new score
|
||||
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
||||
self.trunc_scores[:, -1:] = next_scores
|
||||
|
||||
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
||||
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
||||
|
||||
# Add the stop flags
|
||||
stop_flags[need_trunc_all] = True
|
||||
|
||||
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
||||
reset_mask = need_trunc_all.tile([1, self.window_size])
|
||||
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
|
||||
|
||||
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
import triton
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops import (
|
||||
repetition_early_stopper_kernel,
|
||||
)
|
||||
|
||||
B, W = self.trunc_scores.shape
|
||||
V = probs.shape[1]
|
||||
BLOCK_W = triton.next_power_of_2(W)
|
||||
|
||||
grid = (B,)
|
||||
repetition_early_stopper_kernel[grid](
|
||||
self.trunc_scores,
|
||||
probs,
|
||||
next_tokens,
|
||||
stop_flags,
|
||||
self.threshold,
|
||||
B,
|
||||
W,
|
||||
V,
|
||||
self.trunc_scores.shape[1],
|
||||
probs.shape[1],
|
||||
BLOCK_W=BLOCK_W,
|
||||
)
|
||||
return next_tokens
|
||||
|
||||
|
||||
# mapping strategy name to class
|
||||
EARLY_STOPPER_MAPPING = {
|
||||
"repetition": RepetitionEarlyStopper,
|
||||
}
|
||||
|
||||
|
||||
def get_early_stopper_cls_from_stragegy(strategy: str):
|
||||
"""
|
||||
get early stopper class from strategy name
|
||||
args:
|
||||
- strategy: string, the strategy name
|
||||
"""
|
||||
strategy = strategy.lower()
|
||||
assert (
|
||||
strategy in EARLY_STOPPER_MAPPING
|
||||
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
|
||||
return EARLY_STOPPER_MAPPING[strategy]
|
||||
@@ -44,5 +44,7 @@ class SamplingMetadata:
|
||||
top_k: Optional[paddle.Tensor] = None
|
||||
min_p: Optional[paddle.Tensor] = None
|
||||
max_num_logprobs: Optional[int] = None
|
||||
enable_early_stop: Optional[int] = False
|
||||
stop_flags: Optional[paddle.Tensor] = None
|
||||
prompt_ids: Optional[paddle.Tensor] = None
|
||||
prompt_lens: Optional[paddle.Tensor] = None
|
||||
|
||||
@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||
get_early_stopper_cls_from_stragegy,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores,
|
||||
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
|
||||
Sampler for normal generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, fd_config: FDConfig = None):
|
||||
""" """
|
||||
super().__init__()
|
||||
if (
|
||||
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
fd_config is not None
|
||||
and fd_config.early_stop_config is not None
|
||||
and fd_config.early_stop_config.enable_early_stop
|
||||
):
|
||||
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
|
||||
self.early_stopper = early_stopper_cls()
|
||||
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
||||
|
||||
def apply_logits_processor(
|
||||
self,
|
||||
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
|
||||
logprobs_tensors = (
|
||||
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
||||
)
|
||||
if sampling_metadata.enable_early_stop:
|
||||
# will set the stop batch in stop_flags
|
||||
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
|
||||
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user