diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index d344fe9ee..266d50599 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -266,13 +266,12 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &seq_lens, const paddle::Tensor &end_ids, const paddle::Tensor &next_tokens, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &stop_seqs, + const paddle::Tensor &stop_seqs_len, const bool beam_search); -void GetStopFlagsMultiSeqs( - const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs, - const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids); void UpdateInputes(const paddle::Tensor &stop_flags, const paddle::Tensor ¬_need_stop, // only on cpu @@ -954,12 +953,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("set_stop_value_multi_ends", &GetStopFlagsMulti, "update_inputs function"); - /** - * stop_generation_multi_stop_seqs.cu - * set_stop_value_multi_seqs - */ - m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs, - "update_inputs function"); /** * update_inputs.cu diff --git a/custom_ops/gpu_ops/stop_generation_multi_ends.cu b/custom_ops/gpu_ops/stop_generation_multi_ends.cu index fcabc009b..fe82be207 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_ends.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_ends.cu @@ -30,30 +30,62 @@ __global__ void set_value_by_flags(bool *stop_flags, const int *seq_lens, const int bs, const int end_length, + const int64_t *pre_ids, + const int pre_ids_len, + const int64_t *step_idx, + const int64_t *stop_seqs, + const int *stop_seqs_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, bool beam_search, bool prefill_one_step_stop) { int tid = threadIdx.x; - if (tid < bs) { - if (prefill_one_step_stop) { - stop_flags[tid] = true; - if (seq_lens[tid] == 0) { - topk_ids[tid] = -1; - } - next_tokens[tid] = topk_ids[tid]; - } else { - if (stop_flags[tid]) { - if (seq_lens[tid] == 0) { - topk_ids[tid] = -1; - } else { - topk_ids[tid] = end_ids[0]; - next_tokens[tid] = end_ids[0]; + int bid = blockIdx.x; + if (tid >= stop_seqs_bs) return; + if (bid < bs) { + if(tid == 0){ + if (prefill_one_step_stop) { + stop_flags[bid] = true; + if (seq_lens[bid] == 0) { + topk_ids[bid] = -1; } + next_tokens[bid] = topk_ids[bid]; } else { - next_tokens[tid] = topk_ids[tid]; + if (stop_flags[bid]) { + if (seq_lens[bid] == 0) { + topk_ids[bid] = -1; + } else { + topk_ids[bid] = end_ids[0]; + next_tokens[bid] = end_ids[0]; + } + } else { + next_tokens[bid] = topk_ids[bid]; + } + } + if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { + stop_flags[bid] = true; } } - if (!beam_search && is_in_end(topk_ids[tid], end_ids, end_length)) { - stop_flags[tid] = true; + // dealing stop_seqs + const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; + if (stop_seq_len <= 0) return; + const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; + const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; + const int64_t step_idx_now = step_idx[bid]; + + bool is_end = true; + int count = 1; + for (int i = stop_seq_len - 1; i >= 0; --i) { + if ((step_idx_now - count) < 0 || + pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { + is_end = false; + break; + } + } + if (is_end) { + next_tokens[bid] = end_ids[0]; + stop_flags[bid] = true; + topk_ids[bid] = end_ids[0]; } } } @@ -63,6 +95,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &seq_lens, const paddle::Tensor &end_ids, const paddle::Tensor &next_tokens, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &stop_seqs, + const paddle::Tensor &stop_seqs_len, const bool beam_search) { PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); @@ -83,8 +119,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, std::vector shape = topk_ids.shape(); int64_t bs_now = shape[0]; int64_t end_length = end_ids.shape()[0]; - int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; - set_value_by_flags<<<1, block_size, 0, cu_stream>>>( + int stop_seqs_bs = stop_seqs.shape()[1]; + int stop_seqs_max_len = stop_seqs.shape()[2]; + int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + set_value_by_flags<<>>( const_cast(stop_flags.data()), const_cast(topk_ids.data()), const_cast(next_tokens.data()), @@ -92,12 +130,19 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, seq_lens.data(), bs_now, end_length, + pre_ids.data(), + pre_ids.shape()[1], + step_idx.data(), + stop_seqs.data(), + stop_seqs_len.data(), + stop_seqs_bs, + stop_seqs_max_len, beam_search, prefill_one_step_stop); } PD_BUILD_STATIC_OP(set_stop_value_multi_ends) - .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens"}) + .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"}) .Attrs({"beam_search: bool"}) .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) .SetInplaceMap({{"topk_ids", "topk_ids_out"}, diff --git a/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu b/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu deleted file mode 100644 index c2a14c2cc..000000000 --- a/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include "paddle/extension.h" -#include "helper.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -__global__ void set_value_by_stop_seqs(bool *stop_flags, - int64_t *topk_ids, - const int64_t *pre_ids, - const int64_t *step_idx, - const int64_t *stop_seqs, - const int *stop_seqs_len, - const int *seq_lens, - const int64_t *end_ids, - const int bs, - const int stop_seqs_bs, - const int stop_seqs_max_len, - const int pre_ids_len) { - const int bid = blockIdx.x; - const int tid = threadIdx.x; - if (tid >= stop_seqs_bs) return; - - const int stop_seq_len = stop_seqs_len[tid]; - if (stop_seq_len <= 0) return; - const int64_t *stop_seq_now = stop_seqs + tid * stop_seqs_max_len; - const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; - const int64_t step_idx_now = step_idx[bid]; - if (bid < bs) { - if (stop_flags[bid]) { // 长度超限,当前位置置为2 - topk_ids[bid] = end_ids[0]; - if (seq_lens[bid] == 0) { // 已终止,当前位置置为-1 - topk_ids[bid] = -1; - } - return; - } - bool is_end = true; - int count = 1; - if (topk_ids[bid] == end_ids[0]) { - if (tid == 0) { - stop_flags[bid] = true; - } - return; - } - for (int i = stop_seq_len - 1; i >= 0; --i) { - if ((step_idx_now - count) < 0 || - pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { - is_end = false; - break; - } - } - if (is_end) { - topk_ids[bid] = end_ids[0]; - stop_flags[bid] = true; - } - } -} - -void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, - const paddle::Tensor &stop_seqs, - const paddle::Tensor &stop_seqs_len, - const paddle::Tensor &end_ids) { - PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); - PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); - -#ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place())); - auto cu_stream = dev_ctx->stream(); -#else - auto cu_stream = topk_ids.stream(); -#endif - std::vector shape = topk_ids.shape(); - std::vector stop_seqs_shape = stop_seqs.shape(); - int bs_now = shape[0]; - int stop_seqs_bs = stop_seqs_shape[0]; - int stop_seqs_max_len = stop_seqs_shape[1]; - int pre_ids_len = pre_ids.shape()[1]; - - int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; - set_value_by_stop_seqs<<>>( - const_cast(stop_flags.data()), - const_cast(topk_ids.data()), - pre_ids.data(), - step_idx.data(), - stop_seqs.data(), - stop_seqs_len.data(), - seq_lens.data(), - end_ids.data(), - bs_now, - stop_seqs_bs, - stop_seqs_max_len, - pre_ids_len); -} - -PD_BUILD_STATIC_OP(set_stop_value_multi_seqs) - .Inputs({"topk_ids", - "pre_ids", - "step_idx", - "stop_flags", - "seq_lens", - "stop_seqs", - "stop_seqs_len", - "end_ids"}) - .Outputs({"topk_ids_out", "stop_flags_out"}) - .SetInplaceMap({{"topk_ids", "topk_ids_out"}, - {"stop_flags", "stop_flags_out"}}) - .SetKernelFn(PD_KERNEL(GetStopFlagsMultiSeqs)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 8307330e9..128403ad3 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -260,7 +260,6 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/token_penalty_only_once.cu", "gpu_ops/stop_generation.cu", "gpu_ops/stop_generation_multi_ends.cu", - "gpu_ops/stop_generation_multi_stop_seqs.cu", "gpu_ops/set_flags.cu", "gpu_ops/update_inputs_v1.cu", "gpu_ops/recover_decode_task.cu", @@ -529,7 +528,6 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): sources=[ "gpu_ops/get_padding_offset.cu", "gpu_ops/set_value_by_flags.cu", - "gpu_ops/stop_generation_multi_stop_seqs.cu", "gpu_ops/rebuild_padding.cu", "gpu_ops/update_inputs.cu", "gpu_ops/stop_generation_multi_ends.cu", diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 09c419e42..10671aed5 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -101,9 +101,6 @@ class ModelConfig: self, args, ): - self.max_stop_seqs_num = 5 - self.stop_seqs_max_len = 8 - # NOTE(gongshaotain): form _load_model_init_val() self.top_p = 1.0 self.temperature = 1.0 @@ -122,6 +119,9 @@ class ModelConfig: self.enable_redundant_experts = False self.redundant_experts_num = 0 + self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM) + self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN) + for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 688351883..91babf7a8 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -90,7 +90,8 @@ class SamplingParams: min_p: float = 0.0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None - stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None + stop_token_ids: Optional[List[int]] = None + stop_seqs_len: Optional[int] = None max_tokens: Optional[int] = None reasoning_max_tokens: Optional[int] = None min_tokens: int = 1 diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 06d198db1..63feda934 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -414,6 +414,8 @@ class ErnieProcessor(BaseDataProcessor): Update stop sequences from request. """ stop_seqs = [] + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] for seq in stop_sequences: if seq != self.tokenizer.eos_token_id: stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq))) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 68d168899..5a14d77b4 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -210,15 +210,29 @@ def post_process_normal( paddle.logical_or(model_output.stop_flags, length_cond), model_output.stop_flags, ) - # TODO(gongshaotian): Add use_stop_seqs - set_stop_value_multi_ends( - sampler_output.sampled_token_ids, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.eos_token_id, - model_output.next_tokens, - False, - ) # multi ends + + if current_platform.is_cuda(): + set_stop_value_multi_ends( + sampler_output.sampled_token_ids, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.eos_token_id, + model_output.next_tokens, + model_output.pre_ids, + model_output.step_idx, + model_output.stop_token_ids, + model_output.stop_seqs_len, + False, + ) # multi ends + else: + set_stop_value_multi_ends( + sampler_output.sampled_token_ids, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.eos_token_id, + model_output.next_tokens, + False, + ) # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 3f9014dc2..bf6d19f3d 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -275,11 +275,16 @@ class GPUModelRunner(ModelRunnerBase): if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): - request.stop_seqs_len.append(0) - self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") - self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( - request.get("stop_token_ids"), dtype="int64" + request.sampling_params.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( + request.sampling_params.stop_seqs_len, dtype="int32" ) + self.share_inputs["stop_seqs"][ + idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) + ] = np.array(request.get("stop_token_ids"), dtype="int64") + else: + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + if has_prefill_task: self.share_inputs["not_need_stop"][0] = True @@ -446,11 +451,15 @@ class GPUModelRunner(ModelRunnerBase): if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): - request.stop_seqs_len.append(0) - self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") - self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( - request.get("stop_token_ids"), dtype="int64" + request.sampling_params.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( + request.sampling_params.stop_seqs_len, dtype="int32" ) + self.share_inputs["stop_seqs"][ + idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) + ] = np.array(request.get("stop_token_ids"), dtype="int64") + else: + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens) @@ -619,14 +628,17 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32") # Initialize stop seqs - self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32") + self.share_inputs["stop_seqs_len"] = paddle.full( + [max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32" + ) self.share_inputs["stop_seqs"] = paddle.full( [ + max_num_seqs, self.model_config.max_stop_seqs_num, self.model_config.stop_seqs_max_len, ], -1, - dtype="int32", + dtype="int64", ) if self.speculative_decoding: max_draft_token_num = self.speculative_config.num_speculative_tokens @@ -1012,6 +1024,8 @@ class GPUModelRunner(ModelRunnerBase): think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], ) post_process( @@ -1276,6 +1290,8 @@ class GPUModelRunner(ModelRunnerBase): think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], ) if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 73ecbcbde..67e9ed620 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -215,6 +215,16 @@ class ModelOutputData: """ reasoning_index: paddle.Tensor = None + """ + the token ids of stop sequence + """ + stop_token_ids: paddle.Tensor = None + + """ + the length of stop sequence + """ + stop_seqs_len: paddle.Tensor = None + @dataclass class ModelRunnerOutput: diff --git a/test/operators/test_stop_generation_multi_ends.py b/test/operators/test_stop_generation_multi_ends.py new file mode 100644 index 000000000..7ba359b7b --- /dev/null +++ b/test/operators/test_stop_generation_multi_ends.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UT for GPU operator stop_generation_multi_ends""" + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import set_stop_value_multi_ends + + +def test_set_stop_value_multi_ends_with_stop_seq(): + sampled_token_ids = paddle.to_tensor([[61502], [2]], dtype="int64") + stop_flags = paddle.to_tensor([[False], [True]], dtype="bool") + seq_lens_this_time = paddle.to_tensor([[1], [0]], dtype="int32") + eos_token_id = paddle.to_tensor([2], dtype="int64") + next_tokens = paddle.to_tensor([[61502], [2]], dtype="int64") + + pre_ids = paddle.full([2, 32768], -1, dtype="int64") + pre_ids[0, :10] = np.array([21, 22, 23, 24, 25, 26, 27, 28, 8038, 61502]) + step_idx = paddle.to_tensor([[10], [0]], dtype="int64") + + stop_token_ids = paddle.full([2, 5, 8], -1, dtype="int64") + stop_token_ids[0, 0, :2] = np.array([8038, 61502]) + + stop_seqs_len = paddle.full([2, 5], 10, dtype="int32") + stop_seqs_len[0, 0] = 2 + + set_stop_value_multi_ends( + sampled_token_ids, + stop_flags, + seq_lens_this_time, + eos_token_id, + next_tokens, + pre_ids, + step_idx, + stop_token_ids, + stop_seqs_len, + False, + ) + + assert stop_flags[0, 0] is True + assert sampled_token_ids[0, 0] == 2 # eos token id + + +if __name__ == "__main__": + test_set_stop_value_multi_ends_with_stop_seq()