mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 14:52:33 +08:00
[stop sequence] support stop sequence (#3025)
* stop seqs in multi-ends * unittest for gpu stop op * kernel tid==0
This commit is contained in:
@@ -266,13 +266,12 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const bool beam_search);
|
||||
|
||||
void GetStopFlagsMultiSeqs(
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids);
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
@@ -954,12 +953,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("set_stop_value_multi_ends", &GetStopFlagsMulti,
|
||||
"update_inputs function");
|
||||
|
||||
/**
|
||||
* stop_generation_multi_stop_seqs.cu
|
||||
* set_stop_value_multi_seqs
|
||||
*/
|
||||
m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs,
|
||||
"update_inputs function");
|
||||
|
||||
/**
|
||||
* update_inputs.cu
|
||||
|
@@ -30,30 +30,62 @@ __global__ void set_value_by_flags(bool *stop_flags,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const int64_t *pre_ids,
|
||||
const int pre_ids_len,
|
||||
const int64_t *step_idx,
|
||||
const int64_t *stop_seqs,
|
||||
const int *stop_seqs_len,
|
||||
const int stop_seqs_bs,
|
||||
const int stop_seqs_max_len,
|
||||
bool beam_search,
|
||||
bool prefill_one_step_stop) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs) {
|
||||
int bid = blockIdx.x;
|
||||
if (tid >= stop_seqs_bs) return;
|
||||
if (bid < bs) {
|
||||
if(tid == 0){
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
if (seq_lens[tid] == 0) {
|
||||
topk_ids[tid] = -1;
|
||||
stop_flags[bid] = true;
|
||||
if (seq_lens[bid] == 0) {
|
||||
topk_ids[bid] = -1;
|
||||
}
|
||||
next_tokens[tid] = topk_ids[tid];
|
||||
next_tokens[bid] = topk_ids[bid];
|
||||
} else {
|
||||
if (stop_flags[tid]) {
|
||||
if (seq_lens[tid] == 0) {
|
||||
topk_ids[tid] = -1;
|
||||
if (stop_flags[bid]) {
|
||||
if (seq_lens[bid] == 0) {
|
||||
topk_ids[bid] = -1;
|
||||
} else {
|
||||
topk_ids[tid] = end_ids[0];
|
||||
next_tokens[tid] = end_ids[0];
|
||||
topk_ids[bid] = end_ids[0];
|
||||
next_tokens[bid] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[tid] = topk_ids[tid];
|
||||
next_tokens[bid] = topk_ids[bid];
|
||||
}
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[tid], end_ids, end_length)) {
|
||||
stop_flags[tid] = true;
|
||||
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
|
||||
stop_flags[bid] = true;
|
||||
}
|
||||
}
|
||||
// dealing stop_seqs
|
||||
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
|
||||
if (stop_seq_len <= 0) return;
|
||||
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
|
||||
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
|
||||
bool is_end = true;
|
||||
int count = 1;
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
if ((step_idx_now - count) < 0 ||
|
||||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
|
||||
is_end = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_end) {
|
||||
next_tokens[bid] = end_ids[0];
|
||||
stop_flags[bid] = true;
|
||||
topk_ids[bid] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,6 +95,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const bool beam_search) {
|
||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
@@ -83,8 +119,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
|
||||
int stop_seqs_bs = stop_seqs.shape()[1];
|
||||
int stop_seqs_max_len = stop_seqs.shape()[2];
|
||||
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
@@ -92,12 +130,19 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
pre_ids.data<int64_t>(),
|
||||
pre_ids.shape()[1],
|
||||
step_idx.data<int64_t>(),
|
||||
stop_seqs.data<int64_t>(),
|
||||
stop_seqs_len.data<int>(),
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
|
||||
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens"})
|
||||
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
|
||||
.Attrs({"beam_search: bool"})
|
||||
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
|
||||
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
|
||||
|
@@ -1,133 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void set_value_by_stop_seqs(bool *stop_flags,
|
||||
int64_t *topk_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int64_t *stop_seqs,
|
||||
const int *stop_seqs_len,
|
||||
const int *seq_lens,
|
||||
const int64_t *end_ids,
|
||||
const int bs,
|
||||
const int stop_seqs_bs,
|
||||
const int stop_seqs_max_len,
|
||||
const int pre_ids_len) {
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
if (tid >= stop_seqs_bs) return;
|
||||
|
||||
const int stop_seq_len = stop_seqs_len[tid];
|
||||
if (stop_seq_len <= 0) return;
|
||||
const int64_t *stop_seq_now = stop_seqs + tid * stop_seqs_max_len;
|
||||
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
if (bid < bs) {
|
||||
if (stop_flags[bid]) { // 长度超限,当前位置置为2
|
||||
topk_ids[bid] = end_ids[0];
|
||||
if (seq_lens[bid] == 0) { // 已终止,当前位置置为-1
|
||||
topk_ids[bid] = -1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
bool is_end = true;
|
||||
int count = 1;
|
||||
if (topk_ids[bid] == end_ids[0]) {
|
||||
if (tid == 0) {
|
||||
stop_flags[bid] = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
if ((step_idx_now - count) < 0 ||
|
||||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
|
||||
is_end = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_end) {
|
||||
topk_ids[bid] = end_ids[0];
|
||||
stop_flags[bid] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const paddle::Tensor &end_ids) {
|
||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = topk_ids.stream();
|
||||
#endif
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
|
||||
int bs_now = shape[0];
|
||||
int stop_seqs_bs = stop_seqs_shape[0];
|
||||
int stop_seqs_max_len = stop_seqs_shape[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
|
||||
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
stop_seqs.data<int64_t>(),
|
||||
stop_seqs_len.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
end_ids.data<int64_t>(),
|
||||
bs_now,
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(set_stop_value_multi_seqs)
|
||||
.Inputs({"topk_ids",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"stop_flags",
|
||||
"seq_lens",
|
||||
"stop_seqs",
|
||||
"stop_seqs_len",
|
||||
"end_ids"})
|
||||
.Outputs({"topk_ids_out", "stop_flags_out"})
|
||||
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetStopFlagsMultiSeqs));
|
@@ -260,7 +260,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/token_penalty_only_once.cu",
|
||||
"gpu_ops/stop_generation.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
"gpu_ops/stop_generation_multi_stop_seqs.cu",
|
||||
"gpu_ops/set_flags.cu",
|
||||
"gpu_ops/update_inputs_v1.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
@@ -529,7 +528,6 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
sources=[
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/stop_generation_multi_stop_seqs.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
|
@@ -101,9 +101,6 @@ class ModelConfig:
|
||||
self,
|
||||
args,
|
||||
):
|
||||
self.max_stop_seqs_num = 5
|
||||
self.stop_seqs_max_len = 8
|
||||
|
||||
# NOTE(gongshaotain): form _load_model_init_val()
|
||||
self.top_p = 1.0
|
||||
self.temperature = 1.0
|
||||
@@ -122,6 +119,9 @@ class ModelConfig:
|
||||
self.enable_redundant_experts = False
|
||||
self.redundant_experts_num = 0
|
||||
|
||||
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
@@ -90,7 +90,8 @@ class SamplingParams:
|
||||
min_p: float = 0.0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
stop_seqs_len: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
reasoning_max_tokens: Optional[int] = None
|
||||
min_tokens: int = 1
|
||||
|
@@ -414,6 +414,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
Update stop sequences from request.
|
||||
"""
|
||||
stop_seqs = []
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
for seq in stop_sequences:
|
||||
if seq != self.tokenizer.eos_token_id:
|
||||
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
|
||||
|
@@ -210,7 +210,21 @@ def post_process_normal(
|
||||
paddle.logical_or(model_output.stop_flags, length_cond),
|
||||
model_output.stop_flags,
|
||||
)
|
||||
# TODO(gongshaotian): Add use_stop_seqs
|
||||
|
||||
if current_platform.is_cuda():
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
model_output.pre_ids,
|
||||
model_output.step_idx,
|
||||
model_output.stop_token_ids,
|
||||
model_output.stop_seqs_len,
|
||||
False,
|
||||
) # multi ends
|
||||
else:
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
@@ -218,7 +232,7 @@ def post_process_normal(
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
False,
|
||||
) # multi ends
|
||||
)
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
|
@@ -275,11 +275,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
request.sampling_params.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
|
||||
request.sampling_params.stop_seqs_len, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"][
|
||||
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
|
||||
] = np.array(request.get("stop_token_ids"), dtype="int64")
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
if has_prefill_task:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
@@ -446,11 +451,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
request.sampling_params.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
|
||||
request.sampling_params.stop_seqs_len, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"][
|
||||
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
|
||||
] = np.array(request.get("stop_token_ids"), dtype="int64")
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
|
||||
|
||||
@@ -619,14 +628,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32")
|
||||
|
||||
# Initialize stop seqs
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32")
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"] = paddle.full(
|
||||
[
|
||||
max_num_seqs,
|
||||
self.model_config.max_stop_seqs_num,
|
||||
self.model_config.stop_seqs_max_len,
|
||||
],
|
||||
-1,
|
||||
dtype="int32",
|
||||
dtype="int64",
|
||||
)
|
||||
if self.speculative_decoding:
|
||||
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
@@ -1012,6 +1024,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
||||
post_process(
|
||||
@@ -1276,6 +1290,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||
|
@@ -215,6 +215,16 @@ class ModelOutputData:
|
||||
"""
|
||||
reasoning_index: paddle.Tensor = None
|
||||
|
||||
"""
|
||||
the token ids of stop sequence
|
||||
"""
|
||||
stop_token_ids: paddle.Tensor = None
|
||||
|
||||
"""
|
||||
the length of stop sequence
|
||||
"""
|
||||
stop_seqs_len: paddle.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
|
58
test/operators/test_stop_generation_multi_ends.py
Normal file
58
test/operators/test_stop_generation_multi_ends.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""UT for GPU operator stop_generation_multi_ends"""
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import set_stop_value_multi_ends
|
||||
|
||||
|
||||
def test_set_stop_value_multi_ends_with_stop_seq():
|
||||
sampled_token_ids = paddle.to_tensor([[61502], [2]], dtype="int64")
|
||||
stop_flags = paddle.to_tensor([[False], [True]], dtype="bool")
|
||||
seq_lens_this_time = paddle.to_tensor([[1], [0]], dtype="int32")
|
||||
eos_token_id = paddle.to_tensor([2], dtype="int64")
|
||||
next_tokens = paddle.to_tensor([[61502], [2]], dtype="int64")
|
||||
|
||||
pre_ids = paddle.full([2, 32768], -1, dtype="int64")
|
||||
pre_ids[0, :10] = np.array([21, 22, 23, 24, 25, 26, 27, 28, 8038, 61502])
|
||||
step_idx = paddle.to_tensor([[10], [0]], dtype="int64")
|
||||
|
||||
stop_token_ids = paddle.full([2, 5, 8], -1, dtype="int64")
|
||||
stop_token_ids[0, 0, :2] = np.array([8038, 61502])
|
||||
|
||||
stop_seqs_len = paddle.full([2, 5], 10, dtype="int32")
|
||||
stop_seqs_len[0, 0] = 2
|
||||
|
||||
set_stop_value_multi_ends(
|
||||
sampled_token_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
eos_token_id,
|
||||
next_tokens,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
stop_token_ids,
|
||||
stop_seqs_len,
|
||||
False,
|
||||
)
|
||||
|
||||
assert stop_flags[0, 0] is True
|
||||
assert sampled_token_ids[0, 0] == 2 # eos token id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_set_stop_value_multi_ends_with_stop_seq()
|
Reference in New Issue
Block a user