mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00

Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* delete nonzero * delete setup_ops_base.py * check if * check gcp infer_seed.cpu() * fix repetition_early_stopper_kernel cuda 700
130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
from abc import abstractmethod
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.config import EarlyStopConfig
|
|
|
|
|
|
class EarlyStopper:
|
|
@abstractmethod
|
|
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
|
"""
|
|
Initialize the stopper and set hyper-parameters.
|
|
args:
|
|
- batch_size: int, the batch size of input
|
|
- cfg: EarlyStopConfig
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
|
"""
|
|
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
|
|
args:
|
|
- probs: [batch_size, vocab_size], the probs of every sample
|
|
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
|
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class RepetitionEarlyStopper(EarlyStopper):
|
|
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
|
self.early_stop_cfg = cfg
|
|
self.window_size = cfg.window_size
|
|
self.threshold = cfg.threshold
|
|
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
|
|
|
|
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
|
"""
|
|
args:
|
|
- probs: [batch_size, vocab_size], the probs of every sample
|
|
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
|
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
|
"""
|
|
# It will use normal execute if there is no triton support, otherwise use triton
|
|
try:
|
|
self.process_triton(probs, next_tokens, stop_flags)
|
|
except Exception:
|
|
self.process_normal(probs, next_tokens, stop_flags)
|
|
|
|
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
|
# Get the probability score corresponding to next_tokens in this step
|
|
next_scores = paddle.index_sample(probs, next_tokens)
|
|
|
|
# Sliding window: Move left one grid and insert new score
|
|
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
|
self.trunc_scores[:, -1:] = next_scores
|
|
|
|
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
|
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
|
|
|
# Add the stop flags
|
|
stop_flags[need_trunc_all] = True
|
|
|
|
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
|
reset_mask = need_trunc_all.tile([1, self.window_size])
|
|
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
|
|
|
|
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
|
import triton
|
|
|
|
from fastdeploy.model_executor.ops.triton_ops import (
|
|
repetition_early_stopper_kernel,
|
|
)
|
|
|
|
B, W = self.trunc_scores.shape
|
|
real_bsz, V = probs.shape
|
|
BLOCK_W = triton.next_power_of_2(W)
|
|
|
|
grid = (real_bsz,)
|
|
repetition_early_stopper_kernel[grid](
|
|
self.trunc_scores,
|
|
probs,
|
|
next_tokens,
|
|
stop_flags,
|
|
self.threshold,
|
|
B,
|
|
W,
|
|
V,
|
|
self.trunc_scores.shape[1],
|
|
probs.shape[1],
|
|
BLOCK_W=BLOCK_W,
|
|
)
|
|
return next_tokens
|
|
|
|
|
|
# mapping strategy name to class
|
|
EARLY_STOPPER_MAPPING = {
|
|
"repetition": RepetitionEarlyStopper,
|
|
}
|
|
|
|
|
|
def get_early_stopper_cls_from_stragegy(strategy: str):
|
|
"""
|
|
get early stopper class from strategy name
|
|
args:
|
|
- strategy: string, the strategy name
|
|
"""
|
|
strategy = strategy.lower()
|
|
assert (
|
|
strategy in EARLY_STOPPER_MAPPING
|
|
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
|
|
return EARLY_STOPPER_MAPPING[strategy]
|