mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -567,6 +567,74 @@ class GraphOptimizationConfig:
|
|||||||
argument = self.use_cudagraph
|
argument = self.use_cudagraph
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Early Stop Configuration class.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
window_size: size of the window
|
||||||
|
threshold: trigger early stop when the ratio of probs exceeds the threshold
|
||||||
|
"""
|
||||||
|
"""enable to use early stop"""
|
||||||
|
self.enable_early_stop: bool = False
|
||||||
|
"""strategy for early stop, the strategy lists are ['repetition']"""
|
||||||
|
self.strategy: str = "repetition"
|
||||||
|
""" the maximum length of verify window for early stop """
|
||||||
|
self.window_size: int = 3000
|
||||||
|
""" the probs threshold for early stop """
|
||||||
|
self.threshold: float = 0.99
|
||||||
|
|
||||||
|
if args is not None:
|
||||||
|
for key, value in args.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
self.check_legality_parameters()
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""
|
||||||
|
Convert early_stop_config to json string.
|
||||||
|
"""
|
||||||
|
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.to_json_string()
|
||||||
|
|
||||||
|
def check_legality_parameters(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Check the legality of parameters passed in from the command line"""
|
||||||
|
if self.enable_early_stop is not None:
|
||||||
|
assert isinstance(
|
||||||
|
self.enable_early_stop, bool
|
||||||
|
), "In early stop config, type of enable_early_stop must is bool."
|
||||||
|
if self.window_size is not None:
|
||||||
|
assert isinstance(self.window_size, int), "In early stop config, type of window_size must be int."
|
||||||
|
assert self.window_size > 0, "window_size must large than 0"
|
||||||
|
if self.threshold is not None:
|
||||||
|
assert isinstance(self.threshold, float), "In early stop config, type of threshold must be float."
|
||||||
|
assert self.threshold >= 0 and self.threshold <= 1, "threshold must between 0 and 1"
|
||||||
|
|
||||||
|
def update_enable_early_stop(self, argument: bool):
|
||||||
|
"""
|
||||||
|
Unified user specifies the enable_early_stop parameter through two methods,
|
||||||
|
'--enable-early-stop' and '--early-stop-config'
|
||||||
|
"""
|
||||||
|
if self.enable_early_stop is None:
|
||||||
|
# User only set '--enable-early-stop'
|
||||||
|
self.enable_early_stop = argument
|
||||||
|
else:
|
||||||
|
# User both set '--enable-early-stop' and '--early-stop-config'
|
||||||
|
if self.enable_early_stop is False and argument is True:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid parameter: Cannot set ---enable-early-stop and --early-stop-config '{\"enable_early_stop\":false}' simultaneously."
|
||||||
|
)
|
||||||
|
argument = self.enable_early_stop
|
||||||
|
|
||||||
|
|
||||||
class LoadConfig:
|
class LoadConfig:
|
||||||
"""
|
"""
|
||||||
Configuration for dynamic weight loading strategies
|
Configuration for dynamic weight loading strategies
|
||||||
@@ -776,6 +844,7 @@ class FDConfig:
|
|||||||
load_config: LoadConfig = field(default=None, init=True)
|
load_config: LoadConfig = field(default=None, init=True)
|
||||||
quant_config: Optional[QuantConfigBase] = None
|
quant_config: Optional[QuantConfigBase] = None
|
||||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||||
|
early_stop_config: Optional[EarlyStopConfig] = None
|
||||||
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
|
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
|
||||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||||
|
|
||||||
|
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
|
EarlyStopConfig,
|
||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
LoadConfig,
|
LoadConfig,
|
||||||
SpeculativeConfig,
|
SpeculativeConfig,
|
||||||
@@ -313,6 +314,16 @@ class EngineArgs:
|
|||||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
enable_early_stop: bool = False
|
||||||
|
"""
|
||||||
|
Flag to enable early stop. Default is False (disabled).
|
||||||
|
"""
|
||||||
|
|
||||||
|
early_stop_config: Optional[Dict[str, Any]] = None
|
||||||
|
"""
|
||||||
|
Configuration for early stop.
|
||||||
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""
|
"""
|
||||||
Post-initialization processing to set default tokenizer if not provided.
|
Post-initialization processing to set default tokenizer if not provided.
|
||||||
@@ -464,6 +475,18 @@ class EngineArgs:
|
|||||||
default=EngineArgs.enable_logprob,
|
default=EngineArgs.enable_logprob,
|
||||||
help="Enable output of token-level log probabilities.",
|
help="Enable output of token-level log probabilities.",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--enable-early-stop",
|
||||||
|
action="store_true",
|
||||||
|
default=EngineArgs.enable_early_stop,
|
||||||
|
help="Enable early stopping during generation.",
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--early-stop-config",
|
||||||
|
type=json.loads,
|
||||||
|
default=EngineArgs.early_stop_config,
|
||||||
|
help="the config for early stop.",
|
||||||
|
)
|
||||||
|
|
||||||
# Parallel processing parameters group
|
# Parallel processing parameters group
|
||||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||||
@@ -811,6 +834,16 @@ class EngineArgs:
|
|||||||
graph_optimization_args[k] = v
|
graph_optimization_args[k] = v
|
||||||
return GraphOptimizationConfig(graph_optimization_args)
|
return GraphOptimizationConfig(graph_optimization_args)
|
||||||
|
|
||||||
|
def create_early_stop_config(self) -> EarlyStopConfig:
|
||||||
|
"""
|
||||||
|
Create and retuan an EarlyStopConfig object based on the current settings.
|
||||||
|
"""
|
||||||
|
early_stop_args = asdict(self)
|
||||||
|
if self.early_stop_config is not None:
|
||||||
|
for k, v in self.early_stop_config.items():
|
||||||
|
early_stop_args[k] = v
|
||||||
|
return EarlyStopConfig(early_stop_args)
|
||||||
|
|
||||||
def create_engine_config(self) -> Config:
|
def create_engine_config(self) -> Config:
|
||||||
"""
|
"""
|
||||||
Create and return a Config object based on the current settings.
|
Create and return a Config object based on the current settings.
|
||||||
@@ -833,6 +866,9 @@ class EngineArgs:
|
|||||||
graph_opt_cfg = self.create_graph_optimization_config()
|
graph_opt_cfg = self.create_graph_optimization_config()
|
||||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||||
|
|
||||||
|
early_stop_cfg = self.create_early_stop_config()
|
||||||
|
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
||||||
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
||||||
@@ -866,4 +902,5 @@ class EngineArgs:
|
|||||||
guided_decoding_backend=self.guided_decoding_backend,
|
guided_decoding_backend=self.guided_decoding_backend,
|
||||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||||
enable_logprob=self.enable_logprob,
|
enable_logprob=self.enable_logprob,
|
||||||
|
early_stop_config=early_stop_cfg,
|
||||||
)
|
)
|
||||||
|
@@ -182,6 +182,7 @@ class Config:
|
|||||||
guided_decoding_backend: Optional[str] = None,
|
guided_decoding_backend: Optional[str] = None,
|
||||||
disable_any_whitespace: bool = False,
|
disable_any_whitespace: bool = False,
|
||||||
enable_logprob: bool = False,
|
enable_logprob: bool = False,
|
||||||
|
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Config class.
|
Initialize the Config class.
|
||||||
@@ -210,6 +211,8 @@ class Config:
|
|||||||
guided_decoding_backend(str): Guided decoding backend. Default is None.
|
guided_decoding_backend(str): Guided decoding backend. Default is None.
|
||||||
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
|
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
|
||||||
Default is False.
|
Default is False.
|
||||||
|
enable_logprob(bool): Enable logprob. Default is False.
|
||||||
|
early_stop_config (Optional[Dict[str, Any]]): Early stop configuration. Default is None.
|
||||||
"""
|
"""
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@@ -255,6 +258,7 @@ class Config:
|
|||||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||||
self.reasoning_parser = reasoning_parser
|
self.reasoning_parser = reasoning_parser
|
||||||
self.graph_optimization_config = graph_optimization_config
|
self.graph_optimization_config = graph_optimization_config
|
||||||
|
self.early_stop_config = early_stop_config
|
||||||
self.guided_decoding_backend = guided_decoding_backend
|
self.guided_decoding_backend = guided_decoding_backend
|
||||||
self.disable_any_whitespace = disable_any_whitespace
|
self.disable_any_whitespace = disable_any_whitespace
|
||||||
self._str_to_list("innode_prefill_ports", int)
|
self._str_to_list("innode_prefill_ports", int)
|
||||||
|
@@ -1085,6 +1085,7 @@ class LLMEngine:
|
|||||||
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
||||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||||
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_append_flag = {
|
worker_append_flag = {
|
||||||
|
129
fastdeploy/model_executor/layers/sample/early_stopper.py
Normal file
129
fastdeploy/model_executor/layers/sample/early_stopper.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.config import EarlyStopConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopper:
|
||||||
|
@abstractmethod
|
||||||
|
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
||||||
|
"""
|
||||||
|
Initialize the stopper and set hyper-parameters.
|
||||||
|
args:
|
||||||
|
- batch_size: int, the batch size of input
|
||||||
|
- cfg: EarlyStopConfig
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
processs the stopper and set the stop_flags corresponding to the batch that triggers early stop to True
|
||||||
|
args:
|
||||||
|
- probs: [batch_size, vocab_size], the probs of every sample
|
||||||
|
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
||||||
|
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RepetitionEarlyStopper(EarlyStopper):
|
||||||
|
def initialize(self, batch_size: int, cfg: EarlyStopConfig):
|
||||||
|
self.early_stop_cfg = cfg
|
||||||
|
self.window_size = cfg.window_size
|
||||||
|
self.threshold = cfg.threshold
|
||||||
|
self.trunc_scores = paddle.zeros((batch_size, self.early_stop_cfg.window_size), dtype="float32")
|
||||||
|
|
||||||
|
def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
args:
|
||||||
|
- probs: [batch_size, vocab_size], the probs of every sample
|
||||||
|
- next_tokens: [batch_size, 1], the token index of every chosen sample
|
||||||
|
- stop_flags: [batch_size, 1], determine which batch will be stopped
|
||||||
|
"""
|
||||||
|
# It will use normal execute if there is no triton support, otherwise use triton
|
||||||
|
try:
|
||||||
|
self.process_triton(probs, next_tokens, stop_flags)
|
||||||
|
except Exception:
|
||||||
|
self.process_normal(probs, next_tokens, stop_flags)
|
||||||
|
|
||||||
|
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||||
|
# Get the probability score corresponding to next_tokens in this step
|
||||||
|
next_scores = paddle.index_sample(probs, next_tokens)
|
||||||
|
|
||||||
|
# Sliding window: Move left one grid and insert new score
|
||||||
|
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
||||||
|
self.trunc_scores[:, -1:] = next_scores
|
||||||
|
|
||||||
|
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
||||||
|
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Add the stop flags
|
||||||
|
stop_flags[need_trunc_all] = True
|
||||||
|
|
||||||
|
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
||||||
|
reset_mask = need_trunc_all.tile([1, self.window_size])
|
||||||
|
self.trunc_scores = paddle.where(reset_mask, paddle.zeros_like(self.trunc_scores), self.trunc_scores)
|
||||||
|
|
||||||
|
def process_triton(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.triton_ops import (
|
||||||
|
repetition_early_stopper_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
B, W = self.trunc_scores.shape
|
||||||
|
V = probs.shape[1]
|
||||||
|
BLOCK_W = triton.next_power_of_2(W)
|
||||||
|
|
||||||
|
grid = (B,)
|
||||||
|
repetition_early_stopper_kernel[grid](
|
||||||
|
self.trunc_scores,
|
||||||
|
probs,
|
||||||
|
next_tokens,
|
||||||
|
stop_flags,
|
||||||
|
self.threshold,
|
||||||
|
B,
|
||||||
|
W,
|
||||||
|
V,
|
||||||
|
self.trunc_scores.shape[1],
|
||||||
|
probs.shape[1],
|
||||||
|
BLOCK_W=BLOCK_W,
|
||||||
|
)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
# mapping strategy name to class
|
||||||
|
EARLY_STOPPER_MAPPING = {
|
||||||
|
"repetition": RepetitionEarlyStopper,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_early_stopper_cls_from_stragegy(strategy: str):
|
||||||
|
"""
|
||||||
|
get early stopper class from strategy name
|
||||||
|
args:
|
||||||
|
- strategy: string, the strategy name
|
||||||
|
"""
|
||||||
|
strategy = strategy.lower()
|
||||||
|
assert (
|
||||||
|
strategy in EARLY_STOPPER_MAPPING
|
||||||
|
), f"{strategy} is not supported yet, only support {EARLY_STOPPER_MAPPING.keys()}."
|
||||||
|
return EARLY_STOPPER_MAPPING[strategy]
|
@@ -44,5 +44,7 @@ class SamplingMetadata:
|
|||||||
top_k: Optional[paddle.Tensor] = None
|
top_k: Optional[paddle.Tensor] = None
|
||||||
min_p: Optional[paddle.Tensor] = None
|
min_p: Optional[paddle.Tensor] = None
|
||||||
max_num_logprobs: Optional[int] = None
|
max_num_logprobs: Optional[int] = None
|
||||||
|
enable_early_stop: Optional[int] = False
|
||||||
|
stop_flags: Optional[paddle.Tensor] = None
|
||||||
prompt_ids: Optional[paddle.Tensor] = None
|
prompt_ids: Optional[paddle.Tensor] = None
|
||||||
prompt_lens: Optional[paddle.Tensor] = None
|
prompt_lens: Optional[paddle.Tensor] = None
|
||||||
|
@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
|
|||||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||||
LogitsProcessorBase,
|
LogitsProcessorBase,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||||
|
get_early_stopper_cls_from_stragegy,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||||
from fastdeploy.model_executor.layers.sample.ops import (
|
from fastdeploy.model_executor.layers.sample.ops import (
|
||||||
apply_penalty_multi_scores,
|
apply_penalty_multi_scores,
|
||||||
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
|
|||||||
Sampler for normal generation.
|
Sampler for normal generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, fd_config: FDConfig = None):
|
||||||
""" """
|
""" """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if (
|
if (
|
||||||
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.processor = SamplerProcessor()
|
self.processor = SamplerProcessor()
|
||||||
|
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||||
|
if (
|
||||||
|
fd_config is not None
|
||||||
|
and fd_config.early_stop_config is not None
|
||||||
|
and fd_config.early_stop_config.enable_early_stop
|
||||||
|
):
|
||||||
|
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
|
||||||
|
self.early_stopper = early_stopper_cls()
|
||||||
|
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
||||||
|
|
||||||
def apply_logits_processor(
|
def apply_logits_processor(
|
||||||
self,
|
self,
|
||||||
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
|
|||||||
logprobs_tensors = (
|
logprobs_tensors = (
|
||||||
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
||||||
)
|
)
|
||||||
|
if sampling_metadata.enable_early_stop:
|
||||||
|
# will set the stop batch in stop_flags
|
||||||
|
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
|
||||||
|
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
|
||||||
|
|
||||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||||
|
|
||||||
|
@@ -15,9 +15,10 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
|
||||||
from .wint2_fused_moe import fused_moe_wint2_triton
|
from .wint2_fused_moe import fused_moe_wint2_triton
|
||||||
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel
|
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel
|
||||||
|
|
||||||
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel"]
|
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel", "repetition_early_stopper_kernel"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def repetition_early_stopper_kernel(
|
||||||
|
trunc_ptr, # float32[B, W]
|
||||||
|
probs_ptr, # float32[B, V]
|
||||||
|
next_tokens_ptr, # int32[B]
|
||||||
|
stop_flags, # bool[B]
|
||||||
|
threshold,
|
||||||
|
B, # batch size
|
||||||
|
W, # windows size
|
||||||
|
V, # vocab size
|
||||||
|
stride_bw,
|
||||||
|
stride_bv,
|
||||||
|
BLOCK_W: tl.constexpr,
|
||||||
|
):
|
||||||
|
b = tl.program_id(0)
|
||||||
|
w_offs = tl.arange(0, BLOCK_W)
|
||||||
|
|
||||||
|
# current ptr
|
||||||
|
trunc_row = trunc_ptr + b * stride_bw
|
||||||
|
probs_row = probs_ptr + b * stride_bv
|
||||||
|
|
||||||
|
# step1: use index_sample to get next_score
|
||||||
|
next_token = tl.load(next_tokens_ptr + b)
|
||||||
|
next_score = tl.load(probs_row + next_token)
|
||||||
|
|
||||||
|
# step2: move window left(w = 0 ~ W-2)←(w = 1 ~ W-1)
|
||||||
|
mask = w_offs < W - 1
|
||||||
|
val = tl.load(trunc_row + w_offs + 1, mask=mask)
|
||||||
|
tl.store(trunc_row + w_offs, val, mask=mask)
|
||||||
|
|
||||||
|
# step3: Insert the current score at the end
|
||||||
|
tl.store(trunc_row + W - 1, next_score)
|
||||||
|
|
||||||
|
# step4: determine whether all are greater than threshold
|
||||||
|
scores = tl.load(trunc_row + w_offs, mask=w_offs < W, other=0.0)
|
||||||
|
is_over = scores > threshold
|
||||||
|
all_over = tl.sum(is_over & (w_offs < W)) == W
|
||||||
|
|
||||||
|
# step5: set stop flags and reset trunc scores
|
||||||
|
if all_over:
|
||||||
|
tl.store(stop_flags + b, True)
|
||||||
|
zero = tl.full([BLOCK_W], 0.0, tl.float32)
|
||||||
|
tl.store(trunc_row + w_offs, zero, mask=w_offs < W)
|
@@ -57,6 +57,7 @@ class RolloutModelConfig:
|
|||||||
disable_any_whitespace: bool = True,
|
disable_any_whitespace: bool = True,
|
||||||
enable_logprob: bool = False,
|
enable_logprob: bool = False,
|
||||||
graph_optimization_config: str = None,
|
graph_optimization_config: str = None,
|
||||||
|
early_stop_config: str = None,
|
||||||
local_rank: int = 0,
|
local_rank: int = 0,
|
||||||
):
|
):
|
||||||
# Required parameters
|
# Required parameters
|
||||||
@@ -100,6 +101,7 @@ class RolloutModelConfig:
|
|||||||
self.enable_logprob = enable_logprob
|
self.enable_logprob = enable_logprob
|
||||||
self.graph_optimization_config = graph_optimization_config
|
self.graph_optimization_config = graph_optimization_config
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
|
self.early_stop_config = early_stop_config
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||||
|
@@ -82,6 +82,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.speculative_method = self.fd_config.speculative_config.method
|
self.speculative_method = self.fd_config.speculative_config.method
|
||||||
self.speculative_decoding = self.speculative_method is not None
|
self.speculative_decoding = self.speculative_method is not None
|
||||||
self.enable_logprob = fd_config.model_config.enable_logprob
|
self.enable_logprob = fd_config.model_config.enable_logprob
|
||||||
|
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||||
|
|
||||||
self.guided_backend = None
|
self.guided_backend = None
|
||||||
if self.fd_config.parallel_config.guided_decoding_backend != "off":
|
if self.fd_config.parallel_config.guided_decoding_backend != "off":
|
||||||
@@ -108,10 +109,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
"matmul_v2",
|
"matmul_v2",
|
||||||
"fused_gemm_epilogue",
|
"fused_gemm_epilogue",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Sampler
|
# Sampler
|
||||||
if not self.speculative_decoding:
|
if not self.speculative_decoding:
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler(fd_config)
|
||||||
else:
|
else:
|
||||||
self.sampler = SpeculativeSampler(fd_config)
|
self.sampler = SpeculativeSampler(fd_config)
|
||||||
|
|
||||||
@@ -753,6 +753,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||||
max_num_logprobs=20 if self.enable_logprob else None,
|
max_num_logprobs=20 if self.enable_logprob else None,
|
||||||
|
enable_early_stop=self.enable_early_stop,
|
||||||
|
stop_flags=self.share_inputs["stop_flags"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
|
@@ -28,6 +28,7 @@ from fastdeploy.config import (
|
|||||||
CacheConfig,
|
CacheConfig,
|
||||||
DecodingConfig,
|
DecodingConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
|
EarlyStopConfig,
|
||||||
ErnieArchitectures,
|
ErnieArchitectures,
|
||||||
FDConfig,
|
FDConfig,
|
||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
@@ -565,6 +566,12 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable output of token-level log probabilities.",
|
help="Enable output of token-level log probabilities.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--early_stop_config",
|
||||||
|
type=json.loads,
|
||||||
|
default=None,
|
||||||
|
help="Configuration of early stop.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@@ -608,6 +615,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
|
|
||||||
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
||||||
|
|
||||||
|
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||||
|
|
||||||
# Note(tangbinhan): used for load_checkpoint
|
# Note(tangbinhan): used for load_checkpoint
|
||||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||||
@@ -679,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
decoding_config=decoding_config,
|
decoding_config=decoding_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
graph_opt_config=graph_opt_config,
|
graph_opt_config=graph_opt_config,
|
||||||
|
early_stop_config=early_stop_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
update_fd_config_for_mm(fd_config)
|
update_fd_config_for_mm(fd_config)
|
||||||
|
235
test/layers/test_repetition_early_stopper.py
Normal file
235
test/layers/test_repetition_early_stopper.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.config import EarlyStopConfig
|
||||||
|
from fastdeploy.model_executor.layers.sample.early_stopper import RepetitionEarlyStopper
|
||||||
|
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
np.random.seed(2025)
|
||||||
|
paddle.seed(2025)
|
||||||
|
|
||||||
|
|
||||||
|
def simulate_step_probs(
|
||||||
|
batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step_i, trigger_flags, high_prob=0.99
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a probability distribution for the specified batch of samples,
|
||||||
|
some samples start to have "high confidence" after some step_i,
|
||||||
|
high_prob is the confidence of the target token (such as 0.95).
|
||||||
|
"""
|
||||||
|
probs = np.random.rand(batch_size, vocab_size).astype("float32")
|
||||||
|
probs /= probs.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
if step_i >= trigger_flags[i]:
|
||||||
|
low_prob = (1.0 - high_prob) / (vocab_size - 1)
|
||||||
|
probs[i].fill(low_prob)
|
||||||
|
if i == early_stop_batch_id:
|
||||||
|
probs[i, fixed_token_id] = high_prob
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
def remove_min_max(lst):
|
||||||
|
"""
|
||||||
|
remove the min and max value
|
||||||
|
"""
|
||||||
|
if len(lst) < 2:
|
||||||
|
return lst
|
||||||
|
min_val = min(lst)
|
||||||
|
max_val = max(lst)
|
||||||
|
return [x for x in lst if x != min_val and x != max_val]
|
||||||
|
|
||||||
|
|
||||||
|
def test_repetition_early_stopper():
|
||||||
|
# This test only for 1 batch to trigger early stop
|
||||||
|
batch_size = 20
|
||||||
|
vocab_size = 16
|
||||||
|
window_size = 4
|
||||||
|
threshold = 0.9
|
||||||
|
eos_token_id = vocab_size
|
||||||
|
max_steps = 10
|
||||||
|
|
||||||
|
# Select a token as final token
|
||||||
|
fixed_token_id = np.random.randint(0, vocab_size)
|
||||||
|
# Set a batch to trigger early stop
|
||||||
|
early_stop_batch_id = np.random.randint(0, batch_size)
|
||||||
|
print(f"{fixed_token_id=}\n{early_stop_batch_id=}\n{eos_token_id=}")
|
||||||
|
|
||||||
|
# Determine the first step in each batch where the high probability starts to appear
|
||||||
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||||
|
trigger_step_flags = dict(trigger_step_flags)
|
||||||
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||||
|
stopper = RepetitionEarlyStopper()
|
||||||
|
stopper.initialize(batch_size, cfg)
|
||||||
|
|
||||||
|
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
||||||
|
next_tokens[early_stop_batch_id, 0] = fixed_token_id
|
||||||
|
|
||||||
|
print(f"{next_tokens=}\ntrigger_start={trigger_step_flags[early_stop_batch_id]}")
|
||||||
|
|
||||||
|
triggered_step = [None] * batch_size
|
||||||
|
stop_flags = paddle.zeros_like(next_tokens)
|
||||||
|
for step in range(max_steps):
|
||||||
|
print(f"\n===== Step {step} =====")
|
||||||
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||||
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||||
|
probs = paddle.to_tensor(probs_np)
|
||||||
|
print("Before process:")
|
||||||
|
print("tokens:\n", stop_flags.numpy().T)
|
||||||
|
|
||||||
|
stopper.process(probs, next_tokens, stop_flags)
|
||||||
|
|
||||||
|
print("After process:")
|
||||||
|
print("tokens:\n", stop_flags.numpy().T)
|
||||||
|
|
||||||
|
out_np = stop_flags.numpy()
|
||||||
|
for i in range(batch_size):
|
||||||
|
if out_np[i, 0] and triggered_step[i] is None:
|
||||||
|
triggered_step[i] = step
|
||||||
|
|
||||||
|
# Show which step trigger the early stop in batch i
|
||||||
|
print("trigger_step: ", triggered_step)
|
||||||
|
assert (
|
||||||
|
triggered_step[early_stop_batch_id] == trigger_step_flags[early_stop_batch_id] + window_size - 1
|
||||||
|
), "not expected trigger step"
|
||||||
|
|
||||||
|
|
||||||
|
def test_consistency():
|
||||||
|
batch_size = 20
|
||||||
|
vocab_size = 103424
|
||||||
|
window_size = 3000
|
||||||
|
threshold = 0.9
|
||||||
|
eos_token_id = vocab_size
|
||||||
|
max_steps = 10
|
||||||
|
|
||||||
|
fixed_token_id = np.random.randint(0, vocab_size)
|
||||||
|
early_stop_batch_id = np.random.randint(0, batch_size)
|
||||||
|
|
||||||
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||||
|
trigger_step_flags = dict(trigger_step_flags)
|
||||||
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||||
|
stopper_normal = RepetitionEarlyStopper()
|
||||||
|
stopper_normal.initialize(batch_size, cfg)
|
||||||
|
stopper_triton = RepetitionEarlyStopper()
|
||||||
|
stopper_triton.initialize(batch_size, cfg)
|
||||||
|
|
||||||
|
next_tokens_normal = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
||||||
|
next_tokens_triton = next_tokens_normal.clone()
|
||||||
|
|
||||||
|
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
|
||||||
|
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
|
||||||
|
|
||||||
|
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
|
||||||
|
stop_flags_triton = stop_flags_normal.clone()
|
||||||
|
|
||||||
|
triggered_step_normal = [None] * batch_size
|
||||||
|
triggered_step_triton = [None] * batch_size
|
||||||
|
|
||||||
|
for step in range(max_steps):
|
||||||
|
|
||||||
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||||
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||||
|
probs = paddle.to_tensor(probs_np)
|
||||||
|
|
||||||
|
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
|
||||||
|
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
|
||||||
|
|
||||||
|
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
|
||||||
|
|
||||||
|
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
|
||||||
|
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
|
||||||
|
|
||||||
|
out_normal = stop_flags_normal.numpy()
|
||||||
|
out_triton = stop_flags_triton.numpy()
|
||||||
|
for i in range(batch_size):
|
||||||
|
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
|
||||||
|
triggered_step_normal[i] = step
|
||||||
|
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
|
||||||
|
triggered_step_triton[i] = step
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
expected = triggered_step_normal[i]
|
||||||
|
actual = triggered_step_triton[i]
|
||||||
|
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
||||||
|
|
||||||
|
print("Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_performance():
|
||||||
|
batch_size = 256
|
||||||
|
vocab_size = 103424
|
||||||
|
window_size = 3000
|
||||||
|
threshold = 0.9
|
||||||
|
eos_token_id = vocab_size
|
||||||
|
max_steps = 50
|
||||||
|
|
||||||
|
fixed_token_id = np.random.randint(0, vocab_size)
|
||||||
|
early_stop_batch_id = np.random.randint(0, batch_size)
|
||||||
|
print(f"{fixed_token_id=}\n{early_stop_batch_id=}")
|
||||||
|
|
||||||
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||||
|
trigger_step_flags = dict(trigger_step_flags)
|
||||||
|
|
||||||
|
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
||||||
|
next_tokens[early_stop_batch_id, 0] = fixed_token_id
|
||||||
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||||
|
print("Testing performance triton...")
|
||||||
|
seconds = []
|
||||||
|
stopper = RepetitionEarlyStopper()
|
||||||
|
stopper.initialize(batch_size, cfg)
|
||||||
|
stop_flags = paddle.zeros_like(next_tokens)
|
||||||
|
for step in range(max_steps):
|
||||||
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||||
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||||
|
probs = paddle.to_tensor(probs_np)
|
||||||
|
s = time.perf_counter()
|
||||||
|
stopper.process_triton(probs, next_tokens, stop_flags)
|
||||||
|
e = time.perf_counter()
|
||||||
|
seconds.append(e - s)
|
||||||
|
print(
|
||||||
|
f"triton:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Testing performance normal...")
|
||||||
|
seconds = []
|
||||||
|
stopper = RepetitionEarlyStopper()
|
||||||
|
stopper.initialize(batch_size, cfg)
|
||||||
|
stop_flags = paddle.zeros_like(next_tokens)
|
||||||
|
for step in range(max_steps):
|
||||||
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
||||||
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||||
|
probs = paddle.to_tensor(probs_np)
|
||||||
|
s = time.perf_counter()
|
||||||
|
stopper.process_normal(probs, next_tokens, stop_flags)
|
||||||
|
e = time.perf_counter()
|
||||||
|
seconds.append(e - s)
|
||||||
|
print(
|
||||||
|
f"normal:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Config:")
|
||||||
|
print(f"{batch_size=}, {window_size=}, {threshold=}, {eos_token_id=}, {vocab_size=}, {max_steps=}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_repetition_early_stopper()
|
||||||
|
test_consistency()
|
||||||
|
test_performance()
|
Reference in New Issue
Block a user