mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00

* support repetition early stop and support user to set the parameter * remove log * fix codestyle * add the early_stop_config to rollout_config * update config and EarlyStopper class * fix the bug for triton * modify the stop method * update description * modify the usage for stop_flags --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
236 lines
9.0 KiB
Python
236 lines
9.0 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import time
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.config import EarlyStopConfig
|
|
from fastdeploy.model_executor.layers.sample.early_stopper import RepetitionEarlyStopper
|
|
|
|
paddle.set_device("gpu")
|
|
np.random.seed(2025)
|
|
paddle.seed(2025)
|
|
|
|
|
|
def simulate_step_probs(
|
|
batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step_i, trigger_flags, high_prob=0.99
|
|
):
|
|
"""
|
|
Generate a probability distribution for the specified batch of samples,
|
|
some samples start to have "high confidence" after some step_i,
|
|
high_prob is the confidence of the target token (such as 0.95).
|
|
"""
|
|
probs = np.random.rand(batch_size, vocab_size).astype("float32")
|
|
probs /= probs.sum(axis=1, keepdims=True)
|
|
|
|
for i in range(batch_size):
|
|
if step_i >= trigger_flags[i]:
|
|
low_prob = (1.0 - high_prob) / (vocab_size - 1)
|
|
probs[i].fill(low_prob)
|
|
if i == early_stop_batch_id:
|
|
probs[i, fixed_token_id] = high_prob
|
|
return probs
|
|
|
|
|
|
def remove_min_max(lst):
|
|
"""
|
|
remove the min and max value
|
|
"""
|
|
if len(lst) < 2:
|
|
return lst
|
|
min_val = min(lst)
|
|
max_val = max(lst)
|
|
return [x for x in lst if x != min_val and x != max_val]
|
|
|
|
|
|
def test_repetition_early_stopper():
|
|
# This test only for 1 batch to trigger early stop
|
|
batch_size = 20
|
|
vocab_size = 16
|
|
window_size = 4
|
|
threshold = 0.9
|
|
eos_token_id = vocab_size
|
|
max_steps = 10
|
|
|
|
# Select a token as final token
|
|
fixed_token_id = np.random.randint(0, vocab_size)
|
|
# Set a batch to trigger early stop
|
|
early_stop_batch_id = np.random.randint(0, batch_size)
|
|
print(f"{fixed_token_id=}\n{early_stop_batch_id=}\n{eos_token_id=}")
|
|
|
|
# Determine the first step in each batch where the high probability starts to appear
|
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
|
trigger_step_flags = dict(trigger_step_flags)
|
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
|
stopper = RepetitionEarlyStopper()
|
|
stopper.initialize(batch_size, cfg)
|
|
|
|
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
|
next_tokens[early_stop_batch_id, 0] = fixed_token_id
|
|
|
|
print(f"{next_tokens=}\ntrigger_start={trigger_step_flags[early_stop_batch_id]}")
|
|
|
|
triggered_step = [None] * batch_size
|
|
stop_flags = paddle.zeros_like(next_tokens)
|
|
for step in range(max_steps):
|
|
print(f"\n===== Step {step} =====")
|
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
|
probs = paddle.to_tensor(probs_np)
|
|
print("Before process:")
|
|
print("tokens:\n", stop_flags.numpy().T)
|
|
|
|
stopper.process(probs, next_tokens, stop_flags)
|
|
|
|
print("After process:")
|
|
print("tokens:\n", stop_flags.numpy().T)
|
|
|
|
out_np = stop_flags.numpy()
|
|
for i in range(batch_size):
|
|
if out_np[i, 0] and triggered_step[i] is None:
|
|
triggered_step[i] = step
|
|
|
|
# Show which step trigger the early stop in batch i
|
|
print("trigger_step: ", triggered_step)
|
|
assert (
|
|
triggered_step[early_stop_batch_id] == trigger_step_flags[early_stop_batch_id] + window_size - 1
|
|
), "not expected trigger step"
|
|
|
|
|
|
def test_consistency():
|
|
batch_size = 20
|
|
vocab_size = 103424
|
|
window_size = 3000
|
|
threshold = 0.9
|
|
eos_token_id = vocab_size
|
|
max_steps = 10
|
|
|
|
fixed_token_id = np.random.randint(0, vocab_size)
|
|
early_stop_batch_id = np.random.randint(0, batch_size)
|
|
|
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
|
trigger_step_flags = dict(trigger_step_flags)
|
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
|
stopper_normal = RepetitionEarlyStopper()
|
|
stopper_normal.initialize(batch_size, cfg)
|
|
stopper_triton = RepetitionEarlyStopper()
|
|
stopper_triton.initialize(batch_size, cfg)
|
|
|
|
next_tokens_normal = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
|
next_tokens_triton = next_tokens_normal.clone()
|
|
|
|
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
|
|
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
|
|
|
|
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
|
|
stop_flags_triton = stop_flags_normal.clone()
|
|
|
|
triggered_step_normal = [None] * batch_size
|
|
triggered_step_triton = [None] * batch_size
|
|
|
|
for step in range(max_steps):
|
|
|
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
|
probs = paddle.to_tensor(probs_np)
|
|
|
|
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
|
|
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
|
|
|
|
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
|
|
|
|
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
|
|
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
|
|
|
|
out_normal = stop_flags_normal.numpy()
|
|
out_triton = stop_flags_triton.numpy()
|
|
for i in range(batch_size):
|
|
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
|
|
triggered_step_normal[i] = step
|
|
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
|
|
triggered_step_triton[i] = step
|
|
|
|
for i in range(batch_size):
|
|
expected = triggered_step_normal[i]
|
|
actual = triggered_step_triton[i]
|
|
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
|
|
|
print("Triton vs Normal: All tokens, states, and trigger timings match.")
|
|
|
|
|
|
def test_performance():
|
|
batch_size = 256
|
|
vocab_size = 103424
|
|
window_size = 3000
|
|
threshold = 0.9
|
|
eos_token_id = vocab_size
|
|
max_steps = 50
|
|
|
|
fixed_token_id = np.random.randint(0, vocab_size)
|
|
early_stop_batch_id = np.random.randint(0, batch_size)
|
|
print(f"{fixed_token_id=}\n{early_stop_batch_id=}")
|
|
|
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
|
trigger_step_flags = dict(trigger_step_flags)
|
|
|
|
next_tokens = paddle.randint(0, vocab_size, shape=[batch_size, 1], dtype="int64")
|
|
next_tokens[early_stop_batch_id, 0] = fixed_token_id
|
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
|
print("Testing performance triton...")
|
|
seconds = []
|
|
stopper = RepetitionEarlyStopper()
|
|
stopper.initialize(batch_size, cfg)
|
|
stop_flags = paddle.zeros_like(next_tokens)
|
|
for step in range(max_steps):
|
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
|
probs = paddle.to_tensor(probs_np)
|
|
s = time.perf_counter()
|
|
stopper.process_triton(probs, next_tokens, stop_flags)
|
|
e = time.perf_counter()
|
|
seconds.append(e - s)
|
|
print(
|
|
f"triton:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
|
|
)
|
|
|
|
print("Testing performance normal...")
|
|
seconds = []
|
|
stopper = RepetitionEarlyStopper()
|
|
stopper.initialize(batch_size, cfg)
|
|
stop_flags = paddle.zeros_like(next_tokens)
|
|
for step in range(max_steps):
|
|
flags = [trigger_step_flags[i] for i in range(batch_size)]
|
|
probs_np = simulate_step_probs(batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
|
probs = paddle.to_tensor(probs_np)
|
|
s = time.perf_counter()
|
|
stopper.process_normal(probs, next_tokens, stop_flags)
|
|
e = time.perf_counter()
|
|
seconds.append(e - s)
|
|
print(
|
|
f"normal:\nexecute times: {max_steps}\ntotal execution time: {np.sum(seconds)*1000} ms \navg every step execution time: {np.mean(remove_min_max(seconds))*1000} ms"
|
|
)
|
|
|
|
print("Config:")
|
|
print(f"{batch_size=}, {window_size=}, {threshold=}, {eos_token_id=}, {vocab_size=}, {max_steps=}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_repetition_early_stopper()
|
|
test_consistency()
|
|
test_performance()
|