[Feature] support stop_token_ids (#5399)

* support stop_token_ids

* fix

* delete chinese

* support both

* delete print
This commit is contained in:
lizexu123
2025-12-09 17:49:12 +08:00
committed by GitHub
parent df67379bc3
commit 95eab9f9ee
18 changed files with 377 additions and 127 deletions

View File

@@ -420,6 +420,7 @@ void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& min_tokens,
const bool beam_search);
void UpdateInputs(const paddle::Tensor& stop_flags,
@@ -764,7 +765,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
const paddle::Tensor& seq_lens,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& end_ids);
const paddle::Tensor& end_ids,
const paddle::Tensor& min_tokens);
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& accept_tokens,

View File

@@ -32,6 +32,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
const int pre_ids_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
@@ -46,6 +47,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (step_idx_now >= min_token_limit);
if (!can_stop) return;
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
@@ -138,7 +143,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
const paddle::Tensor &end_ids,
const paddle::Tensor &min_tokens) {
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
@@ -166,6 +172,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
pre_ids_len);
}
@@ -178,7 +185,8 @@ PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
"seq_lens",
"stop_seqs",
"stop_seqs_len",
"end_ids"})
"end_ids",
"min_tokens"})
.Outputs({"accept_tokens_out", "stop_flags_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"stop_flags", "stop_flags_out"}})

View File

@@ -37,59 +37,67 @@ __global__ void set_value_by_flags(bool *stop_flags,
const int *stop_seqs_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
bool beam_search,
bool prefill_one_step_stop) {
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
if(tid == 0){
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
const int64_t current_step = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (current_step >= min_token_limit);
if (tid == 0) {
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];
bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && can_stop &&
is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
}
if (!can_stop) return;
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now =
stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];
bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
}
}
}
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
@@ -101,50 +109,63 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &min_tokens,
const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = topk_ids.stream();
auto cu_stream = topk_ids.stream();
#endif
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
beam_search,
prefill_one_step_stop);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
beam_search,
prefill_one_step_stop);
}
PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
.Inputs({"topk_ids",
"stop_flags",
"seq_lens",
"end_ids",
"next_tokens",
"pre_ids",
"step_idx",
"stop_seqs",
"stop_seqs_len",
"min_tokens"})
.Attrs({"beam_search: bool"})
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
.SetInplaceMap({{"topk_ids", "topk_ids_out"},

View File

@@ -2,7 +2,7 @@
# Early Stopping
The early stopping is used to prematurely terminate the token generation of the model. Specifically, the early stopping uses different strategies to determine whether the currently generated token sequence meets the early stopping criteria. If so, token generation is terminated prematurely. FastDeploy currently supports the repetition strategy and stop sequence.
The early stopping is used to prematurely terminate the token generation of the model. Specifically, the early stopping uses different strategies to determine whether the currently generated token sequence meets the early stopping criteria. If so, token generation is terminated prematurely. FastDeploy currently supports the repetition strategy and stop sequence and stop_token_ids.
## 1. Repetition Strategy
* The repetition strategy determines whether to trigger the early stopping function by checking the number of times a high-probability token is generated.
@@ -121,3 +121,42 @@ output = llm.chat(messages=[{"role": "user", "content": "今天天气真好"}],
print(output)
```
## 3. Stop_token_ids
* The Stop_token_ids strategy determines whether to trigger early stopping by checking whether the generated token sequence contains a user-specified stop token_id.
* Specifically, if the token sequence generated by a batch contains a user-specified stop_token_ids, token generation for that batch is terminated prematurely.
### Usage Instructions
request with stop_token_ids parameter it can be List[int]
* online serving, set `stop_token_ids` parameter in request
```
# create a chat request with "stop_token_ids" parameter
curl -X POST "http://0.0.0.0:13312/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "default",
"messages": [
{
"role": "user",
"content": "北京天安门在哪里?"
}
],
"temperature": 0.7,
"stream": false,
"seed": 1,
"stop_token_ids":[104208]
}'
```
* offline LLM, set `stop_token_ids` parameter in `SamplingParams`
```
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
model_name_or_path = "/Qwen/Qwen3-0.6B"
sampling_params = SamplingParams(temperature=1, seed=1,stop_token_ids=[104208])
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
output = llm.chat(messages=[{"role": "user", "content": "北京天安门在哪里?"}], use_tqdm=True, sampling_params=sampling_params)
print(output)

View File

@@ -2,7 +2,7 @@
# 早停功能
早停功能用于提前结束模型生成token的过程具体来说早停功能会采取不同的策略判断当前生成的token序列是否满足早停条件如果满足则提前结束token生成。FastDeploy目前支持`Repetition`策略和`Stop Sequence`策略。
早停功能用于提前结束模型生成token的过程具体来说早停功能会采取不同的策略判断当前生成的token序列是否满足早停条件如果满足则提前结束token生成。FastDeploy目前支持`Repetition`策略和`Stop Sequence`策略`Stop_token_ids`策略
## 1.Repetition策略
* Repetition策略通过检查生成高概率token的次数决定是否需要触发早停功能。
@@ -116,3 +116,43 @@ output = llm.chat(messages=[{"role": "user", "content": "今天天气真好"}],
print(output)
```
## 3.Stop_token_ids策略
* Stop token ids策略通过检查生成的token序列是否包含用户指定的停止token id决定是否需要触发早停功能。
* 具体来说当某个batch生成的token序列中包含用户指定的停止token_id时将提前结束该batch的token生成过程。
### 使用说明
在请求服务时,在请求中包含`stop_token_ids`字段,是`List[int]`。
* 在线推理请求示例请求时添加stop_token_ids参数
```
# create a chat request with "stop_token_ids" parameter
curl -X POST "http://0.0.0.0:13312/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "default",
"messages": [
{
"role": "user",
"content": "北京天安门在哪里?"
}
],
"temperature": 0.7,
"stream": false,
"seed": 1,
"stop_token_ids":[104208]
}'
```
* 离线推理请求,在`SamplingParams`中增加`stop_token_ids`参数
```
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
model_name_or_path = "/root/paddlejob/workspace/env_run/output/models/paddle/Qwen/Qwen3-0.6B"
# 超参设置
sampling_params = SamplingParams(temperature=1, seed=1,stop_token_ids=[104208])
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
output = llm.chat(messages=[{"role": "user", "content": "北京天安门在哪里?"}], use_tqdm=True, sampling_params=sampling_params)
print(output)

View File

@@ -24,6 +24,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor
from fastdeploy.utils import data_processor_logger
_SAMPLING_EPS = 1e-5
from fastdeploy.input.utils import process_stop_token_ids
class Ernie4_5Processor(BaseDataProcessor):
@@ -92,12 +93,8 @@ class Ernie4_5Processor(BaseDataProcessor):
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
# processing bad_words
bad_words = request.get("bad_words")
@@ -173,12 +170,8 @@ class Ernie4_5Processor(BaseDataProcessor):
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
# processing bad_words
bad_words = request.get("bad_words")

View File

@@ -21,7 +21,7 @@ from paddleformers.generation import GenerationConfig
from fastdeploy.engine.request import Request
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.input.utils import IDS_TYPE_FLAG, process_stop_token_ids
from fastdeploy.utils import data_processor_logger
from .process import DataProcessor
@@ -207,11 +207,8 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")

View File

@@ -23,6 +23,7 @@ from fastdeploy.utils import data_processor_logger
from .process import DataProcessor
_SAMPLING_EPS = 1e-5
from fastdeploy.input.utils import process_stop_token_ids
class PaddleOCRVLProcessor(TextProcessor):
@@ -210,11 +211,8 @@ class PaddleOCRVLProcessor(TextProcessor):
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
if request.get("prompt"):
multimodal_data = request.get("multimodal_data")

View File

@@ -18,6 +18,7 @@ import numpy as np
from fastdeploy.engine.request import Request
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
from fastdeploy.input.utils import process_stop_token_ids
from fastdeploy.utils import data_processor_logger
from .process import DataProcessor
@@ -209,11 +210,8 @@ class QwenVLProcessor(TextProcessor):
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")

View File

@@ -21,6 +21,7 @@ from paddleformers.generation import GenerationConfig
from paddleformers.transformers import Llama3Tokenizer, LlamaTokenizer
from fastdeploy import envs
from fastdeploy.input.utils import process_stop_token_ids
from fastdeploy.utils import data_processor_logger
_SAMPLING_EPS = 1e-5
@@ -212,12 +213,8 @@ class DataProcessor(BaseDataProcessor):
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
# processing bad_words
bad_words = request.get("bad_words")
@@ -290,12 +287,8 @@ class DataProcessor(BaseDataProcessor):
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
# processing bad_words
bad_words = request.get("bad_words")

View File

@@ -19,3 +19,34 @@ __all__ = [
]
IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3}
from typing import Any, Callable, Dict, List, Tuple
def process_stop_token_ids(
request: Dict[str, Any],
update_stop_seq_fn: Callable[[List[str]], Tuple[List[List[int]], List[int]]],
) -> None:
stop_token_ids_final = []
if request.get("stop_token_ids") is not None:
stop_token_ids = request.get("stop_token_ids")
if isinstance(stop_token_ids, list) and len(stop_token_ids) > 0:
if isinstance(stop_token_ids[0], int):
# List[int] -> List[List[int]]
stop_token_ids_final.extend([[t] for t in stop_token_ids])
elif isinstance(stop_token_ids[0], list):
# Already List[List[int]]
stop_token_ids_final.extend(stop_token_ids)
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, _ = update_stop_seq_fn(stop_sequences)
stop_token_ids_final.extend(stop_seqs)
# Update request
if stop_token_ids_final:
stop_seqs_len = [len(seq) for seq in stop_token_ids_final]
request["stop_token_ids"] = stop_token_ids_final
request["stop_seqs_len"] = stop_seqs_len

View File

@@ -355,6 +355,7 @@ def post_process_normal(
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.min_tokens,
False,
) # multi ends
elif current_platform.is_maca():
@@ -368,6 +369,7 @@ def post_process_normal(
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.min_tokens,
False,
) # multi ends
else:
@@ -472,6 +474,7 @@ def post_process_specualate(
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.eos_token_id,
model_output.min_tokens,
)
speculate_update(
model_output.seq_lens_encoder,

View File

@@ -1779,6 +1779,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
@@ -1879,6 +1880,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
)
@@ -2349,6 +2351,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
@@ -2454,6 +2457,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
prompt_logprobs_list=prompt_logprobs_list,

View File

@@ -279,6 +279,11 @@ class ModelOutputData:
"""
prompt_logprobs_list: Optional[LogprobsTensors] = None
"""
the minimum tokens that will be generated
"""
min_tokens: paddle.Tensor = None
@dataclass
class ModelRunnerOutput:

View File

@@ -1321,6 +1321,7 @@ class XPUModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
)
if self.speculative_decoding:
# base model post process

View File

@@ -308,3 +308,33 @@ def test_profile_reset_block_num():
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
)
def test_thinking_with_stop_token_ids(api_url, headers):
"""
Test case to verify thinking behavior when stop token ids are provided.
"""
messages = [{"role": "user", "content": "北京天安门在哪里"}]
payload = {
"messages": messages,
"max_tokens": 100,
"temperature": 0.8,
"seed": 1,
"stop_token_ids": [105930],
}
resp = requests.post(api_url, headers=headers, json=payload)
assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}"
try:
response_json = resp.json()
except Exception as e:
assert False, f"Response is not valid JSON: {e}"
content = response_json.get("choices", [{}])[0].get("message", {}).get("content", "")
expected_output = "<think>\n好的,用户问“北京天安门在哪里"
assert content == expected_output, (
f"Unexpected response content.\n" f"Expected: {expected_output!r}\n" f"Actual: {content!r}"
)

View File

@@ -32,6 +32,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
):
accept_tokens_out = accept_tokens.clone()
stop_flags_out = stop_flags.clone()
@@ -45,6 +46,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
)
return {
@@ -88,6 +90,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
)
stop_seqs_len = paddle.to_tensor([3, 0], dtype="int32")
end_ids = paddle.to_tensor([-1], dtype="int64")
min_tokens = paddle.to_tensor([0, 0], dtype="int64")
gpu_results = self.run_op(
accept_tokens,
accept_num,
@@ -98,6 +101,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
)
expected_accept_tokens = np.array([[4, 5, -1, 0, 0], [1, 2, 3, 0, 0]])
@@ -121,6 +125,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs = paddle.to_tensor([[11, 12, 13], [14, 15, 16]], dtype="int64")
stop_seqs_len = paddle.to_tensor([3, 3], dtype="int32")
end_ids = paddle.to_tensor([-1], dtype="int64")
min_tokens = paddle.to_tensor([0, 0], dtype="int64")
gpu_results = self.run_op(
accept_tokens,
@@ -132,6 +137,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
)
np.testing.assert_array_equal(gpu_results["output_accept_tokens"], accept_tokens.numpy())
@@ -152,6 +158,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
)
stop_seqs_len = paddle.to_tensor([3], dtype="int32")
end_ids = paddle.to_tensor([-1], dtype="int64")
min_tokens = paddle.to_tensor([0], dtype="int64")
gpu_results = self.run_op(
accept_tokens,
@@ -163,6 +170,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
)
np.testing.assert_array_equal(gpu_results["output_accept_tokens"], accept_tokens.numpy())
@@ -180,6 +188,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs = paddle.to_tensor([[5, 4, 3]], dtype="int64")
stop_seqs_len = paddle.to_tensor([3], dtype="int32")
end_ids = paddle.to_tensor([-1], dtype="int64")
min_tokens = paddle.to_tensor([0], dtype="int64")
gpu_results = self.run_op(
accept_tokens,
@@ -191,11 +200,52 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
)
np.testing.assert_array_equal(gpu_results["output_accept_tokens"], accept_tokens.numpy())
np.testing.assert_array_equal(gpu_results["output_stop_flags"], stop_flags.numpy())
def test_min_tokens_allows_stop(self):
"""Test that stopping is allowed when step_idx >= min_tokens"""
accept_tokens = paddle.to_tensor(
[[4, 5, 0, 0, 0]],
dtype="int64",
)
accept_num = paddle.to_tensor([3], dtype="int32")
pre_ids = paddle.to_tensor(
[[7, 8, 9, 3, 4, 5]],
dtype="int64",
)
step_idx = paddle.to_tensor([6], dtype="int64")
stop_flags = paddle.to_tensor([False], dtype="bool")
seq_lens = paddle.to_tensor([6], dtype="int32")
stop_seqs = paddle.to_tensor(
[[[3, 4, 5]]],
dtype="int64",
)
stop_seqs_len = paddle.to_tensor([[3]], dtype="int32")
end_ids = paddle.to_tensor([-1], dtype="int64")
min_tokens = paddle.to_tensor([5], dtype="int64") # min_tokens=5, step_idx=6 >= 5
gpu_results = self.run_op(
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
min_tokens,
)
expected_accept_tokens = np.array([[4, 5, -1, 0, 0]])
expected_stop_flags = np.array([True])
np.testing.assert_array_equal(gpu_results["output_accept_tokens"], expected_accept_tokens)
np.testing.assert_array_equal(gpu_results["output_stop_flags"], expected_stop_flags)
if __name__ == "__main__":
unittest.main()

View File

@@ -36,6 +36,7 @@ def test_set_stop_value_multi_ends_with_stop_seq():
stop_seqs_len = paddle.full([2, 5], 10, dtype="int32")
stop_seqs_len[0, 0] = 2
min_tokens = paddle.to_tensor([[0], [0]], dtype="int64")
set_stop_value_multi_ends(
sampled_token_ids,
@@ -47,6 +48,7 @@ def test_set_stop_value_multi_ends_with_stop_seq():
step_idx,
stop_token_ids,
stop_seqs_len,
min_tokens,
False,
)
@@ -54,5 +56,40 @@ def test_set_stop_value_multi_ends_with_stop_seq():
assert sampled_token_ids[0, 0] == 2 # eos token id
def test_min_tokens():
"""Test min_tokens functionality"""
sampled_token_ids = paddle.to_tensor([[2], [100], [200]], dtype="int64")
stop_flags = paddle.to_tensor([[False], [False], [False]], dtype="bool")
seq_lens_this_time = paddle.to_tensor([[1], [1], [1]], dtype="int32")
eos_token_id = paddle.to_tensor([2], dtype="int64")
next_tokens = paddle.to_tensor([[2], [100], [200]], dtype="int64")
pre_ids = paddle.full([3, 100], -1, dtype="int64")
step_idx = paddle.to_tensor([[5], [50], [10]], dtype="int64")
stop_seqs = paddle.full([3, 5, 8], -1, dtype="int64")
stop_seqs_len = paddle.zeros([3, 5], dtype="int32")
min_tokens = paddle.to_tensor([[10], [0], [5]], dtype="int64")
set_stop_value_multi_ends(
sampled_token_ids,
stop_flags,
seq_lens_this_time,
eos_token_id,
next_tokens,
pre_ids,
step_idx,
stop_seqs,
stop_seqs_len,
min_tokens,
False,
)
# Sample 0: step < min_tokens, should not stop even with EOS
assert bool(stop_flags[0, 0]) is False
if __name__ == "__main__":
test_set_stop_value_multi_ends_with_stop_seq()
test_min_tokens()