mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Support repetition early stop (#3024)
* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -26,6 +26,9 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||
get_early_stopper_cls_from_stragegy,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores,
|
||||
@@ -165,7 +168,7 @@ class Sampler(nn.Layer):
|
||||
Sampler for normal generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, fd_config: FDConfig = None):
|
||||
""" """
|
||||
super().__init__()
|
||||
if (
|
||||
@@ -180,6 +183,15 @@ class Sampler(nn.Layer):
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
fd_config is not None
|
||||
and fd_config.early_stop_config is not None
|
||||
and fd_config.early_stop_config.enable_early_stop
|
||||
):
|
||||
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
|
||||
self.early_stopper = early_stopper_cls()
|
||||
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
||||
|
||||
def apply_logits_processor(
|
||||
self,
|
||||
@@ -275,6 +287,10 @@ class Sampler(nn.Layer):
|
||||
logprobs_tensors = (
|
||||
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
||||
)
|
||||
if sampling_metadata.enable_early_stop:
|
||||
# will set the stop batch in stop_flags
|
||||
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
|
||||
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
|
||||
|
Reference in New Issue
Block a user