mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -567,6 +567,74 @@ class GraphOptimizationConfig:
|
||||
argument = self.use_cudagraph
|
||||
|
||||
|
||||
class EarlyStopConfig:
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
"""
|
||||
Early Stop Configuration class.
|
||||
|
||||
Attributes:
|
||||
window_size: size of the window
|
||||
threshold: trigger early stop when the ratio of probs exceeds the threshold
|
||||
"""
|
||||
"""enable to use early stop"""
|
||||
self.enable_early_stop: bool = False
|
||||
"""strategy for early stop, the strategy lists are ['repetition']"""
|
||||
self.strategy: str = "repetition"
|
||||
""" the maximum length of verify window for early stop """
|
||||
self.window_size: int = 3000
|
||||
""" the probs threshold for early stop """
|
||||
self.threshold: float = 0.99
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
self.check_legality_parameters()
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert early_stop_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
if self.enable_early_stop is not None:
|
||||
assert isinstance(
|
||||
self.enable_early_stop, bool
|
||||
), "In early stop config, type of enable_early_stop must is bool."
|
||||
if self.window_size is not None:
|
||||
assert isinstance(self.window_size, int), "In early stop config, type of window_size must be int."
|
||||
assert self.window_size > 0, "window_size must large than 0"
|
||||
if self.threshold is not None:
|
||||
assert isinstance(self.threshold, float), "In early stop config, type of threshold must be float."
|
||||
assert self.threshold >= 0 and self.threshold <= 1, "threshold must between 0 and 1"
|
||||
|
||||
def update_enable_early_stop(self, argument: bool):
|
||||
"""
|
||||
Unified user specifies the enable_early_stop parameter through two methods,
|
||||
'--enable-early-stop' and '--early-stop-config'
|
||||
"""
|
||||
if self.enable_early_stop is None:
|
||||
# User only set '--enable-early-stop'
|
||||
self.enable_early_stop = argument
|
||||
else:
|
||||
# User both set '--enable-early-stop' and '--early-stop-config'
|
||||
if self.enable_early_stop is False and argument is True:
|
||||
raise ValueError(
|
||||
"Invalid parameter: Cannot set ---enable-early-stop and --early-stop-config '{\"enable_early_stop\":false}' simultaneously."
|
||||
)
|
||||
argument = self.enable_early_stop
|
||||
|
||||
|
||||
class LoadConfig:
|
||||
"""
|
||||
Configuration for dynamic weight loading strategies
|
||||
@@ -776,6 +844,7 @@ class FDConfig:
|
||||
load_config: LoadConfig = field(default=None, init=True)
|
||||
quant_config: Optional[QuantConfigBase] = None
|
||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||
early_stop_config: Optional[EarlyStopConfig] = None
|
||||
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
|
||||
|
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
EarlyStopConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
SpeculativeConfig,
|
||||
@@ -313,6 +314,16 @@ class EngineArgs:
|
||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||
"""
|
||||
|
||||
enable_early_stop: bool = False
|
||||
"""
|
||||
Flag to enable early stop. Default is False (disabled).
|
||||
"""
|
||||
|
||||
early_stop_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Configuration for early stop.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -464,6 +475,18 @@ class EngineArgs:
|
||||
default=EngineArgs.enable_logprob,
|
||||
help="Enable output of token-level log probabilities.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--enable-early-stop",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_early_stop,
|
||||
help="Enable early stopping during generation.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--early-stop-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.early_stop_config,
|
||||
help="the config for early stop.",
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
@@ -811,6 +834,16 @@ class EngineArgs:
|
||||
graph_optimization_args[k] = v
|
||||
return GraphOptimizationConfig(graph_optimization_args)
|
||||
|
||||
def create_early_stop_config(self) -> EarlyStopConfig:
|
||||
"""
|
||||
Create and retuan an EarlyStopConfig object based on the current settings.
|
||||
"""
|
||||
early_stop_args = asdict(self)
|
||||
if self.early_stop_config is not None:
|
||||
for k, v in self.early_stop_config.items():
|
||||
early_stop_args[k] = v
|
||||
return EarlyStopConfig(early_stop_args)
|
||||
|
||||
def create_engine_config(self) -> Config:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
@@ -833,6 +866,9 @@ class EngineArgs:
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
|
||||
assert not (
|
||||
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
||||
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
||||
@@ -866,4 +902,5 @@ class EngineArgs:
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
enable_logprob=self.enable_logprob,
|
||||
early_stop_config=early_stop_cfg,
|
||||
)
|
||||
|
@@ -182,6 +182,7 @@ class Config:
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
enable_logprob: bool = False,
|
||||
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Config class.
|
||||
@@ -210,6 +211,8 @@ class Config:
|
||||
guided_decoding_backend(str): Guided decoding backend. Default is None.
|
||||
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
|
||||
Default is False.
|
||||
enable_logprob(bool): Enable logprob. Default is False.
|
||||
early_stop_config (Optional[Dict[str, Any]]): Early stop configuration. Default is None.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -255,6 +258,7 @@ class Config:
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.early_stop_config = early_stop_config
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self._str_to_list("innode_prefill_ports", int)
|
||||
|
@@ -1085,6 +1085,7 @@ class LLMEngine:
|
||||
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||
)
|
||||
|
||||
worker_append_flag = {
|
||||
|
129
fastdeploy/model_executor/layers/sample/early_stopper.py
Normal file
129
fastdeploy/model_executor/layers/sample/early_stopper.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import EarlyStopConfig
|
||||
|
||||
|
||||
class EarlyStopper:
|
||||
@abstractmethod
|
||||
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
||||
"""
|
||||
Initialize the stopper and set hyper-parameters.
|
||||
args:
|
||||
- batch_size: int, the batch size of input
|
||||
- cfg: EarlyStopConfig
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
"""
|
||||
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
|
||||
args:
|
||||
- probs: [batch_size, vocab_size], the probs of every sample
|
||||
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
||||
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RepetitionEarlyStopper(EarlyStopper):
|
||||
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
||||
self.early_stop_cfg = cfg
|
||||
self.window_size = cfg.window_size
|
||||
self.threshold = cfg.threshold
|
||||
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
|
||||
|
||||
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
"""
|
||||
args:
|
||||
- probs: [batch_size, vocab_size], the probs of every sample
|
||||
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
||||
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
||||
"""
|
||||
# It will use normal execute if there is no triton support, otherwise use triton
|
||||
try:
|
||||
self.process_triton(probs, next_tokens, stop_flags)
|
||||
except Exception:
|
||||
self.process_normal(probs, next_tokens, stop_flags)
|
||||
|
||||
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
# Get the probability score corresponding to next_tokens in this step
|
||||
next_scores = paddle.index_sample(probs, next_tokens)
|
||||
|
||||
# Sliding window: Move left one grid and insert new score
|
||||
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
||||
self.trunc_scores[:, -1:] = next_scores
|
||||
|
||||
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
||||
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
||||
|
||||
# Add the stop flags
|
||||
stop_flags[need_trunc_all] = True
|
||||
|
||||
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
||||
reset_mask = need_trunc_all.tile([1, self.window_size])
|
||||
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
|
||||
|
||||
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
import triton
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops import (
|
||||
repetition_early_stopper_kernel,
|
||||
)
|
||||
|
||||
B, W = self.trunc_scores.shape
|
||||
V = probs.shape[1]
|
||||
BLOCK_W = triton.next_power_of_2(W)
|
||||
|
||||
grid = (B,)
|
||||
repetition_early_stopper_kernel[grid](
|
||||
self.trunc_scores,
|
||||
probs,
|
||||
next_tokens,
|
||||
stop_flags,
|
||||
self.threshold,
|
||||
B,
|
||||
W,
|
||||
V,
|
||||
self.trunc_scores.shape[1],
|
||||
probs.shape[1],
|
||||
BLOCK_W=BLOCK_W,
|
||||
)
|
||||
return next_tokens
|
||||
|
||||
|
||||
# mapping strategy name to class
|
||||
EARLY_STOPPER_MAPPING = {
|
||||
"repetition": RepetitionEarlyStopper,
|
||||
}
|
||||
|
||||
|
||||
def get_early_stopper_cls_from_stragegy(strategy: str):
|
||||
"""
|
||||
get early stopper class from strategy name
|
||||
args:
|
||||
- strategy: string, the strategy name
|
||||
"""
|
||||
strategy = strategy.lower()
|
||||
assert (
|
||||
strategy in EARLY_STOPPER_MAPPING
|
||||
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
|
||||
return EARLY_STOPPER_MAPPING[strategy]
|
@@ -44,5 +44,7 @@ class SamplingMetadata:
|
||||
top_k: Optional[paddle.Tensor] = None
|
||||
min_p: Optional[paddle.Tensor] = None
|
||||
max_num_logprobs: Optional[int] = None
|
||||
enable_early_stop: Optional[int] = False
|
||||
stop_flags: Optional[paddle.Tensor] = None
|
||||
prompt_ids: Optional[paddle.Tensor] = None
|
||||
prompt_lens: Optional[paddle.Tensor] = None
|
||||
|
@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||
get_early_stopper_cls_from_stragegy,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores,
|
||||
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
|
||||
Sampler for normal generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, fd_config: FDConfig = None):
|
||||
""" """
|
||||
super().__init__()
|
||||
if (
|
||||
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
fd_config is not None
|
||||
and fd_config.early_stop_config is not None
|
||||
and fd_config.early_stop_config.enable_early_stop
|
||||
):
|
||||
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
|
||||
self.early_stopper = early_stopper_cls()
|
||||
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
||||
|
||||
def apply_logits_processor(
|
||||
self,
|
||||
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
|
||||
logprobs_tensors = (
|
||||
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
||||
)
|
||||
if sampling_metadata.enable_early_stop:
|
||||
# will set the stop batch in stop_flags
|
||||
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
|
||||
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
|
||||
|
@@ -15,9 +15,10 @@
|
||||
"""
|
||||
|
||||
try:
|
||||
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
|
||||
from .wint2_fused_moe import fused_moe_wint2_triton
|
||||
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel
|
||||
|
||||
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel"]
|
||||
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel", "repetition_early_stopper_kernel"]
|
||||
except:
|
||||
pass
|
||||
|
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def repetition_early_stopper_kernel(
|
||||
trunc_ptr, # float32[B, W]
|
||||
probs_ptr, # float32[B, V]
|
||||
next_tokens_ptr, # int32[B]
|
||||
stop_flags, # bool[B]
|
||||
threshold,
|
||||
B, # batch size
|
||||
W, # windows size
|
||||
V, # vocab size
|
||||
stride_bw,
|
||||
stride_bv,
|
||||
BLOCK_W: tl.constexpr,
|
||||
):
|
||||
b = tl.program_id(0)
|
||||
w_offs = tl.arange(0, BLOCK_W)
|
||||
|
||||
# current ptr
|
||||
trunc_row = trunc_ptr + b * stride_bw
|
||||
probs_row = probs_ptr + b * stride_bv
|
||||
|
||||
# step1: use index_sample to get next_score
|
||||
next_token = tl.load(next_tokens_ptr + b)
|
||||
next_score = tl.load(probs_row + next_token)
|
||||
|
||||
# step2: move window left(w = 0 ~ W-2)←(w = 1 ~ W-1)
|
||||
mask = w_offs < W - 1
|
||||
val = tl.load(trunc_row + w_offs + 1, mask=mask)
|
||||
tl.store(trunc_row + w_offs, val, mask=mask)
|
||||
|
||||
# step3: Insert the current score at the end
|
||||
tl.store(trunc_row + W - 1, next_score)
|
||||
|
||||
# step4: determine whether all are greater than threshold
|
||||
scores = tl.load(trunc_row + w_offs, mask=w_offs < W, other=0.0)
|
||||
is_over = scores > threshold
|
||||
all_over = tl.sum(is_over & (w_offs < W)) == W
|
||||
|
||||
# step5: set stop flags and reset trunc scores
|
||||
if all_over:
|
||||
tl.store(stop_flags + b, True)
|
||||
zero = tl.full([BLOCK_W], 0.0, tl.float32)
|
||||
tl.store(trunc_row + w_offs, zero, mask=w_offs < W)
|
@@ -57,6 +57,7 @@ class RolloutModelConfig:
|
||||
disable_any_whitespace: bool = True,
|
||||
enable_logprob: bool = False,
|
||||
graph_optimization_config: str = None,
|
||||
early_stop_config: str = None,
|
||||
local_rank: int = 0,
|
||||
):
|
||||
# Required parameters
|
||||
@@ -100,6 +101,7 @@ class RolloutModelConfig:
|
||||
self.enable_logprob = enable_logprob
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.local_rank = local_rank
|
||||
self.early_stop_config = early_stop_config
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
@@ -82,6 +82,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.speculative_method = self.fd_config.speculative_config.method
|
||||
self.speculative_decoding = self.speculative_method is not None
|
||||
self.enable_logprob = fd_config.model_config.enable_logprob
|
||||
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||
|
||||
self.guided_backend = None
|
||||
if self.fd_config.parallel_config.guided_decoding_backend != "off":
|
||||
@@ -108,10 +109,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"matmul_v2",
|
||||
"fused_gemm_epilogue",
|
||||
]
|
||||
|
||||
# Sampler
|
||||
if not self.speculative_decoding:
|
||||
self.sampler = Sampler()
|
||||
self.sampler = Sampler(fd_config)
|
||||
else:
|
||||
self.sampler = SpeculativeSampler(fd_config)
|
||||
|
||||
@@ -753,6 +753,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
enable_early_stop=self.enable_early_stop,
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
|
@@ -28,6 +28,7 @@ from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
DecodingConfig,
|
||||
DeviceConfig,
|
||||
EarlyStopConfig,
|
||||
ErnieArchitectures,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
@@ -565,6 +566,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Enable output of token-level log probabilities.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stop_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Configuration of early stop.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -608,6 +615,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
||||
|
||||
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||
@@ -679,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
early_stop_config=early_stop_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
|
235
test/layers/test_repetition_early_stopper.py
Normal file
235
test/layers/test_repetition_early_stopper.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import EarlyStopConfig
|
||||
from fastdeploy.model_executor.layers.sample.early_stopper import RepetitionEarlyStopper
|
||||
|
||||
paddle.set_device("gpu")
|
||||
np.random.seed(2025)
|
||||
paddle.seed(2025)
|
||||
|
||||
|
||||
def simulate_step_probs(
|
||||
batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step_i, trigger_flags, high_prob=0.99
|
||||
):
|
||||
"""
|
||||
Generate a probability distribution for the specified batch of samples,
|
||||
some samples start to have "high confidence" after some step_i,
|
||||
high_prob is the confidence of the target token (such as 0.95).
|
||||
"""
|
||||
probs = np.random.rand(batch_size, vocab_size).astype("float32")
|
||||
probs /= probs.sum(axis=1, keepdims=True)
|
||||
|
||||
for i in range(batch_size):
|
||||
if step_i >= trigger_flags[i]:
|
||||
low_prob = (1.0 - high_prob) / (vocab_size - 1)
|
||||
probs[i].fill(low_prob)
|
||||
if i == early_stop_batch_id:
|
||||
probs[i, fixed_token_id] = high_prob
|
||||
return probs
|
||||
|
||||
|
||||
def remove_min_max(lst):
|
||||
"""
|
||||
remove the min and max value
|
||||
"""
|
||||
if len(lst) < 2:
|
||||
return lst
|
||||
min_val = min(lst)
|
||||
max_val = max(lst)
|
||||
return [x for x in lst if x != min_val and x != max_val]
|
||||
|
||||
|
||||
def test_repetition_early_stopper():
|
||||
# This test only for 1 batch to trigger early stop
|
||||
batch_size = 20
|
||||
vocab_size = 16
|
||||
window_size = 4
|
||||
threshold = 0.9
|
||||
eos_token_id = vocab_size
|
||||
max_steps = 10
|
||||
|
||||
# Select a token as final token
|
||||
fixed_token_id = np.random.randint(0, vocab_size)
|
||||
# Set a batch to trigger early stop
|
||||
early_stop_batch_id = np.random.randint(0, batch_size)
|
||||
print(f"{fixed_token_id=}\n{early_stop_batch_id=}\n{eos_token_id=}")
|
||||
|
||||
# Determine the first step in each batch where the high probability starts to appear
|
||||
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||
trigger_step_flags = dict(trigger_step_flags)
|
||||
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||
stopper = RepetitionEarlyStopper()
|
||||
stopper.initialize(batch_size, cfg)
|
||||
|
||||
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
||||
next_tokens[early_stop_batch_id, 0] = fixed_token_id
|
||||
|
||||
print(f"{next_tokens=}\ntrigger_start={trigger_step_flags[early_stop_batch_id]}")
|
||||
|
||||
triggered_step = [None] * batch_size
|
||||
stop_flags = paddle.zeros_like(next_tokens)
|
||||
for step in range(max_steps):
|
||||
print(f"\n===== Step {step} =====")
|
||||
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||
probs = paddle.to_tensor(probs_np)
|
||||
print("Before process:")
|
||||
print("tokens:\n", stop_flags.numpy().T)
|
||||
|
||||
stopper.process(probs, next_tokens, stop_flags)
|
||||
|
||||
print("After process:")
|
||||
print("tokens:\n", stop_flags.numpy().T)
|
||||
|
||||
out_np = stop_flags.numpy()
|
||||
for i in range(batch_size):
|
||||
if out_np[i, 0] and triggered_step[i] is None:
|
||||
triggered_step[i] = step
|
||||
|
||||
# Show which step trigger the early stop in batch i
|
||||
print("trigger_step: ", triggered_step)
|
||||
assert (
|
||||
triggered_step[early_stop_batch_id] == trigger_step_flags[early_stop_batch_id] + window_size - 1
|
||||
), "not expected trigger step"
|
||||
|
||||
|
||||
def test_consistency():
|
||||
batch_size = 20
|
||||
vocab_size = 103424
|
||||
window_size = 3000
|
||||
threshold = 0.9
|
||||
eos_token_id = vocab_size
|
||||
max_steps = 10
|
||||
|
||||
fixed_token_id = np.random.randint(0, vocab_size)
|
||||
early_stop_batch_id = np.random.randint(0, batch_size)
|
||||
|
||||
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||
trigger_step_flags = dict(trigger_step_flags)
|
||||
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||
stopper_normal = RepetitionEarlyStopper()
|
||||
stopper_normal.initialize(batch_size, cfg)
|
||||
stopper_triton = RepetitionEarlyStopper()
|
||||
stopper_triton.initialize(batch_size, cfg)
|
||||
|
||||
next_tokens_normal = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
||||
next_tokens_triton = next_tokens_normal.clone()
|
||||
|
||||
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
|
||||
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
|
||||
|
||||
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
|
||||
stop_flags_triton = stop_flags_normal.clone()
|
||||
|
||||
triggered_step_normal = [None] * batch_size
|
||||
triggered_step_triton = [None] * batch_size
|
||||
|
||||
for step in range(max_steps):
|
||||
|
||||
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||
probs = paddle.to_tensor(probs_np)
|
||||
|
||||
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
|
||||
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
|
||||
|
||||
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
|
||||
|
||||
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
|
||||
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
|
||||
|
||||
out_normal = stop_flags_normal.numpy()
|
||||
out_triton = stop_flags_triton.numpy()
|
||||
for i in range(batch_size):
|
||||
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
|
||||
triggered_step_normal[i] = step
|
||||
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
|
||||
triggered_step_triton[i] = step
|
||||
|
||||
for i in range(batch_size):
|
||||
expected = triggered_step_normal[i]
|
||||
actual = triggered_step_triton[i]
|
||||
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
||||
|
||||
print("Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||
|
||||
|
||||
def test_performance():
|
||||
batch_size = 256
|
||||
vocab_size = 103424
|
||||
window_size = 3000
|
||||
threshold = 0.9
|
||||
eos_token_id = vocab_size
|
||||
max_steps = 50
|
||||
|
||||
fixed_token_id = np.random.randint(0, vocab_size)
|
||||
early_stop_batch_id = np.random.randint(0, batch_size)
|
||||
print(f"{fixed_token_id=}\n{early_stop_batch_id=}")
|
||||
|
||||
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||
trigger_step_flags = dict(trigger_step_flags)
|
||||
|
||||
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
||||
next_tokens[early_stop_batch_id, 0] = fixed_token_id
|
||||
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||
print("Testing performance triton...")
|
||||
seconds = []
|
||||
stopper = RepetitionEarlyStopper()
|
||||
stopper.initialize(batch_size, cfg)
|
||||
stop_flags = paddle.zeros_like(next_tokens)
|
||||
for step in range(max_steps):
|
||||
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||
probs = paddle.to_tensor(probs_np)
|
||||
s = time.perf_counter()
|
||||
stopper.process_triton(probs, next_tokens, stop_flags)
|
||||
e = time.perf_counter()
|
||||
seconds.append(e - s)
|
||||
print(
|
||||
f"triton:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
|
||||
)
|
||||
|
||||
print("Testing performance normal...")
|
||||
seconds = []
|
||||
stopper = RepetitionEarlyStopper()
|
||||
stopper.initialize(batch_size, cfg)
|
||||
stop_flags = paddle.zeros_like(next_tokens)
|
||||
for step in range(max_steps):
|
||||
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||
probs = paddle.to_tensor(probs_np)
|
||||
s = time.perf_counter()
|
||||
stopper.process_normal(probs, next_tokens, stop_flags)
|
||||
e = time.perf_counter()
|
||||
seconds.append(e - s)
|
||||
print(
|
||||
f"normal:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
|
||||
)
|
||||
|
||||
print("Config:")
|
||||
print(f"{batch_size=}, {window_size=}, {threshold=}, {eos_token_id=}, {vocab_size=}, {max_steps=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_repetition_early_stopper()
|
||||
test_consistency()
|
||||
test_performance()
|
Reference in New Issue
Block a user