Files
FastDeploy/fastdeploy/model_executor/layers/sample/early_stopper.py
chen f0f00a6025
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
[OPs] Universal optimization and Fix early_stop cuda 700 (#3375)
* delete nonzero

* delete setup_ops_base.py

* check if

* check gcp infer_seed.cpu()

* fix repetition_early_stopper_kernel cuda 700
2025-08-14 22:40:44 +08:00

130 lines
4.6 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import abstractmethod
import paddle
from fastdeploy.config import EarlyStopConfig
class EarlyStopper:
@abstractmethod
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
"""
Initialize the stopper and set hyper-parameters.
args:
- batch_size: int, the batch size of input
- cfg: EarlyStopConfig
"""
raise NotImplementedError
@abstractmethod
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
"""
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
args:
- probs: [batch_size, vocab_size], the probs of every sample
- next_tokens: [batch_size, 1], the token index of every chosen sample
- stop_flags: [batch_size, 1], determine which batch will be stopped
"""
raise NotImplementedError
class RepetitionEarlyStopper(EarlyStopper):
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
self.early_stop_cfg = cfg
self.window_size = cfg.window_size
self.threshold = cfg.threshold
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
"""
args:
- probs: [batch_size, vocab_size], the probs of every sample
- next_tokens: [batch_size, 1], the token index of every chosen sample
- stop_flags: [batch_size, 1], determine which batch will be stopped
"""
# It will use normal execute if there is no triton support, otherwise use triton
try:
self.process_triton(probs, next_tokens, stop_flags)
except Exception:
self.process_normal(probs, next_tokens, stop_flags)
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
# Get the probability score corresponding to next_tokens in this step
next_scores = paddle.index_sample(probs, next_tokens)
# Sliding window: Move left one grid and insert new score
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
self.trunc_scores[:, -1:] = next_scores
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
# Add the stop flags
stop_flags[need_trunc_all] = True
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
reset_mask = need_trunc_all.tile([1, self.window_size])
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
import triton
from fastdeploy.model_executor.ops.triton_ops import (
repetition_early_stopper_kernel,
)
B, W = self.trunc_scores.shape
real_bsz, V = probs.shape
BLOCK_W = triton.next_power_of_2(W)
grid = (real_bsz,)
repetition_early_stopper_kernel[grid](
self.trunc_scores,
probs,
next_tokens,
stop_flags,
self.threshold,
B,
W,
V,
self.trunc_scores.shape[1],
probs.shape[1],
BLOCK_W=BLOCK_W,
)
return next_tokens
# mapping strategy name to class
EARLY_STOPPER_MAPPING = {
"repetition": RepetitionEarlyStopper,
}
def get_early_stopper_cls_from_stragegy(strategy: str):
"""
get early stopper class from strategy name
args:
- strategy: string, the strategy name
"""
strategy = strategy.lower()
assert (
strategy in EARLY_STOPPER_MAPPING
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
return EARLY_STOPPER_MAPPING[strategy]