[Feature] support prompt repetition_penalty (#2954)

* [Feature] support prompt repetition_penalty (#2806)

* [Bug Fix] fix bug of prompt penalty (#2888)
This commit is contained in:
ming1753
2025-07-22 19:42:33 +08:00
committed by GitHub
parent 535a15ab8f
commit 69be77c8c0
8 changed files with 305 additions and 64 deletions

View File

@@ -20,16 +20,16 @@ __global__ inline void min_length_logits_process(T *logits,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t bs,
const int64_t length,
const int64_t end_length) {
const int64_t vocab_size,
const int64_t eos_len) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < end_length; i++) {
logits[bi * length + eos_token_id[i]] = -1e10;
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e10;
}
}
}
@@ -41,61 +41,83 @@ __global__ inline void min_length_logits_process<half>(
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t bs,
const int64_t length,
const int64_t end_length) {
const int64_t vocab_size,
const int64_t eos_len) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < end_length; i++) {
logits[bi * length + eos_token_id[i]] = -1e4;
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e4;
}
}
}
__global__ void update_repeat_times(const int64_t *pre_ids,
const int64_t *prompt_ids,
const int64_t *prompt_len,
const int64_t *cur_len,
int *repeat_times,
int *is_repeated,
const int64_t bs,
const int64_t length,
const int64_t length_id) {
int bi = blockIdx.x;
const int64_t vocab_size,
const int64_t max_dec_len,
const int64_t max_model_len) {
int64_t bi = blockIdx.x;
if (cur_len[bi] < 0) {
return;
}
int tid = threadIdx.x;
const int64_t *pre_ids_now = pre_ids + bi * length_id;
int *repeat_times_now = repeat_times + bi * length;
for (int i = tid; i < length_id; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
const int64_t prompt_len_now = prompt_len[bi];
int64_t tid = threadIdx.x;
const int64_t *prompt_now = prompt_ids + bi * max_model_len;
const int64_t *pre_ids_now = pre_ids + bi * max_dec_len;
int *repeat_times_now = repeat_times + bi * vocab_size;
int *is_repeated_now = is_repeated + bi * vocab_size;
const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len;
for (int64_t i = tid; i < loop_len; i += blockDim.x) {
if (i < max_dec_len) {
int64_t id = pre_ids_now[i];
if (id >= 0) {
atomicAdd(&repeat_times_now[id], 1);
atomicAdd(&is_repeated_now[id], 1);
}
}
if (i < prompt_len_now) {
int64_t id = prompt_now[i];
if (id >= 0) {
atomicAdd(&is_repeated_now[id], 1);
}
}
}
}
template <typename T>
__global__ void update_value_by_repeat_times(const int *repeat_times,
const int *is_repeated,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int64_t bs,
const int64_t length) {
const int64_t vocab_size) {
int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * length;
const int *repeat_times_now = repeat_times + bi * length;
T *logits_now = logits + bi * vocab_size;
const int *repeat_times_now = repeat_times + bi * vocab_size;
const int *is_repeated_now = is_repeated + bi * vocab_size;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = tid; i < length; i += blockDim.x) {
for (int i = tid; i < vocab_size; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times != 0) {
if (is_repeated_now[i] != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
}
if (times != 0) {
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
@@ -106,20 +128,22 @@ template <typename T>
__global__ void ban_bad_words(T *logits,
const int64_t *bad_words_list,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length) {
const int64_t vocab_size,
const int64_t bad_words_len) {
const int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * length;
for (int i = tid; i < bad_words_length; i += blockDim.x) {
T *logits_now = logits + bi * vocab_size;
for (int i = tid; i < bad_words_len; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
if (bad_words_token_id >= vocab_size || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}
template <paddle::DataType D>
void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const paddle::Tensor &prompt_ids,
const paddle::Tensor &prompt_len,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_score,
@@ -141,12 +165,15 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
auto is_repeated =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t end_length = eos_token_id.shape()[0];
int64_t vocab_size = shape[1];
int64_t max_dec_len = pre_ids.shape()[1];
int64_t bad_words_len = bad_tokens.shape()[0];
int64_t eos_len = eos_token_id.shape()[0];
int64_t max_model_len = prompt_ids.shape()[1];
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
@@ -156,10 +183,10 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bs,
length,
end_length);
vocab_size,
eos_len);
block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (max_dec_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
@@ -167,13 +194,17 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
#endif
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
pre_ids.data<int64_t>(),
prompt_ids.data<int64_t>(),
prompt_len.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
is_repeated.data<int>(),
bs,
length,
length_id);
vocab_size,
max_dec_len,
max_model_len);
block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (vocab_size + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
@@ -181,6 +212,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
#endif
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
is_repeated.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>(
@@ -191,9 +223,9 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bs,
length);
vocab_size);
block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (bad_words_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
@@ -204,11 +236,13 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
bs,
length,
length_bad_words);
vocab_size,
bad_words_len);
}
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &prompt_ids,
const paddle::Tensor &prompt_len,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
@@ -222,6 +256,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<
paddle::DataType::BFLOAT16>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
@@ -233,30 +269,34 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
eos_token_id);
}
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
return token_penalty_multi_scores_kernel<
paddle::DataType::FLOAT16>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
case paddle::DataType::FLOAT32: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
return token_penalty_multi_scores_kernel<
paddle::DataType::FLOAT32>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
default: {
PD_THROW(
@@ -269,6 +309,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores)
.Inputs({"pre_ids",
"prompt_ids",
"prompt_len",
"logits",
"penalty_scores",
"frequency_scores",

View File

@@ -43,3 +43,5 @@ class SamplingMetadata:
top_p: paddle.Tensor
top_k: Optional[paddle.Tensor] = None
max_num_logprobs: Optional[int] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None

View File

@@ -21,6 +21,8 @@ from fastdeploy.platforms import current_platform
def apply_penalty_multi_scores(
pre_token_ids: paddle.Tensor,
prompt_ids: paddle.Tensor,
prompt_lens: paddle.Tensor,
logits: paddle.Tensor,
repetition_penalties: paddle.Tensor,
frequency_penalties: paddle.Tensor,
@@ -39,6 +41,8 @@ def apply_penalty_multi_scores(
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,
@@ -69,6 +73,8 @@ def apply_penalty_multi_scores(
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,

View File

@@ -253,6 +253,8 @@ class Sampler(nn.Layer):
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
sampling_metadata.prompt_lens,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,

View File

@@ -216,12 +216,15 @@ class GPUModelRunner(ModelRunnerBase):
1] = request.prompt_token_ids[-1]
self.share_inputs["input_ids"][idx:idx + 1,
0] = request.prompt_token_ids[0]
self.share_inputs["prompt_ids"][idx:idx + 1,
:length] = np.array(request.prompt_token_ids)
self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0
self.share_inputs['seq_lens_decoder'][idx:idx + 1] = length
self.share_inputs['seq_lens_this_time'][idx:idx + 1] = 1
self.share_inputs['step_seq_lens_encoder'][idx:idx + 1] = 0
self.share_inputs['step_seq_lens_decoder'][idx:idx +
1] = length
self.share_inputs["prompt_lens"][idx:idx + 1] = length
self.share_inputs['step_idx'][idx:idx + 1] = 1
if self.speculative_decoding:
@@ -236,6 +239,9 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["input_ids"][idx:idx +
1, :length] = np.array(
request.prompt_token_ids)
self.share_inputs["prompt_ids"][idx:idx +
1, :length] = np.array(
request.prompt_token_ids)
# Use chunked prefill
if self.parallel_config.enable_chunked_prefill:
@@ -275,6 +281,7 @@ class GPUModelRunner(ModelRunnerBase):
idx:idx + 1] = token_chunk_size
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.share_inputs["prompt_lens"][idx:idx + 1] = token_chunk_size
else:
if self.enable_mm:
inputs = self._preprocess_mm_task(request.multimodal_inputs)
@@ -299,6 +306,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs['step_seq_lens_encoder'][idx:idx +
1] = length
self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length
self.share_inputs["prompt_lens"][idx:idx + 1] = length
if self.enable_mm:
enable_thinking = request.get("enable_thinking", True)
@@ -397,6 +405,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["input_ids"][idx:idx +
1, :input_length] = np.array(
[5] * input_length)
self.share_inputs["prompt_ids"][idx:idx + 1, :input_length] = np.array(
[5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array(
[2], dtype="int64").reshape(-1, 1)
self.share_inputs["seq_lens_this_time"][idx:idx + 1] = input_length
@@ -404,6 +414,7 @@ class GPUModelRunner(ModelRunnerBase):
1] = input_length
self.share_inputs["seq_lens_encoder"][idx:idx + 1] = input_length
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
self.share_inputs["prompt_lens"][idx:idx + 1] = 0
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len
self.share_inputs["min_dec_len"][idx:idx + 1] = max_dec_len
@@ -434,6 +445,10 @@ class GPUModelRunner(ModelRunnerBase):
[max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id,
dtype='int64')
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id,
dtype='int64')
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
@@ -478,6 +493,9 @@ class GPUModelRunner(ModelRunnerBase):
[max_num_seqs, 1], 0, dtype='int32')
self.share_inputs["step_seq_lens_decoder"] = paddle.full(
[max_num_seqs, 1], 0, dtype='int32')
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
@@ -687,6 +705,8 @@ class GPUModelRunner(ModelRunnerBase):
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],
prompt_lens=self.share_inputs["prompt_lens"],
frequency_penalties=self.share_inputs["frequency_score"],
presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"],
@@ -1022,6 +1042,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["image_features"] = None
token_chunk_size = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][idx:idx + 1, :token_chunk_size] = inputs["input_ids"]
self.share_inputs["prompt_ids"][
idx:idx + 1,
self.share_inputs["prompt_lens"][idx:idx + 1]: self.share_inputs["prompt_lens"][idx:idx + 1] + token_chunk_size
] = inputs["input_ids"]
self.share_inputs["seq_lens_decoder"][idx:idx +1] = task.start_idx
task.start_idx += token_chunk_size
else:
@@ -1034,6 +1058,7 @@ class GPUModelRunner(ModelRunnerBase):
1] = token_chunk_size
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.share_inputs["prompt_lens"][idx:idx + 1] += token_chunk_size
self.share_inputs["step_idx"][idx:idx + 1] = 0
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(

View File

@@ -174,6 +174,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs['step_seq_lens_encoder'][idx:idx + 1] = 0
self.share_inputs['step_seq_lens_decoder'][idx:idx +
1] = length
self.share_inputs["prompt_lens"][idx:idx + 1] = length
self.share_inputs['step_idx'][idx:idx + 1] = 1
if self.speculative_decoding:
@@ -208,6 +209,7 @@ class IluvatarModelRunner(ModelRunnerBase):
idx:idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs['step_seq_lens_decoder'][
idx:idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["prompt_lens"][idx:idx + 1] = token_chunk_size
else:
self.share_inputs['seq_lens_decoder'][
idx:idx + 1] = request.get("seq_lens_decoder", 0)
@@ -218,6 +220,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs['step_seq_lens_encoder'][idx:idx +
1] = length
self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length
self.share_inputs["prompt_lens"][idx:idx + 1] = length
if len(request.eos_token_ids
) < self.parallel_config.eos_tokens_lens:
@@ -290,6 +293,8 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["input_ids"][idx:idx +
1, :input_length] = np.array(
[5] * input_length)
self.share_inputs["prompt_ids"][idx:idx + 1, :input_length] = np.array(
[5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array(
[2], dtype="int64").reshape(-1, 1)
self.share_inputs["seq_lens_this_time"][idx:idx + 1] = input_length
@@ -297,6 +302,7 @@ class IluvatarModelRunner(ModelRunnerBase):
1] = input_length
self.share_inputs["seq_lens_encoder"][idx:idx + 1] = input_length
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
self.share_inputs["prompt_lens"][idx:idx + 1] = 0
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len
self.share_inputs["stop_flags"][idx:idx + 1] = False
@@ -325,6 +331,10 @@ class IluvatarModelRunner(ModelRunnerBase):
[max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id,
dtype='int64')
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id,
dtype='int64')
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
@@ -369,6 +379,9 @@ class IluvatarModelRunner(ModelRunnerBase):
[max_num_seqs, 1], 0, dtype='int32')
self.share_inputs["step_seq_lens_decoder"] = paddle.full(
[max_num_seqs, 1], 0, dtype='int32')
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
@@ -563,6 +576,8 @@ class IluvatarModelRunner(ModelRunnerBase):
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],
prompt_lens=self.share_inputs["prompt_lens"],
frequency_penalties=self.share_inputs["frequency_score"],
presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"],
@@ -845,6 +860,7 @@ class IluvatarModelRunner(ModelRunnerBase):
token_chunk_size])
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.share_inputs["prompt_lens"][idx:idx + 1] += token_chunk_size
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["seq_lens_decoder"][
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)

