From 95eab9f9ee1204e8e463a0c45edc158cb5a1e6e4 Mon Sep 17 00:00:00 2001 From: lizexu123 <39205361+lizexu123@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:49:12 +0800 Subject: [PATCH] [Feature] support stop_token_ids (#5399) * support stop_token_ids * fix * delete chinese * support both * delete print --- custom_ops/gpu_ops/cpp_extensions.cc | 4 +- .../speculate_set_stop_value_multi_seqs.cu | 12 +- .../gpu_ops/stop_generation_multi_ends.cu | 185 ++++++++++-------- docs/features/early_stop.md | 41 +++- docs/zh/features/early_stop.md | 42 +++- fastdeploy/input/ernie4_5_processor.py | 17 +- .../ernie4_5_vl_processor.py | 9 +- .../paddleocr_vl_processor.py | 8 +- .../qwen_vl_processor/qwen_vl_processor.py | 8 +- fastdeploy/input/text_processor.py | 17 +- fastdeploy/input/utils.py | 31 +++ .../model_executor/pre_and_post_process.py | 3 + fastdeploy/worker/gpu_model_runner.py | 4 + fastdeploy/worker/output.py | 5 + fastdeploy/worker/xpu_model_runner.py | 1 + .../Qwen3-MoE/test_Qwen3-MoE_serving.py | 30 +++ ...est_speculate_set_stop_value_multi_seqs.py | 50 +++++ .../test_stop_generation_multi_ends.py | 37 ++++ 18 files changed, 377 insertions(+), 127 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 49a7f9240..f78a1a4ce 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -420,6 +420,7 @@ void GetStopFlagsMulti(const paddle::Tensor& topk_ids, const paddle::Tensor& step_idx, const paddle::Tensor& stop_seqs, const paddle::Tensor& stop_seqs_len, + const paddle::Tensor& min_tokens, const bool beam_search); void UpdateInputs(const paddle::Tensor& stop_flags, @@ -764,7 +765,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens, const paddle::Tensor& seq_lens, const paddle::Tensor& stop_seqs, const paddle::Tensor& stop_seqs_len, - const paddle::Tensor& end_ids); + const paddle::Tensor& end_ids, + const paddle::Tensor& min_tokens); void SpeculateVerify(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& accept_tokens, diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index 956beceb5..b1a5332d9 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -32,6 +32,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, const int accept_tokens_len, const int stop_seqs_bs, const int stop_seqs_max_len, + const int64_t *min_tokens, const int pre_ids_len) { const int bid = blockIdx.x; const int tid = threadIdx.x; @@ -46,6 +47,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; const int accept_num = accept_nums[bid]; const int64_t step_idx_now = step_idx[bid]; + const int64_t min_token_limit = min_tokens[bid]; + + const bool can_stop = (step_idx_now >= min_token_limit); + if (!can_stop) return; if (!stop_flags[bid]) { int accept_idx = 0; bool is_end = false; @@ -138,7 +143,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs, const paddle::Tensor &stop_seqs_len, - const paddle::Tensor &end_ids) { + const paddle::Tensor &end_ids, + const paddle::Tensor &min_tokens) { PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64); PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); @@ -166,6 +172,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, accept_tokens_len, stop_seqs_bs, stop_seqs_max_len, + min_tokens.data(), pre_ids_len); } @@ -178,7 +185,8 @@ PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs) "seq_lens", "stop_seqs", "stop_seqs_len", - "end_ids"}) + "end_ids", + "min_tokens"}) .Outputs({"accept_tokens_out", "stop_flags_out"}) .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, {"stop_flags", "stop_flags_out"}}) diff --git a/custom_ops/gpu_ops/stop_generation_multi_ends.cu b/custom_ops/gpu_ops/stop_generation_multi_ends.cu index 3165be10a..3c7c4c884 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_ends.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_ends.cu @@ -37,59 +37,67 @@ __global__ void set_value_by_flags(bool *stop_flags, const int *stop_seqs_len, const int stop_seqs_bs, const int stop_seqs_max_len, + const int64_t *min_tokens, bool beam_search, bool prefill_one_step_stop) { - int tid = threadIdx.x; - int bid = blockIdx.x; - if (tid >= stop_seqs_bs) return; - if (bid < bs) { - if(tid == 0){ - if (prefill_one_step_stop) { - stop_flags[bid] = true; - if (seq_lens[bid] == 0) { - topk_ids[bid] = -1; - } - next_tokens[bid] = topk_ids[bid]; - } else { - if (stop_flags[bid]) { - if (seq_lens[bid] == 0) { - topk_ids[bid] = -1; - } else { - topk_ids[bid] = end_ids[0]; - next_tokens[bid] = end_ids[0]; - } - } else { - next_tokens[bid] = topk_ids[bid]; - } - } - if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { - stop_flags[bid] = true; - topk_ids[bid] = end_ids[0]; - next_tokens[bid] = end_ids[0]; - } + int tid = threadIdx.x; + int bid = blockIdx.x; + if (tid >= stop_seqs_bs) return; + if (bid < bs) { + const int64_t current_step = step_idx[bid]; + const int64_t min_token_limit = min_tokens[bid]; + const bool can_stop = (current_step >= min_token_limit); + if (tid == 0) { + if (prefill_one_step_stop) { + stop_flags[bid] = true; + if (seq_lens[bid] == 0) { + topk_ids[bid] = -1; } - // dealing stop_seqs - const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; - if (stop_seq_len <= 0) return; - const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; - const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; - const int64_t step_idx_now = step_idx[bid]; - - bool is_end = true; - int count = 1; - for (int i = stop_seq_len - 1; i >= 0; --i) { - if ((step_idx_now - count) < 0 || - pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { - is_end = false; - break; - } - } - if (is_end) { - next_tokens[bid] = end_ids[0]; - stop_flags[bid] = true; + next_tokens[bid] = topk_ids[bid]; + } else { + if (stop_flags[bid]) { + if (seq_lens[bid] == 0) { + topk_ids[bid] = -1; + } else { topk_ids[bid] = end_ids[0]; + next_tokens[bid] = end_ids[0]; + } + } else { + next_tokens[bid] = topk_ids[bid]; } + } + if (!beam_search && can_stop && + is_in_end(topk_ids[bid], end_ids, end_length)) { + stop_flags[bid] = true; + topk_ids[bid] = end_ids[0]; + next_tokens[bid] = end_ids[0]; + } } + + if (!can_stop) return; + // dealing stop_seqs + const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; + if (stop_seq_len <= 0) return; + const int64_t *stop_seq_now = + stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; + const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; + const int64_t step_idx_now = step_idx[bid]; + + bool is_end = true; + int count = 1; + for (int i = stop_seq_len - 1; i >= 0; --i) { + if ((step_idx_now - count) < 0 || + pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { + is_end = false; + break; + } + } + if (is_end) { + next_tokens[bid] = end_ids[0]; + stop_flags[bid] = true; + topk_ids[bid] = end_ids[0]; + } + } } void GetStopFlagsMulti(const paddle::Tensor &topk_ids, @@ -101,50 +109,63 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &step_idx, const paddle::Tensor &stop_seqs, const paddle::Tensor &stop_seqs_len, + const paddle::Tensor &min_tokens, const bool beam_search) { - PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); - PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); - bool prefill_one_step_stop = false; - if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { - // std::cout << "Your PATH is: " << env_p << '\n'; - if (env_p[0] == '1') { - prefill_one_step_stop = true; - } + PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); + PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); + bool prefill_one_step_stop = false; + if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { + // std::cout << "Your PATH is: " << env_p << '\n'; + if (env_p[0] == '1') { + prefill_one_step_stop = true; } + } #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place())); - auto cu_stream = dev_ctx->stream(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get( + topk_ids.place())); + auto cu_stream = dev_ctx->stream(); #else - auto cu_stream = topk_ids.stream(); + auto cu_stream = topk_ids.stream(); #endif - std::vector shape = topk_ids.shape(); - int64_t bs_now = shape[0]; - int64_t end_length = end_ids.shape()[0]; - int stop_seqs_bs = stop_seqs.shape()[1]; - int stop_seqs_max_len = stop_seqs.shape()[2]; - int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; - set_value_by_flags<<>>( - const_cast(stop_flags.data()), - const_cast(topk_ids.data()), - const_cast(next_tokens.data()), - end_ids.data(), - seq_lens.data(), - bs_now, - end_length, - pre_ids.data(), - pre_ids.shape()[1], - step_idx.data(), - stop_seqs.data(), - stop_seqs_len.data(), - stop_seqs_bs, - stop_seqs_max_len, - beam_search, - prefill_one_step_stop); + std::vector shape = topk_ids.shape(); + int64_t bs_now = shape[0]; + int64_t end_length = end_ids.shape()[0]; + int stop_seqs_bs = stop_seqs.shape()[1]; + int stop_seqs_max_len = stop_seqs.shape()[2]; + int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + set_value_by_flags<<>>( + const_cast(stop_flags.data()), + const_cast(topk_ids.data()), + const_cast(next_tokens.data()), + end_ids.data(), + seq_lens.data(), + bs_now, + end_length, + pre_ids.data(), + pre_ids.shape()[1], + step_idx.data(), + stop_seqs.data(), + stop_seqs_len.data(), + stop_seqs_bs, + stop_seqs_max_len, + min_tokens.data(), + beam_search, + prefill_one_step_stop); } PD_BUILD_STATIC_OP(set_stop_value_multi_ends) - .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"}) + .Inputs({"topk_ids", + "stop_flags", + "seq_lens", + "end_ids", + "next_tokens", + "pre_ids", + "step_idx", + "stop_seqs", + "stop_seqs_len", + "min_tokens"}) .Attrs({"beam_search: bool"}) .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) .SetInplaceMap({{"topk_ids", "topk_ids_out"}, diff --git a/docs/features/early_stop.md b/docs/features/early_stop.md index 1f0bb87b6..a221b5f84 100644 --- a/docs/features/early_stop.md +++ b/docs/features/early_stop.md @@ -2,7 +2,7 @@ # Early Stopping -The early stopping is used to prematurely terminate the token generation of the model. Specifically, the early stopping uses different strategies to determine whether the currently generated token sequence meets the early stopping criteria. If so, token generation is terminated prematurely. FastDeploy currently supports the repetition strategy and stop sequence. +The early stopping is used to prematurely terminate the token generation of the model. Specifically, the early stopping uses different strategies to determine whether the currently generated token sequence meets the early stopping criteria. If so, token generation is terminated prematurely. FastDeploy currently supports the repetition strategy and stop sequence and stop_token_ids. ## 1. Repetition Strategy * The repetition strategy determines whether to trigger the early stopping function by checking the number of times a high-probability token is generated. @@ -121,3 +121,42 @@ output = llm.chat(messages=[{"role": "user", "content": "今天天气真好"}], print(output) ``` + +## 3. Stop_token_ids +* The Stop_token_ids strategy determines whether to trigger early stopping by checking whether the generated token sequence contains a user-specified stop token_id. + +* Specifically, if the token sequence generated by a batch contains a user-specified stop_token_ids, token generation for that batch is terminated prematurely. + +### Usage Instructions + +request with stop_token_ids parameter, it can be List[int] + +* online serving, set `stop_token_ids` parameter in request +``` +# create a chat request with "stop_token_ids" parameter +curl -X POST "http://0.0.0.0:13312/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "model": "default", + "messages": [ + { + "role": "user", + "content": "北京天安门在哪里?" + } + ], + "temperature": 0.7, + "stream": false, + "seed": 1, + "stop_token_ids":[104208] +}' +``` + +* offline LLM, set `stop_token_ids` parameter in `SamplingParams` +``` +from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.llm import LLM +model_name_or_path = "/Qwen/Qwen3-0.6B" +sampling_params = SamplingParams(temperature=1, seed=1,stop_token_ids=[104208]) +llm = LLM(model=model_name_or_path, tensor_parallel_size=1) +output = llm.chat(messages=[{"role": "user", "content": "北京天安门在哪里?"}], use_tqdm=True, sampling_params=sampling_params) +print(output) diff --git a/docs/zh/features/early_stop.md b/docs/zh/features/early_stop.md index 720134ee2..7f6132e71 100644 --- a/docs/zh/features/early_stop.md +++ b/docs/zh/features/early_stop.md @@ -2,7 +2,7 @@ # 早停功能 -早停功能用于提前结束模型生成token的过程,具体来说早停功能会采取不同的策略,判断当前生成的token序列是否满足早停条件,如果满足则提前结束token生成。FastDeploy目前支持`Repetition`策略和`Stop Sequence`策略。 +早停功能用于提前结束模型生成token的过程,具体来说早停功能会采取不同的策略,判断当前生成的token序列是否满足早停条件,如果满足则提前结束token生成。FastDeploy目前支持`Repetition`策略和`Stop Sequence`策略 和`Stop_token_ids`策略。 ## 1.Repetition策略 * Repetition策略通过检查生成高概率token的次数决定是否需要触发早停功能。 @@ -116,3 +116,43 @@ output = llm.chat(messages=[{"role": "user", "content": "今天天气真好"}], print(output) ``` + +## 3.Stop_token_ids策略 +* Stop token ids策略通过检查生成的token序列是否包含用户指定的停止token id决定是否需要触发早停功能。 +* 具体来说,当某个batch生成的token序列中包含用户指定的停止token_id时,将提前结束该batch的token生成过程。 +### 使用说明 + +在请求服务时,在请求中包含`stop_token_ids`字段,是`List[int]`。 +* 在线推理请求示例,请求时添加stop_token_ids参数 +``` +# create a chat request with "stop_token_ids" parameter + +curl -X POST "http://0.0.0.0:13312/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "model": "default", + "messages": [ + { + "role": "user", + "content": "北京天安门在哪里?" + } + ], + "temperature": 0.7, + "stream": false, + "seed": 1, + "stop_token_ids":[104208] +}' +``` +* 离线推理请求,在`SamplingParams`中增加`stop_token_ids`参数 +``` +from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.llm import LLM + +model_name_or_path = "/root/paddlejob/workspace/env_run/output/models/paddle/Qwen/Qwen3-0.6B" + +# 超参设置 +sampling_params = SamplingParams(temperature=1, seed=1,stop_token_ids=[104208]) +llm = LLM(model=model_name_or_path, tensor_parallel_size=1) +output = llm.chat(messages=[{"role": "user", "content": "北京天安门在哪里?"}], use_tqdm=True, sampling_params=sampling_params) + +print(output) diff --git a/fastdeploy/input/ernie4_5_processor.py b/fastdeploy/input/ernie4_5_processor.py index a151dbfdd..3553f065a 100644 --- a/fastdeploy/input/ernie4_5_processor.py +++ b/fastdeploy/input/ernie4_5_processor.py @@ -24,6 +24,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor from fastdeploy.utils import data_processor_logger _SAMPLING_EPS = 1e-5 +from fastdeploy.input.utils import process_stop_token_ids class Ernie4_5Processor(BaseDataProcessor): @@ -92,12 +93,8 @@ class Ernie4_5Processor(BaseDataProcessor): if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids - # processing stop_sequences - stop_sequences = request.get("stop", []) - if stop_sequences is not None and len(stop_sequences) != 0: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request.set("stop_token_ids", stop_seqs) - request.set("stop_seqs_len", stop_seqs_len) + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) # processing bad_words bad_words = request.get("bad_words") @@ -173,12 +170,8 @@ class Ernie4_5Processor(BaseDataProcessor): if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - # processing stop_sequences - stop_sequences = request.get("stop", []) - if stop_sequences: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request["stop_token_ids"] = stop_seqs - request["stop_seqs_len"] = stop_seqs_len + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) # processing bad_words bad_words = request.get("bad_words") diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 41c7983b9..4e7b03a11 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -21,7 +21,7 @@ from paddleformers.generation import GenerationConfig from fastdeploy.engine.request import Request from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor -from fastdeploy.input.utils import IDS_TYPE_FLAG +from fastdeploy.input.utils import IDS_TYPE_FLAG, process_stop_token_ids from fastdeploy.utils import data_processor_logger from .process import DataProcessor @@ -207,11 +207,8 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor): if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - stop_sequences = request.get("stop", []) - if stop_sequences: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request["stop_token_ids"] = stop_seqs - request["stop_seqs_len"] = stop_seqs_len + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) bad_words = request.get("bad_words") bad_words_token_ids = request.get("bad_words_token_ids") diff --git a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py index a5335fd0c..0d23021d6 100644 --- a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py +++ b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py @@ -23,6 +23,7 @@ from fastdeploy.utils import data_processor_logger from .process import DataProcessor _SAMPLING_EPS = 1e-5 +from fastdeploy.input.utils import process_stop_token_ids class PaddleOCRVLProcessor(TextProcessor): @@ -210,11 +211,8 @@ class PaddleOCRVLProcessor(TextProcessor): if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - stop_sequences = request.get("stop", []) - if stop_sequences: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request["stop_token_ids"] = stop_seqs - request["stop_seqs_len"] = stop_seqs_len + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) if request.get("prompt"): multimodal_data = request.get("multimodal_data") diff --git a/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py b/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py index 06f43f335..960d91e39 100644 --- a/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py +++ b/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py @@ -18,6 +18,7 @@ import numpy as np from fastdeploy.engine.request import Request from fastdeploy.input.text_processor import DataProcessor as TextProcessor +from fastdeploy.input.utils import process_stop_token_ids from fastdeploy.utils import data_processor_logger from .process import DataProcessor @@ -209,11 +210,8 @@ class QwenVLProcessor(TextProcessor): if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - stop_sequences = request.get("stop", []) - if stop_sequences: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request["stop_token_ids"] = stop_seqs - request["stop_seqs_len"] = stop_seqs_len + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) bad_words = request.get("bad_words") bad_words_token_ids = request.get("bad_words_token_ids") diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 7c2e3cbd1..9b7b95c7f 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -21,6 +21,7 @@ from paddleformers.generation import GenerationConfig from paddleformers.transformers import Llama3Tokenizer, LlamaTokenizer from fastdeploy import envs +from fastdeploy.input.utils import process_stop_token_ids from fastdeploy.utils import data_processor_logger _SAMPLING_EPS = 1e-5 @@ -212,12 +213,8 @@ class DataProcessor(BaseDataProcessor): if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0: request.eos_token_ids = self.eos_token_ids - # processing stop_sequences - stop_sequences = request.get("stop", []) - if stop_sequences is not None and len(stop_sequences) != 0: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request.set("stop_token_ids", stop_seqs) - request.set("stop_seqs_len", stop_seqs_len) + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) # processing bad_words bad_words = request.get("bad_words") @@ -290,12 +287,8 @@ class DataProcessor(BaseDataProcessor): if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - # processing stop_sequences - stop_sequences = request.get("stop", []) - if stop_sequences: - stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) - request["stop_token_ids"] = stop_seqs - request["stop_seqs_len"] = stop_seqs_len + # processing stop_sequences and stop_token_ids + process_stop_token_ids(request, self.update_stop_seq) # processing bad_words bad_words = request.get("bad_words") diff --git a/fastdeploy/input/utils.py b/fastdeploy/input/utils.py index 7de8db6d0..0f59d60ab 100644 --- a/fastdeploy/input/utils.py +++ b/fastdeploy/input/utils.py @@ -19,3 +19,34 @@ __all__ = [ ] IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} + + +from typing import Any, Callable, Dict, List, Tuple + + +def process_stop_token_ids( + request: Dict[str, Any], + update_stop_seq_fn: Callable[[List[str]], Tuple[List[List[int]], List[int]]], +) -> None: + stop_token_ids_final = [] + + if request.get("stop_token_ids") is not None: + stop_token_ids = request.get("stop_token_ids") + if isinstance(stop_token_ids, list) and len(stop_token_ids) > 0: + if isinstance(stop_token_ids[0], int): + # List[int] -> List[List[int]] + stop_token_ids_final.extend([[t] for t in stop_token_ids]) + elif isinstance(stop_token_ids[0], list): + # Already List[List[int]] + stop_token_ids_final.extend(stop_token_ids) + + stop_sequences = request.get("stop", []) + if stop_sequences: + stop_seqs, _ = update_stop_seq_fn(stop_sequences) + stop_token_ids_final.extend(stop_seqs) + + # Update request + if stop_token_ids_final: + stop_seqs_len = [len(seq) for seq in stop_token_ids_final] + request["stop_token_ids"] = stop_token_ids_final + request["stop_seqs_len"] = stop_seqs_len diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 5b771b6d7..e88ee0e5b 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -355,6 +355,7 @@ def post_process_normal( model_output.step_idx, model_output.stop_token_ids, model_output.stop_seqs_len, + model_output.min_tokens, False, ) # multi ends elif current_platform.is_maca(): @@ -368,6 +369,7 @@ def post_process_normal( model_output.step_idx, model_output.stop_token_ids, model_output.stop_seqs_len, + model_output.min_tokens, False, ) # multi ends else: @@ -472,6 +474,7 @@ def post_process_specualate( model_output.stop_token_ids, model_output.stop_seqs_len, model_output.eos_token_id, + model_output.min_tokens, ) speculate_update( model_output.seq_lens_encoder, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index f5ee92ea1..a7ee4be8f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1779,6 +1779,7 @@ class GPUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], ) @@ -1879,6 +1880,7 @@ class GPUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], mask_rollback=self.share_inputs["mask_rollback"], ) @@ -2349,6 +2351,7 @@ class GPUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], ) @@ -2454,6 +2457,7 @@ class GPUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], mask_rollback=self.share_inputs["mask_rollback"], prompt_logprobs_list=prompt_logprobs_list, diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 2b66ce4e1..bd7e7ce1d 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -279,6 +279,11 @@ class ModelOutputData: """ prompt_logprobs_list: Optional[LogprobsTensors] = None + """ + the minimum tokens that will be generated + """ + min_tokens: paddle.Tensor = None + @dataclass class ModelRunnerOutput: diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index f9bbb4ea9..adc221dc3 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1321,6 +1321,7 @@ class XPUModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], ) if self.speculative_decoding: # base model post process diff --git a/tests/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py b/tests/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py index d68bf93ca..3400da4c4 100644 --- a/tests/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py +++ b/tests/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py @@ -308,3 +308,33 @@ def test_profile_reset_block_num(): f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内" f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]" ) + + +def test_thinking_with_stop_token_ids(api_url, headers): + """ + Test case to verify thinking behavior when stop token ids are provided. + """ + messages = [{"role": "user", "content": "北京天安门在哪里"}] + + payload = { + "messages": messages, + "max_tokens": 100, + "temperature": 0.8, + "seed": 1, + "stop_token_ids": [105930], + } + + resp = requests.post(api_url, headers=headers, json=payload) + assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}" + + try: + response_json = resp.json() + except Exception as e: + assert False, f"Response is not valid JSON: {e}" + + content = response_json.get("choices", [{}])[0].get("message", {}).get("content", "") + + expected_output = "\n好的,用户问“北京天安门在哪里" + assert content == expected_output, ( + f"Unexpected response content.\n" f"Expected: {expected_output!r}\n" f"Actual: {content!r}" + ) diff --git a/tests/operators/test_speculate_set_stop_value_multi_seqs.py b/tests/operators/test_speculate_set_stop_value_multi_seqs.py index 0058b81e4..3c0880b71 100644 --- a/tests/operators/test_speculate_set_stop_value_multi_seqs.py +++ b/tests/operators/test_speculate_set_stop_value_multi_seqs.py @@ -32,6 +32,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs, stop_seqs_len, end_ids, + min_tokens, ): accept_tokens_out = accept_tokens.clone() stop_flags_out = stop_flags.clone() @@ -45,6 +46,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs, stop_seqs_len, end_ids, + min_tokens, ) return { @@ -88,6 +90,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): ) stop_seqs_len = paddle.to_tensor([3, 0], dtype="int32") end_ids = paddle.to_tensor([-1], dtype="int64") + min_tokens = paddle.to_tensor([0, 0], dtype="int64") gpu_results = self.run_op( accept_tokens, accept_num, @@ -98,6 +101,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs, stop_seqs_len, end_ids, + min_tokens, ) expected_accept_tokens = np.array([[4, 5, -1, 0, 0], [1, 2, 3, 0, 0]]) @@ -121,6 +125,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs = paddle.to_tensor([[11, 12, 13], [14, 15, 16]], dtype="int64") stop_seqs_len = paddle.to_tensor([3, 3], dtype="int32") end_ids = paddle.to_tensor([-1], dtype="int64") + min_tokens = paddle.to_tensor([0, 0], dtype="int64") gpu_results = self.run_op( accept_tokens, @@ -132,6 +137,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs, stop_seqs_len, end_ids, + min_tokens, ) np.testing.assert_array_equal(gpu_results["output_accept_tokens"], accept_tokens.numpy()) @@ -152,6 +158,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): ) stop_seqs_len = paddle.to_tensor([3], dtype="int32") end_ids = paddle.to_tensor([-1], dtype="int64") + min_tokens = paddle.to_tensor([0], dtype="int64") gpu_results = self.run_op( accept_tokens, @@ -163,6 +170,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs, stop_seqs_len, end_ids, + min_tokens, ) np.testing.assert_array_equal(gpu_results["output_accept_tokens"], accept_tokens.numpy()) @@ -180,6 +188,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs = paddle.to_tensor([[5, 4, 3]], dtype="int64") stop_seqs_len = paddle.to_tensor([3], dtype="int32") end_ids = paddle.to_tensor([-1], dtype="int64") + min_tokens = paddle.to_tensor([0], dtype="int64") gpu_results = self.run_op( accept_tokens, @@ -191,11 +200,52 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): stop_seqs, stop_seqs_len, end_ids, + min_tokens, ) np.testing.assert_array_equal(gpu_results["output_accept_tokens"], accept_tokens.numpy()) np.testing.assert_array_equal(gpu_results["output_stop_flags"], stop_flags.numpy()) + def test_min_tokens_allows_stop(self): + """Test that stopping is allowed when step_idx >= min_tokens""" + accept_tokens = paddle.to_tensor( + [[4, 5, 0, 0, 0]], + dtype="int64", + ) + accept_num = paddle.to_tensor([3], dtype="int32") + pre_ids = paddle.to_tensor( + [[7, 8, 9, 3, 4, 5]], + dtype="int64", + ) + step_idx = paddle.to_tensor([6], dtype="int64") + stop_flags = paddle.to_tensor([False], dtype="bool") + seq_lens = paddle.to_tensor([6], dtype="int32") + stop_seqs = paddle.to_tensor( + [[[3, 4, 5]]], + dtype="int64", + ) + stop_seqs_len = paddle.to_tensor([[3]], dtype="int32") + end_ids = paddle.to_tensor([-1], dtype="int64") + min_tokens = paddle.to_tensor([5], dtype="int64") # min_tokens=5, step_idx=6 >= 5 + + gpu_results = self.run_op( + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + min_tokens, + ) + + expected_accept_tokens = np.array([[4, 5, -1, 0, 0]]) + expected_stop_flags = np.array([True]) + np.testing.assert_array_equal(gpu_results["output_accept_tokens"], expected_accept_tokens) + np.testing.assert_array_equal(gpu_results["output_stop_flags"], expected_stop_flags) + if __name__ == "__main__": unittest.main() diff --git a/tests/operators/test_stop_generation_multi_ends.py b/tests/operators/test_stop_generation_multi_ends.py index c350e8304..12d54857a 100644 --- a/tests/operators/test_stop_generation_multi_ends.py +++ b/tests/operators/test_stop_generation_multi_ends.py @@ -36,6 +36,7 @@ def test_set_stop_value_multi_ends_with_stop_seq(): stop_seqs_len = paddle.full([2, 5], 10, dtype="int32") stop_seqs_len[0, 0] = 2 + min_tokens = paddle.to_tensor([[0], [0]], dtype="int64") set_stop_value_multi_ends( sampled_token_ids, @@ -47,6 +48,7 @@ def test_set_stop_value_multi_ends_with_stop_seq(): step_idx, stop_token_ids, stop_seqs_len, + min_tokens, False, ) @@ -54,5 +56,40 @@ def test_set_stop_value_multi_ends_with_stop_seq(): assert sampled_token_ids[0, 0] == 2 # eos token id +def test_min_tokens(): + """Test min_tokens functionality""" + sampled_token_ids = paddle.to_tensor([[2], [100], [200]], dtype="int64") + stop_flags = paddle.to_tensor([[False], [False], [False]], dtype="bool") + seq_lens_this_time = paddle.to_tensor([[1], [1], [1]], dtype="int32") + eos_token_id = paddle.to_tensor([2], dtype="int64") + next_tokens = paddle.to_tensor([[2], [100], [200]], dtype="int64") + + pre_ids = paddle.full([3, 100], -1, dtype="int64") + step_idx = paddle.to_tensor([[5], [50], [10]], dtype="int64") + + stop_seqs = paddle.full([3, 5, 8], -1, dtype="int64") + stop_seqs_len = paddle.zeros([3, 5], dtype="int32") + + min_tokens = paddle.to_tensor([[10], [0], [5]], dtype="int64") + + set_stop_value_multi_ends( + sampled_token_ids, + stop_flags, + seq_lens_this_time, + eos_token_id, + next_tokens, + pre_ids, + step_idx, + stop_seqs, + stop_seqs_len, + min_tokens, + False, + ) + + # Sample 0: step < min_tokens, should not stop even with EOS + assert bool(stop_flags[0, 0]) is False + + if __name__ == "__main__": test_set_stop_value_multi_ends_with_stop_seq() + test_min_tokens()