[Feature] Support repetition early stop (#3024)

* support repetition early stop and support user to set the parameter

* remove log

* fix codestyle

* add the early_stop_config to rollout_config

* update config and EarlyStopper class

* fix the bug for triton

* modify the stop method

* update description

* modify the usage for stop_flags

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Zero Rains
2025-07-29 22:42:54 +08:00
committed by GitHub
parent 3214fb5393
commit b2f9a42d87
13 changed files with 575 additions and 4 deletions

View File

@@ -567,6 +567,74 @@ class GraphOptimizationConfig:
argument = self.use_cudagraph
class EarlyStopConfig:
def __init__(
self,
args,
):
"""
Early Stop Configuration class.
Attributes:
window_size: size of the window
threshold: trigger early stop when the ratio of probs exceeds the threshold
"""
"""enable to use early stop"""
self.enable_early_stop: bool = False
"""strategy for early stop, the strategy lists are ['repetition']"""
self.strategy: str = "repetition"
""" the maximum length of verify window for early stop """
self.window_size: int = 3000
""" the probs threshold for early stop """
self.threshold: float = 0.99
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.check_legality_parameters()
def to_json_string(self):
"""
Convert early_stop_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.enable_early_stop is not None:
assert isinstance(
self.enable_early_stop, bool
), "In early stop config, type of enable_early_stop must is bool."
if self.window_size is not None:
assert isinstance(self.window_size, int), "In early stop config, type of window_size must be int."
assert self.window_size > 0, "window_size must large than 0"
if self.threshold is not None:
assert isinstance(self.threshold, float), "In early stop config, type of threshold must be float."
assert self.threshold >= 0 and self.threshold <= 1, "threshold must between 0 and 1"
def update_enable_early_stop(self, argument: bool):
"""
Unified user specifies the enable_early_stop parameter through two methods,
'--enable-early-stop' and '--early-stop-config'
"""
if self.enable_early_stop is None:
# User only set '--enable-early-stop'
self.enable_early_stop = argument
else:
# User both set '--enable-early-stop' and '--early-stop-config'
if self.enable_early_stop is False and argument is True:
raise ValueError(
"Invalid parameter: Cannot set ---enable-early-stop and --early-stop-config '{\"enable_early_stop\":false}' simultaneously."
)
argument = self.enable_early_stop
class LoadConfig:
"""
Configuration for dynamic weight loading strategies
@@ -776,6 +844,7 @@ class FDConfig:
load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
early_stop_config: Optional[EarlyStopConfig] = None
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore

View File

@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional
from fastdeploy.config import (
CacheConfig,
EarlyStopConfig,
GraphOptimizationConfig,
LoadConfig,
SpeculativeConfig,
@@ -313,6 +314,16 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""
enable_early_stop: bool = False
"""
Flag to enable early stop. Default is False (disabled).
"""
early_stop_config: Optional[Dict[str, Any]] = None
"""
Configuration for early stop.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -464,6 +475,18 @@ class EngineArgs:
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.",
)
model_group.add_argument(
"--enable-early-stop",
action="store_true",
default=EngineArgs.enable_early_stop,
help="Enable early stopping during generation.",
)
model_group.add_argument(
"--early-stop-config",
type=json.loads,
default=EngineArgs.early_stop_config,
help="the config for early stop.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
@@ -811,6 +834,16 @@ class EngineArgs:
graph_optimization_args[k] = v
return GraphOptimizationConfig(graph_optimization_args)
def create_early_stop_config(self) -> EarlyStopConfig:
"""
Create and retuan an EarlyStopConfig object based on the current settings.
"""
early_stop_args = asdict(self)
if self.early_stop_config is not None:
for k, v in self.early_stop_config.items():
early_stop_args[k] = v
return EarlyStopConfig(early_stop_args)
def create_engine_config(self) -> Config:
"""
Create and return a Config object based on the current settings.
@@ -833,6 +866,9 @@ class EngineArgs:
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
assert not (
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
@@ -866,4 +902,5 @@ class EngineArgs:
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
enable_logprob=self.enable_logprob,
early_stop_config=early_stop_cfg,
)

View File

@@ -182,6 +182,7 @@ class Config:
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
enable_logprob: bool = False,
early_stop_config: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Config class.
@@ -210,6 +211,8 @@ class Config:
guided_decoding_backend(str): Guided decoding backend. Default is None.
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
Default is False.
enable_logprob(bool): Enable logprob. Default is False.
early_stop_config (Optional[Dict[str, Any]]): Early stop configuration. Default is None.
"""
self.model_config = model_config
self.cache_config = cache_config
@@ -255,6 +258,7 @@ class Config:
self.long_prefill_token_threshold = long_prefill_token_threshold
self.reasoning_parser = reasoning_parser
self.graph_optimization_config = graph_optimization_config
self.early_stop_config = early_stop_config
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self._str_to_list("innode_prefill_ports", int)

View File

@@ -1085,6 +1085,7 @@ class LLMEngine:
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
)
worker_append_flag = {

View File

@@ -0,0 +1,129 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import abstractmethod
import paddle
from fastdeploy.config import EarlyStopConfig
class EarlyStopper:
@abstractmethod
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
"""
Initialize the stopper and set hyper-parameters.
args:
- batch_size: int, the batch size of input
- cfg: EarlyStopConfig
"""
raise NotImplementedError
@abstractmethod
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
"""
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
args:
- probs: [batch_size, vocab_size], the probs of every sample
- next_tokens: [batch_size, 1], the token index of every chosen sample
- stop_flags: [batch_size, 1], determine which batch will be stopped
"""
raise NotImplementedError
class RepetitionEarlyStopper(EarlyStopper):
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
self.early_stop_cfg = cfg
self.window_size = cfg.window_size
self.threshold = cfg.threshold
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
"""
args:
- probs: [batch_size, vocab_size], the probs of every sample
- next_tokens: [batch_size, 1], the token index of every chosen sample
- stop_flags: [batch_size, 1], determine which batch will be stopped
"""
# It will use normal execute if there is no triton support, otherwise use triton
try:
self.process_triton(probs, next_tokens, stop_flags)
except Exception:
self.process_normal(probs, next_tokens, stop_flags)
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
# Get the probability score corresponding to next_tokens in this step
next_scores = paddle.index_sample(probs, next_tokens)
# Sliding window: Move left one grid and insert new score
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
self.trunc_scores[:, -1:] = next_scores
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
# Add the stop flags
stop_flags[need_trunc_all] = True
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
reset_mask = need_trunc_all.tile([1, self.window_size])
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
import triton
from fastdeploy.model_executor.ops.triton_ops import (
repetition_early_stopper_kernel,
)
B, W = self.trunc_scores.shape
V = probs.shape[1]
BLOCK_W = triton.next_power_of_2(W)
grid = (B,)
repetition_early_stopper_kernel[grid](
self.trunc_scores,
probs,
next_tokens,
stop_flags,
self.threshold,
B,
W,
V,
self.trunc_scores.shape[1],
probs.shape[1],
BLOCK_W=BLOCK_W,
)
return next_tokens
# mapping strategy name to class
EARLY_STOPPER_MAPPING = {
"repetition": RepetitionEarlyStopper,
}
def get_early_stopper_cls_from_stragegy(strategy: str):
"""
get early stopper class from strategy name
args:
- strategy: string, the strategy name
"""
strategy = strategy.lower()
assert (
strategy in EARLY_STOPPER_MAPPING
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
return EARLY_STOPPER_MAPPING[strategy]

View File

@@ -44,5 +44,7 @@ class SamplingMetadata:
top_k: Optional[paddle.Tensor] = None
min_p: Optional[paddle.Tensor] = None
max_num_logprobs: Optional[int] = None
enable_early_stop: Optional[int] = False
stop_flags: Optional[paddle.Tensor] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None

View File

@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
Sampler for normal generation.
"""
def __init__(self):
def __init__(self, fd_config: FDConfig = None):
""" """
super().__init__()
if (
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
raise NotImplementedError
self.processor = SamplerProcessor()
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
fd_config is not None
and fd_config.early_stop_config is not None
and fd_config.early_stop_config.enable_early_stop
):
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
def apply_logits_processor(
self,
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
logprobs_tensors = (
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
)
if sampling_metadata.enable_early_stop:
# will set the stop batch in stop_flags
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
self.processor.update_output_tokens(next_tokens, skip_idx_list)

View File

@@ -15,9 +15,10 @@
"""
try:
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
from .wint2_fused_moe import fused_moe_wint2_triton
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel"]
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel", "repetition_early_stopper_kernel"]
except:
pass

View File

@@ -0,0 +1,63 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton
import triton.language as tl
@triton.jit
def repetition_early_stopper_kernel(
trunc_ptr, # float32[B, W]
probs_ptr, # float32[B, V]
next_tokens_ptr, # int32[B]
stop_flags, # bool[B]
threshold,
B, # batch size
W, # windows size
V, # vocab size
stride_bw,
stride_bv,
BLOCK_W: tl.constexpr,
):
b = tl.program_id(0)
w_offs = tl.arange(0, BLOCK_W)
# current ptr
trunc_row = trunc_ptr + b * stride_bw
probs_row = probs_ptr + b * stride_bv
# step1: use index_sample to get next_score
next_token = tl.load(next_tokens_ptr + b)
next_score = tl.load(probs_row + next_token)
# step2: move window leftw = 0 ~ W-2w = 1 ~ W-1
mask = w_offs < W - 1
val = tl.load(trunc_row + w_offs + 1, mask=mask)
tl.store(trunc_row + w_offs, val, mask=mask)
# step3: Insert the current score at the end
tl.store(trunc_row + W - 1, next_score)
# step4: determine whether all are greater than threshold
scores = tl.load(trunc_row + w_offs, mask=w_offs < W, other=0.0)
is_over = scores > threshold
all_over = tl.sum(is_over & (w_offs < W)) == W
# step5: set stop flags and reset trunc scores
if all_over:
tl.store(stop_flags + b, True)
zero = tl.full([BLOCK_W], 0.0, tl.float32)
tl.store(trunc_row + w_offs, zero, mask=w_offs < W)

View File

@@ -57,6 +57,7 @@ class RolloutModelConfig:
disable_any_whitespace: bool = True,
enable_logprob: bool = False,
graph_optimization_config: str = None,
early_stop_config: str = None,
local_rank: int = 0,
):
# Required parameters
@@ -100,6 +101,7 @@ class RolloutModelConfig:
self.enable_logprob = enable_logprob
self.graph_optimization_config = graph_optimization_config
self.local_rank = local_rank
self.early_stop_config = early_stop_config
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -82,6 +82,7 @@ class GPUModelRunner(ModelRunnerBase):
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.guided_backend = None
if self.fd_config.parallel_config.guided_decoding_backend != "off":
@@ -108,10 +109,9 @@ class GPUModelRunner(ModelRunnerBase):
"matmul_v2",
"fused_gemm_epilogue",
]
# Sampler
if not self.speculative_decoding:
self.sampler = Sampler()
self.sampler = Sampler(fd_config)
else:
self.sampler = SpeculativeSampler(fd_config)
@@ -753,6 +753,8 @@ class GPUModelRunner(ModelRunnerBase):
bad_words_token_ids=self.share_inputs["bad_tokens"],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
)
def load_model(self) -> None:

View File

@@ -28,6 +28,7 @@ from fastdeploy.config import (
CacheConfig,
DecodingConfig,
DeviceConfig,
EarlyStopConfig,
ErnieArchitectures,
FDConfig,
GraphOptimizationConfig,
@@ -565,6 +566,12 @@ def parse_args():
action="store_true",
help="Enable output of token-level log probabilities.",
)
parser.add_argument(
"--early_stop_config",
type=json.loads,
default=None,
help="Configuration of early stop.",
)
args = parser.parse_args()
return args
@@ -608,6 +615,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
early_stop_config = EarlyStopConfig(args.early_stop_config)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
@@ -679,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
decoding_config=decoding_config,
quant_config=quant_config,
graph_opt_config=graph_opt_config,
early_stop_config=early_stop_config,
cache_config=cache_config,
)
update_fd_config_for_mm(fd_config)

View File

@@ -0,0 +1,235 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import numpy as np
import paddle
from fastdeploy.config import EarlyStopConfig
from fastdeploy.model_executor.layers.sample.early_stopper import RepetitionEarlyStopper
paddle.set_device("gpu")
np.random.seed(2025)
paddle.seed(2025)
def simulate_step_probs(
batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step_i, trigger_flags, high_prob=0.99
):
"""
Generate a probability distribution for the specified batch of samples,
some samples start to have "high confidence" after some step_i,
high_prob is the confidence of the target token (such as 0.95).
"""
probs = np.random.rand(batch_size, vocab_size).astype("float32")
probs /= probs.sum(axis=1, keepdims=True)
for i in range(batch_size):
if step_i >= trigger_flags[i]:
low_prob = (1.0 - high_prob) / (vocab_size - 1)
probs[i].fill(low_prob)
if i == early_stop_batch_id:
probs[i, fixed_token_id] = high_prob
return probs
def remove_min_max(lst):
"""
remove the min and max value
"""
if len(lst) < 2:
return lst
min_val = min(lst)
max_val = max(lst)
return [x for x in lst if x != min_val and x != max_val]
def test_repetition_early_stopper():
# This test only for 1 batch to trigger early stop
batch_size = 20
vocab_size = 16
window_size = 4
threshold = 0.9
eos_token_id = vocab_size
max_steps = 10
# Select a token as final token
fixed_token_id = np.random.randint(0, vocab_size)
# Set a batch to trigger early stop
early_stop_batch_id = np.random.randint(0, batch_size)
print(f"{fixed_token_id=}\n{early_stop_batch_id=}\n{eos_token_id=}")
# Determine the first step in each batch where the high probability starts to appear
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
trigger_step_flags = dict(trigger_step_flags)
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
stopper = RepetitionEarlyStopper()
stopper.initialize(batch_size, cfg)
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
next_tokens[early_stop_batch_id, 0] = fixed_token_id
print(f"{next_tokens=}\ntrigger_start={trigger_step_flags[early_stop_batch_id]}")
triggered_step = [None] * batch_size
stop_flags = paddle.zeros_like(next_tokens)
for step in range(max_steps):
print(f"\n===== Step {step} =====")
flags = [trigger_step_flags[i] for i in range(batch_size)]
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
probs = paddle.to_tensor(probs_np)
print("Before process:")
print("tokens:\n", stop_flags.numpy().T)
stopper.process(probs, next_tokens, stop_flags)
print("After process:")
print("tokens:\n", stop_flags.numpy().T)
out_np = stop_flags.numpy()
for i in range(batch_size):
if out_np[i, 0] and triggered_step[i] is None:
triggered_step[i] = step
# Show which step trigger the early stop in batch i
print("trigger_step: ", triggered_step)
assert (
triggered_step[early_stop_batch_id] == trigger_step_flags[early_stop_batch_id] + window_size - 1
), "not expected trigger step"
def test_consistency():
batch_size = 20
vocab_size = 103424
window_size = 3000
threshold = 0.9
eos_token_id = vocab_size
max_steps = 10
fixed_token_id = np.random.randint(0, vocab_size)
early_stop_batch_id = np.random.randint(0, batch_size)
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
trigger_step_flags = dict(trigger_step_flags)
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
stopper_normal = RepetitionEarlyStopper()
stopper_normal.initialize(batch_size, cfg)
stopper_triton = RepetitionEarlyStopper()
stopper_triton.initialize(batch_size, cfg)
next_tokens_normal = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
next_tokens_triton = next_tokens_normal.clone()
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
stop_flags_triton = stop_flags_normal.clone()
triggered_step_normal = [None] * batch_size
triggered_step_triton = [None] * batch_size
for step in range(max_steps):
flags = [trigger_step_flags[i] for i in range(batch_size)]
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
probs = paddle.to_tensor(probs_np)
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
out_normal = stop_flags_normal.numpy()
out_triton = stop_flags_triton.numpy()
for i in range(batch_size):
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
triggered_step_normal[i] = step
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
triggered_step_triton[i] = step
for i in range(batch_size):
expected = triggered_step_normal[i]
actual = triggered_step_triton[i]
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
print("Triton vs Normal: All tokens, states, and trigger timings match.")
def test_performance():
batch_size = 256
vocab_size = 103424
window_size = 3000
threshold = 0.9
eos_token_id = vocab_size
max_steps = 50
fixed_token_id = np.random.randint(0, vocab_size)
early_stop_batch_id = np.random.randint(0, batch_size)
print(f"{fixed_token_id=}\n{early_stop_batch_id=}")
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
trigger_step_flags = dict(trigger_step_flags)
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
next_tokens[early_stop_batch_id, 0] = fixed_token_id
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
print("Testing performance triton...")
seconds = []
stopper = RepetitionEarlyStopper()
stopper.initialize(batch_size, cfg)
stop_flags = paddle.zeros_like(next_tokens)
for step in range(max_steps):
flags = [trigger_step_flags[i] for i in range(batch_size)]
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
probs = paddle.to_tensor(probs_np)
s = time.perf_counter()
stopper.process_triton(probs, next_tokens, stop_flags)
e = time.perf_counter()
seconds.append(e - s)
print(
f"triton:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
)
print("Testing performance normal...")
seconds = []
stopper = RepetitionEarlyStopper()
stopper.initialize(batch_size, cfg)
stop_flags = paddle.zeros_like(next_tokens)
for step in range(max_steps):
flags = [trigger_step_flags[i] for i in range(batch_size)]
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
probs = paddle.to_tensor(probs_np)
s = time.perf_counter()
stopper.process_normal(probs, next_tokens, stop_flags)
e = time.perf_counter()
seconds.append(e - s)
print(
f"normal:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
)
print("Config:")
print(f"{batch_size=}, {window_size=}, {threshold=}, {eos_token_id=}, {vocab_size=}, {max_steps=}")
if __name__ == "__main__":
test_repetition_early_stopper()
test_consistency()
test_performance()