View File

@@ -57,6 +57,12 @@ def _create_default_sampling_metadata(
top_p=paddle.full(shape=[batch_size, 1],
fill_value=0.7,
dtype="float32"),
prompt_ids=paddle.full(shape=[batch_size, max_seq_len],
fill_value=0,
dtype="int64"),
prompt_lens=paddle.full(shape=[batch_size, 1],
fill_value=5,
dtype="int64"),
step_idx=paddle.full(shape=[batch_size, 1],
fill_value=0,
dtype="int64"),

View File

@@ -0,0 +1,142 @@
# Copyright (c) 2025PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" UT for air_topp_sampling kernel """
import copy
import unittest
import numpy as np
import paddle
class Test(unittest.TestCase):
def setUp(self):
"""
Initialize.
"""
self.num_seqs = 4
self.max_model_len = 32768
self.vocab_size = 103424
# prompt token
prompt_ids = paddle.full(shape=[self.num_seqs, self.max_model_len], fill_value=0, dtype='int64')
prompt_lens = paddle.randint(low=0, high=100, shape=[self.num_seqs, 1], dtype='int64')
fake_tokens = paddle.randint(low=3, high=self.vocab_size, shape=[self.num_seqs, self.max_model_len], dtype='int64')
for i in range(self.num_seqs):
prompt_ids[i, :prompt_lens[i]] = fake_tokens[i, :prompt_lens[i]]
# generated token
pre_ids = paddle.full(shape=[self.num_seqs, self.max_model_len], fill_value=-1, dtype='int64')
step_idx = paddle.randint(low=0, high=100, shape=[self.num_seqs, 1], dtype='int64')
fake_tokens = paddle.randint(low=3, high=self.vocab_size, shape=[self.num_seqs, self.max_model_len], dtype='int64')
for i in range(self.num_seqs):
pre_ids[i, :step_idx[i]] = fake_tokens[i, :step_idx[i]]
logits = paddle.randn([self.num_seqs, self.vocab_size]).cast("float32")
penalty_score = paddle.ones([self.num_seqs, 1]) * 1.05
frequency_score = paddle.ones([self.num_seqs, 1]) * 0.5
presence_score = paddle.ones([self.num_seqs, 1]) * 0.3
temperature = paddle.ones([self.num_seqs, 1]) * 0.8
bad_tokens = paddle.to_tensor([[-1]]).cast("int64")
min_dec_len = paddle.ones([self.num_seqs, 1]).cast("int64")
eos_token_id = paddle.to_tensor([[2]]).cast("int64")
self.input_data = {
"prompt_ids": prompt_ids,
"prompt_lens": prompt_lens,
"pre_ids": pre_ids,
"step_idx": step_idx,
"logits": logits,
"bad_tokens": bad_tokens,
"min_dec_len": min_dec_len,
"eos_token_id": eos_token_id,
"penalty_score": penalty_score,
"frequency_score": frequency_score,
"presence_score": presence_score,
"temperature": temperature
}
def get_token_penalty_multi_scores_baseline(self):
input_data = copy.deepcopy(self.input_data)
logits = input_data["logits"]
penalty_score = input_data["penalty_score"]
frequency_score = input_data["frequency_score"]
presence_score = input_data["presence_score"]
temperature = input_data["temperature"]
# min token penalties
mask = input_data["step_idx"] < input_data["min_dec_len"]
for bi, flag in enumerate(mask):
if flag:
logits[bi, input_data["eos_token_id"]] = -1e10
# bad words exclusion
for token in input_data["bad_tokens"]:
if token < 0 or token > self.vocab_size:
continue
logits[:, token] = -1e10
# all penalties
prompt_ids = input_data["prompt_ids"]
for i in range(self.num_seqs):
prompt_ids[i, input_data["prompt_lens"][i]:] = -1
prompt_repeat_times = paddle.zeros([self.num_seqs, self.vocab_size + 1]).cast("int64")
prompt_repeat_times = paddle.put_along_axis(prompt_repeat_times, prompt_ids, paddle.ones_like(input_data["pre_ids"]), axis=1, reduce="add")
prompt_repeat_times = prompt_repeat_times[:, :self.vocab_size]
prompt_mask = prompt_repeat_times > 0
pre_ids = input_data["pre_ids"]
pre_ids[pre_ids == -1] = self.vocab_size
out_repeat_times = paddle.zeros([self.num_seqs, self.vocab_size + 1]).cast("int64")
out_repeat_times = paddle.put_along_axis(out_repeat_times, pre_ids, paddle.ones_like(input_data["pre_ids"]), axis=1, reduce="add")
out_repeat_times = out_repeat_times[:, :self.vocab_size]
output_mask = out_repeat_times > 0
penalty_score = penalty_score.tile(self.vocab_size)
logits[logits > 0] /= paddle.where(output_mask | prompt_mask, penalty_score, 1.0)[logits > 0]
logits[logits <= 0] *= paddle.where(output_mask | prompt_mask, penalty_score, 1.0)[logits <= 0]
logits -= frequency_score * out_repeat_times.cast("float32")
logits -= presence_score * output_mask.cast("float32")
# temperature
logits /= temperature
return logits
def test_penalty_op(self):
"""
"""
baseline_out = self.get_token_penalty_multi_scores_baseline()
from fastdeploy.model_executor.ops.gpu import \
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
self.input_data["pre_ids"],
self.input_data["prompt_ids"],
self.input_data["prompt_lens"],
self.input_data["logits"],
self.input_data["penalty_score"],
self.input_data["frequency_score"],
self.input_data["presence_score"],
self.input_data["temperature"],
self.input_data["bad_tokens"],
self.input_data["step_idx"],
self.input_data["min_dec_len"],
self.input_data["eos_token_id"])
np.testing.assert_allclose(baseline_out.numpy(), logits.numpy(), rtol=1e-04, atol=1e-04)
if __name__ == "__main__":
unittest.main()