diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index ccbcab01a..45c188205 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -912,6 +912,38 @@ void SaveOutMmsgStatic(const paddle::Tensor& x, int64_t rank_id, bool save_each_rank); +void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens, + const paddle::Tensor &max_think_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &limit_think_status, + const int64_t think_end_id); + +void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens, + const paddle::Tensor &max_think_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &limit_think_status, + const int64_t think_end_id, + const int64_t line_break_id); + +void SpeculateLimitThinkingContentLengthV1( + const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder, + const int64_t think_end_id); + +void SpeculateLimitThinkingContentLengthV2( + const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder, + const int64_t think_end_id, + const int64_t line_break_id); + void SpeculateGetLogits(const paddle::Tensor &draft_logits, const paddle::Tensor &next_token_num, const paddle::Tensor &batch_token_num, @@ -1320,6 +1352,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("save_output", &SaveOutMmsgStatic, "save_output function"); + m.def("limit_thinking_content_length_v1", &LimitThinkingContentLengthV1, "limit_thinking_content_length_v1 function"); + + m.def("limit_thinking_content_length_v2", &LimitThinkingContentLengthV2, "limit_thinking_content_length_v2 function"); + + m.def("speculate_limit_thinking_content_length_v1", &SpeculateLimitThinkingContentLengthV1, "speculate limit thinking content length function"); + + m.def("speculate_limit_thinking_content_length_v2", &SpeculateLimitThinkingContentLengthV2, "speculate limit thinking content length function"); + m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function"); m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function"); diff --git a/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu b/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu new file mode 100644 index 000000000..d4c494b53 --- /dev/null +++ b/custom_ops/gpu_ops/limit_thinking_content_length_v1.cu @@ -0,0 +1,88 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void limit_thinking_content_length_kernel_v1( + int64_t *next_tokens, + const int *max_think_lens, + const int64_t *step_idx, + int *limit_think_status, + const int64_t think_end_id, + const int bs) { + int bid = threadIdx.x; + if (bid >= bs) return; + + // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) return; + int current_limit_think_status = limit_think_status[bid]; + // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. + if (current_limit_think_status == 2) { + return; + } + + int64_t next_token = next_tokens[bid]; + const int64_t step = step_idx[bid]; + + // ======================= 思考阶段控制 ======================= + // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 + if (current_limit_think_status < 1) { + // 当开启思考长度控制时,检查是否超时 + if (step >= max_think_len) { + // 强制将当前token替换为结束思考的token + next_token = think_end_id; + // 将状态推进到 1, 表示 "正在结束思考" + current_limit_think_status = 1; + } + } + // ======================= 思考结束处理 ======================= + // 阶段 2: 检查是否已满足结束思考的条件 (status < 2) + // 这种情况会处理两种场景: + // 1. status == 0: 模型自己生成了 think_end_id + // 2. status == 1: 上一阶段强制注入了 think_end_id + if (current_limit_think_status < 2) { + if (next_token == think_end_id) { + // 确认思考结束,将状态推进到 2 (响应阶段) + current_limit_think_status = 2; + } + } + // 写回更新后的 token + next_tokens[bid] = next_token; + // 更新全局状态 + limit_think_status[bid] = current_limit_think_status; +} + +void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens, + const paddle::Tensor &max_think_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &limit_think_status, + const int64_t think_end_id) { + const int batch_size = next_tokens.shape()[0]; + limit_thinking_content_length_kernel_v1<<<1, 1024>>>( + const_cast(next_tokens.data()), + max_think_lens.data(), + step_idx.data(), + const_cast(limit_think_status.data()), + think_end_id, + batch_size); +} + +PD_BUILD_OP(limit_thinking_content_length_v1) + .Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"}) + .Attrs({"think_end_id: int64_t"}) + .Outputs({"next_tokens_out"}) + .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV1)); diff --git a/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu b/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu new file mode 100644 index 000000000..a61dec896 --- /dev/null +++ b/custom_ops/gpu_ops/limit_thinking_content_length_v2.cu @@ -0,0 +1,111 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +// status == 0: 正常生成阶段 +// status == 1: 替换阶段 +// status == 2: 替换结束阶段 +// status == 3: 思考结束阶段 +__global__ void limit_thinking_content_length_kernel_v2( + int64_t *next_tokens, + const int *max_think_lens, + const int64_t *step_idx, + int *limit_think_status, + const int64_t think_end_id, + const int64_t line_break_id, + const int bs) { + int bid = threadIdx.x; + if (bid >= bs) return; + // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) return; + int current_limit_think_status = limit_think_status[bid]; + // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. + if (current_limit_think_status == 3) { + return; + } + + int64_t next_token = next_tokens[bid]; + const int64_t step = step_idx[bid]; + + // ======================= 思考阶段控制 ======================= + // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 + // 阶段 2: 在替换 (status == 1), 检查是否替换结束 + if (current_limit_think_status <= 1) { + // 当开启思考长度控制时,检查是否超时 + if (step == max_think_len) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + current_limit_think_status = 1; + } else if (step == max_think_len + 1) { + // 强制将当前token替换为结束思考的token + next_token = think_end_id; + current_limit_think_status = 1; + } else if (step == max_think_len + 2) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + current_limit_think_status = 1; + } else if (step == max_think_len + 3) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + // 将状态推进到 1, 表示 "正在结束思考" + current_limit_think_status = 2; + } + } + // ======================= 思考结束处理 ======================= + // 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2) + // 这种情况会处理两种场景: + // 1. status == 0: 模型可能自己生成了 + // 2. status == 2: 上一阶段强制注入了 \n\n\n + if (current_limit_think_status == 0) { + if (next_token == think_end_id) { + // 确认思考结束,将状态推进到 3 (响应阶段) + current_limit_think_status = 3; + } + } + if (current_limit_think_status == 2) { + // 确认思考结束,将状态推进到 3 (响应阶段) + current_limit_think_status = 3; + } + // 写回更新后的 token + next_tokens[bid] = next_token; + // 更新全局状态 + limit_think_status[bid] = current_limit_think_status; +} + +void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens, + const paddle::Tensor &max_think_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &limit_think_status, + const int64_t think_end_id, + const int64_t line_break_id) { + const int batch_size = next_tokens.shape()[0]; + limit_thinking_content_length_kernel_v2<<<1, 1024>>>( + const_cast(next_tokens.data()), + max_think_lens.data(), + step_idx.data(), + const_cast(limit_think_status.data()), + think_end_id, + line_break_id, + batch_size); +} + +PD_BUILD_OP(limit_thinking_content_length_v2) + .Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"}) + .Attrs({"think_end_id: int64_t", "line_break_id: int64_t"}) + .Outputs({"next_tokens_out"}) + .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV2)); diff --git a/custom_ops/gpu_ops/set_value_by_flags_and_idx.cu b/custom_ops/gpu_ops/set_value_by_flags_and_idx.cu index 9e7a0ce11..391816830 100644 --- a/custom_ops/gpu_ops/set_value_by_flags_and_idx.cu +++ b/custom_ops/gpu_ops/set_value_by_flags_and_idx.cu @@ -35,7 +35,7 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags, const int seq_len_dec = seq_lens_decoder[tid]; const int seq_len_enc = seq_lens_encoder[tid]; if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped - if (step_idx[tid] >= 0) { + if (step_idx[tid] > 0) { if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder pre_ids_all_now[step_idx[tid]] = input_ids_now[seq_len_enc - 1]; } else { // decoedr, get first token diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v1.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v1.cu new file mode 100644 index 000000000..96e6a7004 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v1.cu @@ -0,0 +1,132 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void speculate_limit_thinking_content_length_kernel_v1( + int64_t* next_tokens, + const int* max_think_lens, + int64_t* step_idx, + int* limit_think_status, + int* accept_num, + int* seq_lens_decoder, + const int64_t think_end_id, + const int tokens_per_step, + const int bs) { + int bid = threadIdx.x; + if (bid >= bs) return; + + const int original_accept_num = accept_num[bid]; + if (original_accept_num <= 0) return; + + // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) return; + int current_limit_think_status = limit_think_status[bid]; + // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. + if (current_limit_think_status == 3) { + return; + } + + int new_accept_num = original_accept_num; + + const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + + for (int token_offset = 0; token_offset < original_accept_num; + token_offset++) { + const int token_idx = bid * tokens_per_step + token_offset; + int64_t next_token = next_tokens[token_idx]; + const int64_t current_step = current_base_step + token_offset; + + bool condition_triggered = false; + + // ======================= 思考阶段控制 ======================= + // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 + // 阶段 2: 在替换 (status == 1), 检查是否替换结束 + if (current_limit_think_status < 1) { + // 当开启思考长度控制时,检查是否超时 + if (current_step >= max_think_len) { + // 强制将当前token替换为结束思考的token + next_token = think_end_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } + } + + // ======================= 思考结束处理 ======================= + // 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2) + // 这种情况会处理两种场景: + // 1. status == 0: 模型可能自己生成了 + // 2. status == 2: 上一阶段强制注入了 + if (current_limit_think_status < 2) { + if (next_token == think_end_id) { + // 确认思考结束,将状态推进到 2 (响应阶段) + current_limit_think_status = 2; + } + } + + next_tokens[token_idx] = next_token; + + if (condition_triggered) { + new_accept_num = token_offset + 1; + break; + } + } + + // 更新全局状态 + int discarded_tokens = original_accept_num - new_accept_num; + if (discarded_tokens > 0) { + step_idx[bid] -= discarded_tokens; + seq_lens_decoder[bid] -= discarded_tokens; + } + + accept_num[bid] = new_accept_num; + limit_think_status[bid] = current_limit_think_status; +} + +void SpeculateLimitThinkingContentLengthV1( + const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder, + const int64_t think_end_id) { + const int batch_size = next_tokens.shape()[0]; + const int tokens_per_step = next_tokens.shape()[1]; + + speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>( + const_cast(next_tokens.data()), + max_think_lens.data(), + const_cast(step_idx.data()), + const_cast(limit_think_status.data()), + const_cast(accept_num.data()), + const_cast(seq_lens_decoder.data()), + think_end_id, + tokens_per_step, + batch_size); +} + +PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1) + .Inputs({"next_tokens", + "max_think_lens", + "step_idx", + "limit_think_status", + "accept_num", + "seq_lens_decoder"}) + .Attrs({"think_end_id: int64_t"}) + .Outputs({"next_tokens_out"}) + .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV1)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu new file mode 100644 index 000000000..e885cfb2a --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu @@ -0,0 +1,159 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +// status == 0: 正常生成阶段 +// status == 1: 替换阶段 +// status == 2: 替换结束阶段 +// status == 3: 思考结束阶段 +__global__ void speculate_limit_thinking_content_length_kernel_v2( + int64_t* next_tokens, + const int* max_think_lens, + int64_t* step_idx, + int* limit_think_status, + int* accept_num, + int* seq_lens_decoder, + const int64_t think_end_id, + const int64_t line_break_id, + const int tokens_per_step, + const int bs) { + int bid = threadIdx.x; + if (bid >= bs) return; + + const int original_accept_num = accept_num[bid]; + if (original_accept_num <= 0) return; + + // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) return; + int current_limit_think_status = limit_think_status[bid]; + // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. + if (current_limit_think_status == 3) { + return; + } + + int new_accept_num = original_accept_num; + + const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + + for (int token_offset = 0; token_offset < original_accept_num; + token_offset++) { + const int token_idx = bid * tokens_per_step + token_offset; + int64_t next_token = next_tokens[token_idx]; + const int64_t current_step = current_base_step + token_offset; + + bool condition_triggered = false; + + // ======================= 思考阶段控制 ======================= + // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 + // 阶段 2: 在替换 (status == 1), 检查是否替换结束 + if (current_limit_think_status <= 1) { + // 当开启思考长度控制时,检查是否超时 + if (current_step == max_think_len) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } else if (current_step == max_think_len + 1) { + // 强制将当前token替换为结束思考的token + next_token = think_end_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } else if (current_step == max_think_len + 2) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } else if (current_step == max_think_len + 3) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + // 将状态推进到 1, 表示 "正在结束思考" + current_limit_think_status = 2; + condition_triggered = true; // 因为修改了token,需要截断 + } + } + + // ======================= 思考结束处理 ======================= + // 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2) + // 这种情况会处理两种场景: + // 1. status == 0: 模型可能自己生成了 + // 2. status == 2: 上一阶段强制注入了 \n\n\n + if (current_limit_think_status == 0) { + if (next_token == think_end_id) { + // 确认思考结束,将状态推进到 3 (响应阶段) + current_limit_think_status = 3; + } + } + if (current_limit_think_status == 2) { + // 确认思考结束,将状态推进到 3 (响应阶段) + current_limit_think_status = 3; + } + + next_tokens[token_idx] = next_token; + + if (condition_triggered) { + new_accept_num = token_offset + 1; + break; + } + } + + // 更新全局状态 + int discarded_tokens = original_accept_num - new_accept_num; + if (discarded_tokens > 0) { + step_idx[bid] -= discarded_tokens; + seq_lens_decoder[bid] -= discarded_tokens; + } + + accept_num[bid] = new_accept_num; + limit_think_status[bid] = current_limit_think_status; +} + +void SpeculateLimitThinkingContentLengthV2( + const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder, + const int64_t think_end_id, + const int64_t line_break_id) { + const int batch_size = next_tokens.shape()[0]; + const int tokens_per_step = next_tokens.shape()[1]; + + speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>( + const_cast(next_tokens.data()), + max_think_lens.data(), + const_cast(step_idx.data()), + const_cast(limit_think_status.data()), + const_cast(accept_num.data()), + const_cast(seq_lens_decoder.data()), + think_end_id, + line_break_id, + tokens_per_step, + batch_size); +} + +PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2) + .Inputs({"next_tokens", + "max_think_lens", + "step_idx", + "limit_think_status", + "accept_num", + "seq_lens_decoder"}) + .Attrs({"think_end_id: int64_t", "line_break_id: int64_t"}) + .Outputs({"next_tokens_out"}) + .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV2)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu index d1ee733fe..f28e83693 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu @@ -38,7 +38,7 @@ __global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all, const int seq_len_dec = seq_lens_decoder[tid]; const int seq_len_enc = seq_lens_encoder[tid]; if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped - if (step_idx[tid] >= 0) { + if (step_idx[tid] > 0) { for (int i = 0; i < accept_num[tid]; i++) { pre_ids_all_now[step_idx[tid] - i] = accept_tokens_now[accept_num[tid] - 1 - i]; diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 52caf4bab..d1d06e9c2 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -301,6 +301,8 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/noaux_tc.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", + "gpu_ops/limit_thinking_content_length_v1.cu", + "gpu_ops/limit_thinking_content_length_v2.cu", ] # pd_disaggregation diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 8b4930967..dc5d472f5 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -80,11 +80,13 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use Machete for wint4 dense GEMM. "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"), + # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie4_5_vl, \n\n\n for ernie_x1) + "FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", ""), + # Timeout for cache_transfer_manager process exit "FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")), # Count for cache_transfer_manager process error "FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")), - } ``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 41953ffb7..1be359102 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -80,6 +80,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # 是否使用 Machete 后端的 wint4 GEMM. "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"), + # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie4_5_vl, \n\n\n for ernie_x1) + "FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", ""), + # cache_transfer_manager 进程残留时退出等待超时时间 "FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")), diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e2434008b..11f84fbd0 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -30,7 +30,6 @@ from typing_extensions import assert_never import fastdeploy from fastdeploy import envs from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase -from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig from fastdeploy.transformer_utils.config import get_pooling_config @@ -225,26 +224,22 @@ class ModelConfig: self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size) self.think_end_id = args.get("think_end_id", -1) + self.im_patch_id = args.get("image_patch_id", -1) + self.line_break_id = args.get("line_break_id", -1) - architectures = self.architectures[0] - - if MultimodalRegistry.contains_model(architectures): - self.enable_mm = True - else: - self.enable_mm = False + self._post_init() + def _post_init(self): self.is_unified_ckpt = check_unified_ckpt(self.model) - - self.override_name_from_config() - self.read_from_env() - self.read_model_config() self.runner_type = self._get_runner_type(self.architectures, self.runner) self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert) - registry = self.registry is_generative_model = registry.is_text_generation_model(self.architectures, self) is_pooling_model = registry.is_pooling_model(self.architectures, self) is_multimodal_model = registry.is_multimodal_model(self.architectures, self) + self.is_reasoning_model = registry.is_reasoning_model(self.architectures, self) + + self.enable_mm = is_multimodal_model if self.runner_type == "generate" and not is_generative_model: if is_multimodal_model: @@ -269,6 +264,9 @@ class ModelConfig: self._architecture = arch self.pooler_config = self._init_pooler_config() + self.override_name_from_config() + self.read_from_env() + self.read_model_config() @property def registry(self): @@ -1282,21 +1280,6 @@ class CacheConfig: logger.info("=============================================================") -class DecodingConfig: - """ - Configuration for decoding - """ - - def __init__( - self, - args, - ): - self.pad_token_id = None - for key, value in args.items(): - if hasattr(self, key): - setattr(self, key, value) - - class CommitConfig: """ Configuration for tracking version information from version.txt @@ -1388,7 +1371,6 @@ class FDConfig: commit_config: CommitConfig = CommitConfig(), scheduler_config: SchedulerConfig = None, device_config: DeviceConfig = None, - decoding_config: DecodingConfig = None, quant_config: QuantConfigBase = None, graph_opt_config: GraphOptimizationConfig = None, plas_attention_config: PlasAttentionConfig = None, @@ -1417,7 +1399,6 @@ class FDConfig: self.quant_config: Optional[QuantConfigBase] = quant_config self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config - self.decoding_config: DecodingConfig = decoding_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 632c1f98e..a5c317f80 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -34,7 +34,6 @@ import numpy as np import paddle from tqdm import tqdm -from fastdeploy.config import ErnieArchitectures from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.common_engine import EngineService from fastdeploy.engine.expert_service import start_data_parallel_service @@ -89,11 +88,10 @@ class LLMEngine: self.is_started = False self.input_processor = InputPreprocessor( - cfg.tokenizer, + cfg.model_config, cfg.structured_outputs_config.reasoning_parser, cfg.limit_mm_per_prompt, cfg.mm_processor_kwargs, - cfg.model_config.enable_mm, cfg.tool_parser, ) self.engine = EngineService(cfg) @@ -490,13 +488,13 @@ class LLMEngine: else len(self.data_processor.tokenizer.vocab) ) - is_ernie = ErnieArchitectures.contains_ernie_arch(self.cfg.model_config.architectures) - if is_ernie: - self.cfg.model_config.think_end_id = self.data_processor.tokenizer.get_vocab().get("", -1) - if self.cfg.model_config.think_end_id != -1: - llm_logger.info(f"Get think_end_id {self.cfg.model_config.think_end_id} from vocab.") - else: - llm_logger.info("No token found in vocabulary, the model can not do reasoning.") + think_end_id = self.data_processor.tokenizer.get_vocab().get("", -1) + if think_end_id > 0: + llm_logger.info(f"Get think_end_id {think_end_id} from vocab.") + else: + llm_logger.info("No token found in vocabulary, the model can not do reasoning.") + image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1) + line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1) ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port) ips = None @@ -524,7 +522,9 @@ class LLMEngine: f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" f" --ori_vocab_size {ori_vocab_size}" - f" --think_end_id {self.cfg.model_config.think_end_id}" + f" --think_end_id {think_end_id}" + f" --image_patch_id {image_patch_id}" + f" --line_break_id {line_break_id}" f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" diff --git a/fastdeploy/entrypoints/cli/tokenizer.py b/fastdeploy/entrypoints/cli/tokenizer.py index fe477a7e6..3012fd1f6 100644 --- a/fastdeploy/entrypoints/cli/tokenizer.py +++ b/fastdeploy/entrypoints/cli/tokenizer.py @@ -21,6 +21,7 @@ import json import typing from pathlib import Path +from fastdeploy.config import ModelConfig from fastdeploy.entrypoints.cli.types import CLISubcommand from fastdeploy.input.preprocess import InputPreprocessor @@ -199,7 +200,7 @@ def main(args: argparse.Namespace) -> None: return # 初始化tokenizer - preprocessor = InputPreprocessor(model_name_or_path=args.model_name_or_path, enable_mm=args.enable_mm) + preprocessor = InputPreprocessor(model_config=ModelConfig({"model": args.model_name_or_path})) tokenizer = preprocessor.create_processor().tokenizer # 执行操作 diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 6fe41f221..fc1ee0751 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -37,7 +37,6 @@ from fastdeploy.inter_communicator import ( ZmqIpcClient, ) from fastdeploy.metrics.work_metrics import work_process_metrics -from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform from fastdeploy.utils import ( EngineError, @@ -62,7 +61,6 @@ class EngineClient: port, limit_mm_per_prompt, mm_processor_kwargs, - # enable_mm=False, reasoning_parser=None, data_parallel_size=1, enable_logprob=False, @@ -71,20 +69,15 @@ class EngineClient: enable_prefix_caching=None, splitwise_role=None, ): - architectures = ModelConfig({"model": model_name_or_path}).architectures[0] - if MultimodalRegistry.contains_model(architectures): - self.enable_mm = True - else: - self.enable_mm = False - + model_config = ModelConfig({"model": model_name_or_path}) input_processor = InputPreprocessor( - tokenizer, + model_config, reasoning_parser, limit_mm_per_prompt, mm_processor_kwargs, - self.enable_mm, tool_parser, ) + self.enable_mm = model_config.enable_mm self.enable_logprob = enable_logprob self.reasoning_parser = reasoning_parser self.data_processor = input_processor.create_processor() diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 766a07b6a..4aeb8308c 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -185,7 +185,6 @@ async def lifespan(app: FastAPI): port=int(args.engine_worker_queue_port[args.local_data_parallel_id]), limit_mm_per_prompt=args.limit_mm_per_prompt, mm_processor_kwargs=args.mm_processor_kwargs, - # args.enable_mm, reasoning_parser=args.reasoning_parser, data_parallel_size=args.data_parallel_size, enable_logprob=args.enable_logprob, diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index af5420e2a..48d605550 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -118,6 +118,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), # Whether to clear cpu cache when clearing model weights. "FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")), + # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie4_5_vl, \n\n\n for ernie_x1) + "FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", ""), # Timeout for cache_transfer_manager process exit "FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")), # Count for cache_transfer_manager process error diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 5b8eb3ccd..b3af46c95 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -24,7 +24,7 @@ from fastdeploy.reasoning import ReasoningParserManager class InputPreprocessor: """ Args: - model_name_or_path (str): + model_config (ModelConfig): Model name or path to the pretrained model. If a model name is provided, it should be a key in the Hugging Face Transformers' model registry (https://huggingface.co/models). The model will be downloaded from the Hugging Face model hub if necessary. @@ -32,8 +32,6 @@ class InputPreprocessor: reasoning_parser (str, optional): Reasoning parser type. Defaults to None. Flag specifies the reasoning parser to use for extracting reasoning content from the model output - enable_mm (bool, optional): - Whether to use the multi-modal model processor. Defaults to False. Raises: ValueError: @@ -43,32 +41,20 @@ class InputPreprocessor: def __init__( self, - model_name_or_path: str, + model_config: ModelConfig, reasoning_parser: str = None, limit_mm_per_prompt: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - enable_mm: bool = False, tool_parser: str = None, ) -> None: - - self.model_name_or_path = model_name_or_path + self.model_config = model_config + self.model_name_or_path = self.model_config.model self.reasoning_parser = reasoning_parser - self.enable_mm = enable_mm self.limit_mm_per_prompt = limit_mm_per_prompt self.mm_processor_kwargs = mm_processor_kwargs self.tool_parser = tool_parser def create_processor(self): - """ - 创建数据处理器。如果启用了多模态注册表,则使用该表中的模型;否则,使用传递给构造函数的模型名称或路径。 - 返回值:DataProcessor(如果不启用多模态注册表)或MultiModalRegistry.Processor(如果启用多模态注册表)。 - - Args: - 无参数。 - - Returns: - DataProcessor or MultiModalRegistry.Processor (Union[DataProcessor, MultiModalRegistry.Processor]): 数据处理器。 - """ reasoning_parser_obj = None tool_parser_obj = None @@ -77,8 +63,7 @@ class InputPreprocessor: if self.tool_parser: tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser) - config = ModelConfig({"model": self.model_name_or_path}) - architectures = config.architectures[0] + architecture = self.model_config.architectures[0] try: from fastdeploy.plugins.input_processor import load_input_processor_plugins @@ -90,8 +75,8 @@ class InputPreprocessor: tool_parser_obj=tool_parser_obj, ) except: - if not self.enable_mm: - if not ErnieArchitectures.contains_ernie_arch(architectures): + if not self.model_config.enable_mm: + if not ErnieArchitectures.contains_ernie_arch(architecture): from fastdeploy.input.text_processor import DataProcessor self.processor = DataProcessor( @@ -108,7 +93,7 @@ class InputPreprocessor: tool_parser_obj=tool_parser_obj, ) else: - if ErnieArchitectures.contains_ernie_arch(architectures): + if ErnieArchitectures.contains_ernie_arch(architecture): from fastdeploy.input.ernie4_5_vl_processor import ( Ernie4_5_VLProcessor, ) @@ -124,7 +109,7 @@ class InputPreprocessor: from fastdeploy.input.qwen_vl_processor import QwenVLProcessor self.processor = QwenVLProcessor( - config=config, + config=self.model_config, model_name_or_path=self.model_name_or_path, limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, diff --git a/fastdeploy/model_executor/models/interfaces_base.py b/fastdeploy/model_executor/models/interfaces_base.py index b5cea3d23..afc9bb3bb 100644 --- a/fastdeploy/model_executor/models/interfaces_base.py +++ b/fastdeploy/model_executor/models/interfaces_base.py @@ -26,33 +26,10 @@ T = TypeVar("T", default=paddle.Tensor) T_co = TypeVar("T_co", default=paddle.Tensor, covariant=True) -def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool: - from .model_base import ModelForCasualLM - - return issubclass(model_cls, ModelForCasualLM) - - def is_pooling_model(model_cls: Type[nn.Layer]) -> bool: return getattr(model_cls, "is_pooling_model", False) -def is_multimodal_model(class_name: str) -> bool: - multimodal_indicators = ["VL", "Vision", "ConditionalGeneration"] - return any(indicator in class_name for indicator in multimodal_indicators) - - -def determine_model_category(class_name: str): - from fastdeploy.model_executor.models.model_base import ModelCategory - - if any(pattern in class_name for pattern in ["VL", "Vision", "ConditionalGeneration"]): - return ModelCategory.MULTIMODAL - elif any(pattern in class_name for pattern in ["Embedding", "ForSequenceClassification"]): - return ModelCategory.EMBEDDING - elif any(pattern in class_name for pattern in ["Reward"]): - return ModelCategory.REWARD - return ModelCategory.TEXT_GENERATION - - def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str: if model_cls is not None: return getattr(model_cls, "default_pooling_type", "LAST") diff --git a/fastdeploy/model_executor/models/model_base.py b/fastdeploy/model_executor/models/model_base.py index 28eb6b7da..b81606bae 100644 --- a/fastdeploy/model_executor/models/model_base.py +++ b/fastdeploy/model_executor/models/model_base.py @@ -12,7 +12,7 @@ import importlib from abc import ABC, abstractmethod from dataclasses import dataclass -from enum import Enum +from enum import IntFlag, auto from functools import lru_cache from typing import Dict, List, Optional, Tuple, Type, Union @@ -26,20 +26,15 @@ from fastdeploy.config import ( iter_architecture_defaults, try_match_architecture_defaults, ) -from fastdeploy.model_executor.models.interfaces_base import ( - determine_model_category, - get_default_pooling_type, - is_multimodal_model, - is_pooling_model, - is_text_generation_model, -) +from fastdeploy.model_executor.models.interfaces_base import get_default_pooling_type -class ModelCategory(Enum): - TEXT_GENERATION = "text_generation" - MULTIMODAL = "multimodal" - EMBEDDING = "embedding" - REWARD = "reward" +class ModelCategory(IntFlag): + TEXT_GENERATION = auto() + MULTIMODAL = auto() + EMBEDDING = auto() + REASONING = auto() + REWARD = auto() @dataclass(frozen=True) @@ -48,18 +43,22 @@ class ModelInfo: category: ModelCategory is_text_generation: bool is_multimodal: bool + is_reasoning: bool is_pooling: bool module_path: str default_pooling_type: str @staticmethod - def from_model_cls(model_cls: Type[nn.Layer], module_path: str = "") -> "ModelInfo": + def from_model_cls( + model_cls: Type[nn.Layer], module_path: str = "", category: ModelCategory = None + ) -> "ModelInfo": return ModelInfo( architecture=model_cls.__name__, - category=determine_model_category(model_cls.__name__), - is_text_generation=is_text_generation_model(model_cls), - is_multimodal=is_multimodal_model(model_cls.__name__), - is_pooling=is_pooling_model(model_cls), + category=category, + is_text_generation=ModelCategory.TEXT_GENERATION in category, + is_multimodal=ModelCategory.MULTIMODAL in category, + is_reasoning=ModelCategory.REASONING in category, + is_pooling=ModelCategory.EMBEDDING in category, default_pooling_type=get_default_pooling_type(model_cls), module_path=module_path, ) @@ -84,6 +83,7 @@ class LazyRegisteredModel(BaseRegisteredModel): module_name: str module_path: str class_name: str + category: ModelCategory def load_model_cls(self) -> Type[nn.Layer]: try: @@ -95,7 +95,7 @@ class LazyRegisteredModel(BaseRegisteredModel): def inspect_model_cls(self) -> ModelInfo: model_cls = self.load_model_cls() - return ModelInfo.from_model_cls(model_cls, self.module_name) + return ModelInfo.from_model_cls(model_cls, self.module_name, self.category) @lru_cache(maxsize=128) @@ -127,6 +127,7 @@ class ModelRegistry: module_name=model_info["module_name"], module_path=model_info["module_path"], class_name=model_info["class_name"], + category=model_info["category"], ) self.models[arch] = model self._registered_models[arch] = model @@ -317,6 +318,17 @@ class ModelRegistry: return model_info.is_multimodal return False + def is_reasoning_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool: + """Check if it's a reasoning model""" + if isinstance(architectures, str): + architectures = [architectures] + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return model_info.is_reasoning + return False + def is_text_generation_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool: """Check if it's a text generation model""" if isinstance(architectures, str): diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 78e270091..6482b357d 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -79,15 +79,90 @@ else: step_reschedule, update_inputs_v1, speculate_step_reschedule, + limit_thinking_content_length_v1, + limit_thinking_content_length_v2, + speculate_limit_thinking_content_length_v1, + speculate_limit_thinking_content_length_v2, ) - from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" +def limit_thinking_content_length( + limit_strategy: str, + sampled_token_ids: paddle.Tensor, + max_think_lens: paddle.Tensor, + step_idx: paddle.Tensor, + limit_think_status: paddle.Tensor, + think_end_id: int, + line_break_id: int = None, +): + if limit_strategy == "": + # for ernie4_5_vl + limit_thinking_content_length_v1( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + think_end_id, + ) + elif limit_strategy == "\n\n\n": + # for ernie_x1 + assert line_break_id > 0 + limit_thinking_content_length_v2( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + think_end_id, + line_break_id, + ) + else: + raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") + + +def speculate_limit_thinking_content_length( + limit_strategy: str, + accept_tokens: paddle.Tensor, + max_think_lens: paddle.Tensor, + step_idx: paddle.Tensor, + limit_think_status: paddle.Tensor, + accept_num: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + think_end_id: int, + line_break_id: int = None, +): + if limit_strategy == "": + # for ernie4_5_vl + speculate_limit_thinking_content_length_v1( + accept_tokens, + max_think_lens, + step_idx, + limit_think_status, + accept_num, + seq_lens_decoder, + think_end_id, + ) + elif limit_strategy == "\n\n\n": + # for ernie_x1 + assert line_break_id > 0 + speculate_limit_thinking_content_length_v2( + accept_tokens, + max_think_lens, + step_idx, + limit_think_status, + accept_num, + seq_lens_decoder, + think_end_id, + line_break_id, + ) + else: + raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") + + def pre_process( input_ids: paddle.Tensor, seq_lens_this_time: int, @@ -185,46 +260,19 @@ def post_process_normal( save_each_rank: bool = False, skip_save_output: bool = False, async_output_queue: queue.Queue = None, + think_end_id: int = -1, + line_break_id: int = -1, ) -> ModelRunnerOutput: """Post-processing steps after completing a single token generation.""" - # handle vl: - if model_output.think_end_id != -1: - thinking_mask = model_output.enable_thinking[: sampler_output.sampled_token_ids.shape[0]] - exists_think_end = (sampler_output.sampled_token_ids == model_output.think_end_id) & thinking_mask - paddle.assign( - paddle.where( - exists_think_end, - model_output.need_think_end - 1, - model_output.need_think_end, - ), - model_output.need_think_end, - ) - - reasoning_index_update_cond = model_output.need_think_end.cast("bool") & thinking_mask - paddle.assign( - paddle.where( - reasoning_index_update_cond, - model_output.reasoning_index - 1, - model_output.reasoning_index, - ), - model_output.reasoning_index, - ) - - stop_wo_think = ((model_output.reasoning_index == 0)) & (model_output.need_think_end > 0) - - stop_wo_think = stop_wo_think & thinking_mask - sampler_output.sampled_token_ids = paddle.where( - stop_wo_think, - model_output.think_end_id, - sampler_output.sampled_token_ids, - ) - paddle.assign( - paddle.where( - stop_wo_think, - model_output.need_think_end - 1, - model_output.need_think_end, - ), - model_output.need_think_end, + if think_end_id > 0: + limit_thinking_content_length( + limit_strategy=envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR, + sampled_token_ids=sampler_output.sampled_token_ids, + max_think_lens=share_inputs["max_think_lens"], + step_idx=share_inputs["step_idx"], + limit_think_status=share_inputs["limit_think_status"], + think_end_id=think_end_id, + line_break_id=line_break_id, ) # 1. Set stop value paddle.assign( @@ -337,10 +385,25 @@ def post_process_normal( def post_process_specualate( sampler_output: SamplerOutput, model_output: ModelOutputData, + share_inputs: Dict[str, paddle.Tensor], save_each_rank: bool = False, skip_save_output: bool = False, + think_end_id: int = -1, + line_break_id: int = -1, ): - """""" + if think_end_id > 0: + speculate_limit_thinking_content_length( + limit_strategy=envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR, + accept_tokens=share_inputs["accept_tokens"], + max_think_lens=share_inputs["max_think_lens"], + step_idx=share_inputs["step_idx"], + limit_think_status=share_inputs["limit_think_status"], + accept_num=share_inputs["accept_num"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + think_end_id=think_end_id, + line_break_id=line_break_id, + ) + speculate_update( model_output.seq_lens_encoder, model_output.seq_lens_decoder, @@ -403,10 +466,20 @@ def post_process( speculative_decoding: bool = False, skip_save_output: bool = False, async_output_queue: queue.Queue = None, + think_end_id: int = -1, + line_break_id: int = -1, ) -> None: """Post-processing steps after completing a single token generation.""" if speculative_decoding: - post_process_specualate(sampler_output, model_output, save_each_rank, skip_save_output) + post_process_specualate( + sampler_output, + model_output, + share_inputs, + save_each_rank, + skip_save_output, + think_end_id, + line_break_id, + ) else: post_process_normal( sampler_output, @@ -416,6 +489,8 @@ def post_process( save_each_rank, skip_save_output, async_output_queue, + think_end_id, + line_break_id, ) diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py deleted file mode 100644 index d827c9b80..000000000 --- a/fastdeploy/multimodal/registry.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - - -class MultimodalRegistry: - """ - A registry for multimodal models - """ - - mm_models: set[str] = { - "Ernie4_5_VLMoeForConditionalGeneration", - "Ernie5MoeForCausalLM", - "Qwen2_5_VLForConditionalGeneration", - "Ernie5ForCausalLM", - "Ernie4_5_VLMoeForProcessRewardModel", - } - - @classmethod - def contains_model(cls, name: str) -> bool: - """ - Check if the given name exists in registry. - """ - return name in cls.mm_models diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 8a273c995..9f85f82d0 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -383,10 +383,6 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["max_dec_len"] = paddle.full( [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" ) - self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") - self.share_inputs["max_length"] = paddle.full( - [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" - ) self.seq_lens_this_time_buffer = paddle.full(max_num_seqs, 0, dtype="int32") self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") @@ -834,10 +830,6 @@ class GCUModelRunner(ModelRunnerBase): ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), - need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), - reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), ) post_process( @@ -1062,10 +1054,6 @@ class GCUModelRunner(ModelRunnerBase): ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), - need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), - reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), ) if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 216cc5cff..7fbbb0ab2 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -287,13 +287,10 @@ class GPUModelRunner(ModelRunnerBase): elif request.structural_tag is not None: schemata_key = ("structural_tag", request.structural_tag) - enable_thinking = request.get("enable_thinking", True) - enable_thinking = enable_thinking if enable_thinking is not None else True - return ( self.guided_backend.get_logits_processor( schemata_key=schemata_key, - enable_thinking=enable_thinking, + enable_thinking=True, ), schemata_key, ) @@ -355,22 +352,14 @@ class GPUModelRunner(ModelRunnerBase): position_ids, request.get("max_tokens", 2048) ) - if request.get("enable_thinking", False): + if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: # Enable thinking - req_reasoning_max_tokens = request.get("reasoning_max_tokens") - req_max_tokens = request.get("max_tokens") - final_reasoning_tokens = ( - req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens - ) - - self.share_inputs["enable_thinking"][idx : idx + 1] = True - self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = final_reasoning_tokens + self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 else: # Disable thinking - self.share_inputs["enable_thinking"][idx : idx + 1] = False - self.share_inputs["need_think_end"][idx : idx + 1, :] = 0 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0 + self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 if isinstance(request.prompt_token_ids, np.ndarray): prompt_token_ids = request.prompt_token_ids.tolist() @@ -595,22 +584,14 @@ class GPUModelRunner(ModelRunnerBase): ) self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - if request.get("enable_thinking", False): + if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: # Enable thinking - req_reasoning_max_tokens = request.get("reasoning_max_tokens") - req_max_tokens = request.get("max_tokens") - final_reasoning_tokens = ( - req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens - ) - - self.share_inputs["enable_thinking"][idx : idx + 1] = True - self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = final_reasoning_tokens + self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 else: # Disable thinking - self.share_inputs["enable_thinking"][idx : idx + 1] = False - self.share_inputs["need_think_end"][idx : idx + 1, :] = 0 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0 + self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 def get_attr_from_request(request, attr, default_value=None): res = request.get(attr, default_value) @@ -861,10 +842,6 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["max_dec_len"] = paddle.full( [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" ) - self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") - self.share_inputs["max_length"] = paddle.full( - [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" - ) self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32") if self.fd_config.parallel_config.enable_expert_parallel: self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") @@ -921,19 +898,15 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["kv_tile_ids_per_batch"] = None self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU - # Initialize rotary position embedding - tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) - # Initialize thinking related buffers - self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["enable_thinking"] = paddle.full(shape=[max_num_seqs, 1], fill_value=False, dtype="bool") - self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32") + self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - # TODO(gongshaotian): move to models + # Initialize rotary position embedding if not self.enable_mm: self.share_inputs["rope_emb"] = get_rope( rotary_dim=self.model_config.head_dim, - position_ids=tmp_position_ids, + position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, @@ -1496,10 +1469,6 @@ class GPUModelRunner(ModelRunnerBase): ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1), - need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), - reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) @@ -1512,6 +1481,8 @@ class GPUModelRunner(ModelRunnerBase): speculative_decoding=self.speculative_decoding, skip_save_output=True, async_output_queue=self.async_output_queue, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, ) if self.speculative_decoding: if self.speculative_method == "mtp": @@ -1876,7 +1847,6 @@ class GPUModelRunner(ModelRunnerBase): self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) - else: sampler_output = self.sampler( logits, @@ -1931,10 +1901,6 @@ class GPUModelRunner(ModelRunnerBase): ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - enable_thinking=self.share_inputs["enable_thinking"], - think_end_id=self.model_config.think_end_id, - need_think_end=self.share_inputs["need_think_end"][:num_running_requests], - reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests], stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], prompt_lens=self.share_inputs["prompt_lens"], @@ -1953,6 +1919,8 @@ class GPUModelRunner(ModelRunnerBase): speculative_decoding=self.speculative_decoding, skip_save_output=skip_save_output, async_output_queue=self.async_output_queue, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, ) if self.guided_backend is not None and sampler_output is not None: self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index 4923c569f..14556147b 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -592,10 +592,6 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["max_dec_len"] = paddle.full( [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" ) - self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") - self.share_inputs["max_length"] = paddle.full( - [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" - ) self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32") self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index f61498871..f6a2c0b15 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -242,11 +242,6 @@ class MetaxModelRunner(ModelRunnerBase): else: position_ids = None - enable_thinking = request.get("enable_thinking", True) - enable_thinking = enable_thinking if enable_thinking is not None else True - self.share_inputs["enable_thinking"][:] = enable_thinking - self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( position_ids, request.get("max_tokens", 2048) ) @@ -459,11 +454,6 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["prompt_lens"][idx : idx + 1] = length if self.enable_mm: - enable_thinking = request.get("enable_thinking", True) - enable_thinking = enable_thinking if enable_thinking is not None else True - self.share_inputs["enable_thinking"][:] = enable_thinking - self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( position_ids, request.get("max_tokens", 2048) ) @@ -638,10 +628,6 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["max_dec_len"] = paddle.full( [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" ) - self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") - self.share_inputs["max_length"] = paddle.full( - [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" - ) self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32") if self.fd_config.parallel_config.enable_expert_parallel: self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") @@ -779,9 +765,6 @@ class MetaxModelRunner(ModelRunnerBase): dtype="float32", ) self.share_inputs["image_features"] = None - self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool") - self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") def _prepare_inputs(self) -> None: """Prepare the model inputs""" @@ -1133,10 +1116,6 @@ class MetaxModelRunner(ModelRunnerBase): ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), - need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), - reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) @@ -1401,10 +1380,6 @@ class MetaxModelRunner(ModelRunnerBase): ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), - need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None), - reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 30e6ae295..b4192e882 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -222,26 +222,6 @@ class ModelOutputData: """ accept_num: paddle.Tensor - """ - vl model enable to think - """ - enable_thinking: paddle.Tensor = None - - """ - vl model think end id - """ - think_end_id: int = -1 - - """ - vl model need to think - """ - need_think_end: paddle.Tensor = None - - """ - vl model reasoning index - """ - reasoning_index: paddle.Tensor = None - """ the token ids of stop sequence """ diff --git a/fastdeploy/worker/utils.py b/fastdeploy/worker/utils.py deleted file mode 100644 index 7a2562f24..000000000 --- a/fastdeploy/worker/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os -import traceback - - -def check_safetensors_model(model_dir: str): - """ - model_dir : the directory of the model - Check whether the model is safetensors format - """ - model_files = list() - all_files = os.listdir(model_dir) - for x in all_files: - if x.startswith("model") and x.endswith(".safetensors"): - model_files.append(x) - - is_safetensors = len(model_files) > 0 - if not is_safetensors: - return False - - if len(model_files) == 1 and model_files[0] == "model.safetensors": - return True - try: - # check all the file exists - safetensors_num = int(model_files[0].strip(".safetensors").split("-")[-1]) - flags = [0] * safetensors_num - for x in model_files: - current_index = int(x.strip(".safetensors").split("-")[1]) - flags[current_index - 1] = 1 - assert ( - sum(flags) == safetensors_num - ), f"Number of safetensor files should be {len(model_files)}, but now it's {sum(flags)}" - except Exception as e: - raise Exception(f"Failed to check unified checkpoint, details: {e}, {str(traceback.format_exc())}.") - return is_safetensors diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index a3e961cd1..f9cde4b1b 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -28,7 +28,6 @@ from paddle.distributed import fleet from fastdeploy import envs from fastdeploy.config import ( CacheConfig, - DecodingConfig, DeviceConfig, EarlyStopConfig, ErnieArchitectures, @@ -41,7 +40,6 @@ from fastdeploy.config import ( SpeculativeConfig, StructuredOutputsConfig, ) -from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus from fastdeploy.model_executor.layers.quantization import parse_quant_config @@ -117,25 +115,9 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: def update_fd_config_for_mm(fd_config: FDConfig) -> None: architectures = fd_config.model_config.architectures if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures): - tokenizer = Ernie4_5Tokenizer.from_pretrained( - fd_config.model_config.model, - model_max_length=fd_config.model_config.max_model_len, - padding_side="right", - use_fast=False, - ) - tokenizer.ignored_index = -100 - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.unk_token - fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank - vision_config = fd_config.model_config.vision_config - vision_config.dtype = fd_config.model_config.dtype - # vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size - # vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank - fd_config.model_config.im_patch_id = tokenizer.get_vocab()["<|IMAGE_PLACEHOLDER|>"] - fd_config.model_config.think_end_id = tokenizer.get_vocab()[""] - fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel + fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype class PaddleDisWorkerProc: @@ -577,6 +559,8 @@ def parse_args(): ) parser.add_argument("--ori_vocab_size", type=int, default=None) parser.add_argument("--think_end_id", type=int, default=-1) + parser.add_argument("--image_patch_id", type=int, default=-1) + parser.add_argument("--line_break_id", type=int, default=-1) parser.add_argument( "--quantization", @@ -707,7 +691,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: paddle.set_default_dtype(args.dtype) model_config = ModelConfig(vars(args)) device_config = DeviceConfig(vars(args)) - decoding_config = DecodingConfig(vars(args)) speculative_config = SpeculativeConfig(args.speculative_config) parallel_config = ParallelConfig(vars(args)) cache_config = CacheConfig(vars(args)) @@ -808,7 +791,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: speculative_config=speculative_config, device_config=device_config, load_config=load_config, - decoding_config=decoding_config, quant_config=quant_config, graph_opt_config=graph_opt_config, early_stop_config=early_stop_config, diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 5b1a886c0..0f05f2dc6 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -193,45 +193,6 @@ def xpu_post_process( update_inputs, ) - # handle vl: - if model_output.enable_thinking: - exists_think_end = sampled_token_ids == model_output.think_end_id - paddle.assign( - paddle.where( - exists_think_end, - model_output.need_think_end - 1, - model_output.need_think_end, - ), - model_output.need_think_end, - ) - - paddle.assign( - paddle.where( - model_output.need_think_end.cast("bool"), - model_output.reasoning_index - 1, - model_output.reasoning_index, - ), - model_output.reasoning_index, - ) - - stop_wo_think = ( - (sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True) - | (model_output.reasoning_index == 0) - ) & (model_output.need_think_end > 0) - sampled_token_ids = paddle.where( - stop_wo_think, - model_output.think_end_id, - sampled_token_ids, - ) - paddle.assign( - paddle.where( - stop_wo_think, - model_output.need_think_end - 1, - model_output.need_think_end, - ), - model_output.need_think_end, - ) - # 1. Set stop value paddle.assign( paddle.where( @@ -466,11 +427,6 @@ class XPUModelRunner(ModelRunnerBase): else: position_ids = None - enable_thinking = request.get("enable_thinking", True) - enable_thinking = enable_thinking if enable_thinking is not None else True - self.share_inputs["enable_thinking"][:] = enable_thinking - self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( position_ids, request.get("max_tokens", 2048) ) @@ -605,11 +561,6 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["prompt_lens"][idx : idx + 1] = length if self.enable_mm: - enable_thinking = request.get("enable_thinking", True) - enable_thinking = enable_thinking if enable_thinking is not None else True - self.share_inputs["enable_thinking"][:] = enable_thinking - self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 - self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( position_ids, request.get("max_tokens", 2048) ) @@ -730,10 +681,6 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["max_dec_len"] = paddle.full( [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" ) - self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") - self.share_inputs["max_length"] = paddle.full( - [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" - ) self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32") self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") @@ -824,9 +771,6 @@ class XPUModelRunner(ModelRunnerBase): dtype="float32", ) self.share_inputs["image_features"] = None - self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool") - self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") def _prepare_inputs(self, is_dummy_run=False) -> None: """Prepare the model inputs""" @@ -1158,10 +1102,6 @@ class XPUModelRunner(ModelRunnerBase): actual_draft_token_num=None, accept_tokens=None, accept_num=None, - enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), - think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), - need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), - reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) diff --git a/tests/entrypoints/cli/test_tokenizer_cli.py b/tests/entrypoints/cli/test_tokenizer_cli.py deleted file mode 100644 index 2bfc2f747..000000000 --- a/tests/entrypoints/cli/test_tokenizer_cli.py +++ /dev/null @@ -1,590 +0,0 @@ -""" -Test cases for tokenizer CLI -""" - -import argparse -import json -import os -import tempfile -import unittest -from unittest.mock import MagicMock, PropertyMock, patch - - -class MockCLISubcommand: - """模拟CLISubcommand基类""" - - pass - - -class MockInputPreprocessor: - """模拟InputPreprocessor类""" - - def __init__(self, model_name_or_path): - self.model_name_or_path = model_name_or_path - - def create_processor(self): - mock_processor = MagicMock() - mock_processor.tokenizer = MagicMock() - return mock_processor - - -# 导入被测试代码,使用模拟的依赖 -with patch("fastdeploy.entrypoints.cli.types.CLISubcommand", MockCLISubcommand): - with patch("fastdeploy.input.preprocess.InputPreprocessor", MockInputPreprocessor): - # 这里直接包含被测试的代码内容 - from fastdeploy.entrypoints.cli.tokenizer import ( - TokenizerSubcommand, - cmd_init, - export_vocabulary, - get_tokenizer_info, - get_vocab_dict, - get_vocab_size, - main, - ) - - -class TestTokenizerSubcommand(unittest.TestCase): - """测试TokenizerSubcommand类""" - - def test_name_attribute(self): - self.assertEqual(TokenizerSubcommand.name, "tokenizer") - - def test_subparser_init(self): - subcommand = TokenizerSubcommand() - mock_subparsers = MagicMock() - mock_parser = MagicMock() - mock_subparsers.add_parser.return_value = mock_parser - - parser = subcommand.subparser_init(mock_subparsers) - - # 验证解析器创建 - mock_subparsers.add_parser.assert_called_once_with( - name="tokenizer", - help="Start the FastDeploy Tokenizer Server.", - description="Start the FastDeploy Tokenizer Server.", - usage="fastdeploy tokenizer [--encode/-e TEXT] [--decode/-d TEXT]", - ) - self.assertEqual(parser, mock_parser) - - # 验证参数添加(检查调用次数) - self.assertGreater(mock_parser.add_argument.call_count, 0) - - def test_cmd_method(self): - subcommand = TokenizerSubcommand() - args = argparse.Namespace() - - with patch("fastdeploy.entrypoints.cli.tokenizer.main") as mock_main: - subcommand.cmd(args) - mock_main.assert_called_once_with(args) - - -class TestCmdInit(unittest.TestCase): - """测试cmd_init函数""" - - def test_cmd_init_returns_list(self): - result = cmd_init() - self.assertIsInstance(result, list) - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TokenizerSubcommand) - - -class TestGetVocabSize(unittest.TestCase): - """测试get_vocab_size函数""" - - def test_with_vocab_size_attribute(self): - mock_tokenizer = MagicMock() - # 使用PropertyMock来正确模拟属性 - type(mock_tokenizer).vocab_size = PropertyMock(return_value=1000) - result = get_vocab_size(mock_tokenizer) - self.assertEqual(result, 1000) - - def test_with_get_vocab_size_method(self): - mock_tokenizer = MagicMock() - # 确保vocab_size属性不存在,让代码使用get_vocab_size方法 - delattr(mock_tokenizer, "vocab_size") - mock_tokenizer.get_vocab_size.return_value = 2000 - result = get_vocab_size(mock_tokenizer) - self.assertEqual(result, 2000) - - def test_with_no_methods_available(self): - mock_tokenizer = MagicMock() - # 移除可能的方法 - delattr(mock_tokenizer, "vocab_size") - delattr(mock_tokenizer, "get_vocab_size") - result = get_vocab_size(mock_tokenizer) - self.assertEqual(result, 100295) # 默认值 - - def test_exception_handling(self): - mock_tokenizer = MagicMock() - # 模拟两个方法都抛出异常 - type(mock_tokenizer).vocab_size = PropertyMock(side_effect=Exception("Error")) - mock_tokenizer.get_vocab_size.side_effect = Exception("Error") - result = get_vocab_size(mock_tokenizer) - self.assertEqual(result, 0) # 默认值 - - -class TestGetTokenizerInfo(unittest.TestCase): - """测试get_tokenizer_info函数""" - - def setUp(self): - self.mock_tokenizer = MagicMock() - type(self.mock_tokenizer).vocab_size = PropertyMock(return_value=1000) - type(self.mock_tokenizer).name_or_path = PropertyMock(return_value="test/model") - type(self.mock_tokenizer).model_max_length = PropertyMock(return_value=512) - - # 特殊token - type(self.mock_tokenizer).bos_token = PropertyMock(return_value="") - type(self.mock_tokenizer).eos_token = PropertyMock(return_value="") - type(self.mock_tokenizer).unk_token = PropertyMock(return_value="") - type(self.mock_tokenizer).sep_token = PropertyMock(return_value="") - type(self.mock_tokenizer).pad_token = PropertyMock(return_value="") - type(self.mock_tokenizer).cls_token = PropertyMock(return_value="") - type(self.mock_tokenizer).mask_token = PropertyMock(return_value="") - - # 特殊token ID - type(self.mock_tokenizer).bos_token_id = PropertyMock(return_value=1) - type(self.mock_tokenizer).eos_token_id = PropertyMock(return_value=2) - type(self.mock_tokenizer).unk_token_id = PropertyMock(return_value=3) - type(self.mock_tokenizer).sep_token_id = PropertyMock(return_value=4) - type(self.mock_tokenizer).pad_token_id = PropertyMock(return_value=0) - type(self.mock_tokenizer).cls_token_id = PropertyMock(return_value=5) - type(self.mock_tokenizer).mask_token_id = PropertyMock(return_value=6) - - def test_normal_case(self): - info = get_tokenizer_info(self.mock_tokenizer) - - self.assertEqual(info["vocab_size"], 1000) - self.assertEqual(info["model_name"], "test/model") - self.assertEqual(info["tokenizer_type"], "MagicMock") - self.assertEqual(info["model_max_length"], 512) - - # 检查特殊token - self.assertEqual(info["special_tokens"]["bos_token"], "") - self.assertEqual(info["special_token_ids"]["bos_token_id"], 1) - - def test_exception_handling(self): - # 模拟在获取属性时抛出异常 - with patch("fastdeploy.entrypoints.cli.tokenizer.get_vocab_size", side_effect=Exception("Test error")): - info = get_tokenizer_info(self.mock_tokenizer) - self.assertIn("error", info) - self.assertIn("Test error", info["error"]) - - -class TestGetVocabDict(unittest.TestCase): - """测试get_vocab_dict函数""" - - def test_vocab_attribute(self): - mock_tokenizer = MagicMock() - mock_tokenizer.vocab = {"hello": 1, "world": 2} - result = get_vocab_dict(mock_tokenizer) - self.assertEqual(result, {"hello": 1, "world": 2}) - - def test_get_vocab_method(self): - mock_tokenizer = MagicMock() - # 确保vocab属性不存在,让代码使用get_vocab方法 - delattr(mock_tokenizer, "vocab") - mock_tokenizer.get_vocab.return_value = {"a": 1, "b": 2} - result = get_vocab_dict(mock_tokenizer) - self.assertEqual(result, {"a": 1, "b": 2}) - - def test_tokenizer_vocab(self): - mock_tokenizer = MagicMock() - # 确保vocab和get_vocab都不存在 - delattr(mock_tokenizer, "vocab") - delattr(mock_tokenizer, "get_vocab") - - mock_inner_tokenizer = MagicMock() - mock_inner_tokenizer.vocab = {"x": 1} - mock_tokenizer.tokenizer = mock_inner_tokenizer - - result = get_vocab_dict(mock_tokenizer) - self.assertEqual(result, {"x": 1}) - - def test_encoder_attribute(self): - mock_tokenizer = MagicMock() - # 确保其他属性都不存在 - delattr(mock_tokenizer, "vocab") - delattr(mock_tokenizer, "get_vocab") - delattr(mock_tokenizer, "tokenizer") - - mock_tokenizer.encoder = {"token": 0} - result = get_vocab_dict(mock_tokenizer) - self.assertEqual(result, {"token": 0}) - - def test_no_vocab_available(self): - mock_tokenizer = MagicMock() - # 移除所有可能的属性 - delattr(mock_tokenizer, "vocab") - delattr(mock_tokenizer, "get_vocab") - delattr(mock_tokenizer, "tokenizer") - delattr(mock_tokenizer, "encoder") - - result = get_vocab_dict(mock_tokenizer) - self.assertEqual(result, {}) - - def test_exception_handling(self): - mock_tokenizer = MagicMock() - # 模拟所有方法都抛出异常 - mock_tokenizer.vocab = {"a": 1} - mock_tokenizer.get_vocab.side_effect = Exception("Error") - result = get_vocab_dict(mock_tokenizer) - self.assertEqual(result, {"a": 1}) - - -class TestExportVocabulary(unittest.TestCase): - """测试export_vocabulary函数""" - - def setUp(self): - self.mock_tokenizer = MagicMock() - self.mock_tokenizer.vocab = {"hello": 1, "world": 2, "test": 3} - - def test_export_json_format(self): - with tempfile.TemporaryDirectory() as temp_dir: - file_path = os.path.join(temp_dir, "vocab.json") - - with patch("builtins.print") as mock_print: - export_vocabulary(self.mock_tokenizer, file_path) - - # 验证文件内容 - with open(file_path, "r", encoding="utf-8") as f: - content = json.load(f) - self.assertEqual(content, {"hello": 1, "world": 2, "test": 3}) - - # 验证打印输出 - mock_print.assert_any_call(f"Vocabulary exported to: {file_path}") - mock_print.assert_any_call("Total tokens: 3") - - def test_export_text_format(self): - with tempfile.TemporaryDirectory() as temp_dir: - file_path = os.path.join(temp_dir, "vocab.txt") - - with patch("builtins.print"): - export_vocabulary(self.mock_tokenizer, file_path) - - # 验证文件内容 - with open(file_path, "r", encoding="utf-8") as f: - lines = f.readlines() - - self.assertEqual(len(lines), 3) - # 检查排序和格式 - 注意repr会添加引号 - self.assertIn("1\t'hello'", lines[0]) - self.assertIn("2\t'world'", lines[1]) - self.assertIn("3\t'test'", lines[2]) - - def test_empty_vocabulary(self): - mock_tokenizer = MagicMock() - mock_tokenizer.vocab = {} - - with tempfile.TemporaryDirectory() as temp_dir: - file_path = os.path.join(temp_dir, "vocab.json") - - with patch("builtins.print") as mock_print: - export_vocabulary(mock_tokenizer, file_path) - mock_print.assert_any_call("Warning: Could not retrieve vocabulary from tokenizer") - - def test_directory_creation(self): - with tempfile.TemporaryDirectory() as temp_dir: - file_path = os.path.join(temp_dir, "newdir", "vocab.json") - - with patch("builtins.print"): - export_vocabulary(self.mock_tokenizer, file_path) - - # 验证目录被创建 - self.assertTrue(os.path.exists(os.path.dirname(file_path))) - - def test_exception_handling(self): - with patch("pathlib.Path.mkdir", side_effect=Exception("Permission denied")): - with patch("builtins.print") as mock_print: - export_vocabulary(self.mock_tokenizer, "/invalid/path/vocab.json") - mock_print.assert_any_call("Error exporting vocabulary: Permission denied") - - -class TestMainFunction(unittest.TestCase): - """测试main函数""" - - def setUp(self): - self.mock_tokenizer = MagicMock() - self.mock_preprocessor = MagicMock() - self.mock_preprocessor.create_processor.return_value.tokenizer = self.mock_tokenizer - - def test_no_arguments(self): - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode=None, - vocab_size=False, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("builtins.print") as mock_print: - main(args) - mock_print.assert_called_with( - "请至少指定一个参数:--encode, --decode, --vocab-size, --info, --export-vocab" - ) - - def test_encode_operation(self): - self.mock_tokenizer.encode.return_value = [101, 102, 103] - - args = argparse.Namespace( - model_name_or_path="test/model", - encode="hello world", - decode=None, - vocab_size=False, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("builtins.print") as mock_print: - main(args) - - self.mock_tokenizer.encode.assert_called_once_with("hello world") - mock_print.assert_any_call("Input text: hello world") - mock_print.assert_any_call("Encoded tokens: [101, 102, 103]") - - def test_decode_operation_list_string(self): - self.mock_tokenizer.decode.return_value = "decoded text" - - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode="[1,2,3]", - vocab_size=False, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("builtins.print") as mock_print: - main(args) - - self.mock_tokenizer.decode.assert_called_once_with([1, 2, 3]) - mock_print.assert_any_call("Decoded text: decoded text") - - def test_decode_operation_comma_string(self): - self.mock_tokenizer.decode.return_value = "decoded text" - - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode="1,2,3", - vocab_size=False, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("builtins.print"): - main(args) - - self.mock_tokenizer.decode.assert_called_once_with([1, 2, 3]) - - def test_decode_operation_already_list(self): - self.mock_tokenizer.decode.return_value = "decoded text" - - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode=[1, 2, 3], - vocab_size=False, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("builtins.print"): - main(args) - self.mock_tokenizer.decode.assert_called_once_with([1, 2, 3]) - - def test_decode_exception_handling(self): - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode="invalid[1,2", # 无效的字符串 - vocab_size=False, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("builtins.print") as mock_print: - main(args) - # 检查是否有包含"Error decoding tokens:"的打印调用 - error_calls = [ - call for call in mock_print.call_args_list if call[0] and "Error decoding tokens:" in call[0][0] - ] - self.assertTrue(len(error_calls) > 0) - - def test_vocab_size_operation(self): - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode=None, - vocab_size=True, - info=False, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("fastdeploy.entrypoints.cli.tokenizer.get_vocab_size", return_value=1000) as mock_get_size: - with patch("builtins.print") as mock_print: - main(args) - - mock_get_size.assert_called_once_with(self.mock_tokenizer) - mock_print.assert_any_call("Vocabulary size: 1000") - - def test_info_operation(self): - mock_info = {"vocab_size": 1000, "model_name": "test"} - - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode=None, - vocab_size=False, - info=True, - vocab_export=None, - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch( - "fastdeploy.entrypoints.cli.tokenizer.get_tokenizer_info", return_value=mock_info - ) as mock_get_info: - with patch("builtins.print") as mock_print: - main(args) - - mock_get_info.assert_called_once_with(self.mock_tokenizer) - mock_print.assert_any_call(json.dumps(mock_info, indent=2)) - - def test_vocab_export_operation(self): - args = argparse.Namespace( - model_name_or_path="test/model", - encode=None, - decode=None, - vocab_size=False, - info=False, - vocab_export="/path/to/vocab.json", - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("fastdeploy.entrypoints.cli.tokenizer.export_vocabulary") as mock_export: - with patch("builtins.print"): - main(args) - - mock_export.assert_called_once_with(self.mock_tokenizer, "/path/to/vocab.json") - - def test_multiple_operations(self): - self.mock_tokenizer.encode.return_value = [1] - self.mock_tokenizer.decode.return_value = "test" - - args = argparse.Namespace( - model_name_or_path="test/model", - encode="hello", - decode="1", - vocab_size=True, - info=True, - vocab_export="/path/to/vocab", - enable_mm=False, - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=self.mock_preprocessor): - with patch("fastdeploy.entrypoints.cli.tokenizer.get_vocab_size", return_value=100): - with patch("fastdeploy.entrypoints.cli.tokenizer.get_tokenizer_info", return_value={"size": 100}): - with patch("fastdeploy.entrypoints.cli.tokenizer.export_vocabulary") as mock_export: - with patch("builtins.print") as mock_print: - main(args) - - # 验证所有操作都被调用 - self.assertEqual(self.mock_tokenizer.encode.call_count, 1) - self.assertEqual(self.mock_tokenizer.decode.call_count, 1) - mock_export.assert_called_once() - - # 验证操作计数 - mock_print.assert_any_call("Completed 5 operation(s)") - - -class TestIntegration(unittest.TestCase): - """集成测试""" - - def test_full_workflow(self): - # 测试完整的CLI工作流程 - subcommand = TokenizerSubcommand() - mock_subparsers = MagicMock() - mock_parser = MagicMock() - mock_subparsers.add_parser.return_value = mock_parser - - # 初始化解析器 - subcommand.subparser_init(mock_subparsers) - - # 测试cmd方法 - args = argparse.Namespace( - model_name_or_path="test/model", encode="test", decode=None, vocab_size=True, info=False, vocab_export=None - ) - - with patch("fastdeploy.entrypoints.cli.tokenizer.main") as mock_main: - subcommand.cmd(args) - mock_main.assert_called_once_with(args) - - def test_main_functionality(self): - """测试main函数的整体功能""" - args = argparse.Namespace( - model_name_or_path="test/model", - encode="hello", - decode="1", - vocab_size=True, - info=True, - vocab_export=None, - enable_mm=False, - ) - - mock_tokenizer = MagicMock() - mock_tokenizer.encode.return_value = [1] - mock_tokenizer.decode.return_value = "hello" - - mock_preprocessor = MagicMock() - mock_preprocessor.create_processor.return_value.tokenizer = mock_tokenizer - - with patch("fastdeploy.entrypoints.cli.tokenizer.InputPreprocessor", return_value=mock_preprocessor): - with patch("fastdeploy.entrypoints.cli.tokenizer.get_vocab_size", return_value=1000): - with patch("fastdeploy.entrypoints.cli.tokenizer.get_tokenizer_info", return_value={"info": "test"}): - with patch("builtins.print") as mock_print: - main(args) - - # 验证基本功能正常 - self.assertTrue(mock_print.called) - # 验证encode和decode被调用 - mock_tokenizer.encode.assert_called_once_with("hello") - mock_tokenizer.decode.assert_called_once_with([1]) - - -# class TestTokenizerCli(unittest.TestCase): -# def setUp(self): -# self.test_args = argparse.Namespace() -# self.test_args.model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT" -# self.test_args.encode = "Hello, world!" -# self.test_args.decode = "[1, 2, 3]" -# self.test_args.vocab_size = True -# self.test_args.info = True -# self.tmpdir = tempfile.TemporaryDirectory() -# self.test_args.vocab_export = os.path.join(self.tmpdir.name, "vocab.txt") - -# def tearDown(self): -# self.tmpdir.cleanup() - -# def test_main(self): -# result = main(self.test_args) -# self.assertIsNotNone(result) -# self.assertTrue(os.path.exists(self.test_args.vocab_export)) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/tests/operators/test_set_value_by_flags_and_idx.py b/tests/operators/test_set_value_by_flags_and_idx.py index 6861ca218..aada9e260 100644 --- a/tests/operators/test_set_value_by_flags_and_idx.py +++ b/tests/operators/test_set_value_by_flags_and_idx.py @@ -34,7 +34,7 @@ def set_value_by_flags_and_idx_numpy( current_step_idx = step_idx[i] if seq_len_enc == 0 and seq_len_dec == 0: continue - if current_step_idx >= 0: + if current_step_idx > 0: if seq_len_enc > 0: token_idx = seq_len_enc - 1 token_to_assign = input_ids[i, token_idx]