Optimizing the performance of think length limit using custom operators (#4279)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* delete impl

* delete min_length&max_length

* support limit thinking content strategy

* fix

* fix

* fix

* update

* fix set_value_by_flags_and_idx

* fix

* fix

* fix

* fix

* update

* fix

* fix

* fix typo

* fix ci

* fix

* fix

* support mtp

* fix

* fix

* update

* update
This commit is contained in:
Yuanle Liu
2025-10-20 21:09:13 +08:00
committed by GitHub
parent 36af88ff3f
commit cef3164c3b
31 changed files with 747 additions and 1032 deletions

View File

@@ -79,15 +79,90 @@ else:
step_reschedule,
update_inputs_v1,
speculate_step_reschedule,
limit_thinking_content_length_v1,
limit_thinking_content_length_v2,
speculate_limit_thinking_content_length_v1,
speculate_limit_thinking_content_length_v2,
)
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
def limit_thinking_content_length(
limit_strategy: str,
sampled_token_ids: paddle.Tensor,
max_think_lens: paddle.Tensor,
step_idx: paddle.Tensor,
limit_think_status: paddle.Tensor,
think_end_id: int,
line_break_id: int = None,
):
if limit_strategy == "</think>":
# for ernie4_5_vl
limit_thinking_content_length_v1(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
# for ernie_x1
assert line_break_id > 0
limit_thinking_content_length_v2(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
think_end_id,
line_break_id,
)
else:
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
def speculate_limit_thinking_content_length(
limit_strategy: str,
accept_tokens: paddle.Tensor,
max_think_lens: paddle.Tensor,
step_idx: paddle.Tensor,
limit_think_status: paddle.Tensor,
accept_num: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
think_end_id: int,
line_break_id: int = None,
):
if limit_strategy == "</think>":
# for ernie4_5_vl
speculate_limit_thinking_content_length_v1(
accept_tokens,
max_think_lens,
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
# for ernie_x1
assert line_break_id > 0
speculate_limit_thinking_content_length_v2(
accept_tokens,
max_think_lens,
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
think_end_id,
line_break_id,
)
else:
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
def pre_process(
input_ids: paddle.Tensor,
seq_lens_this_time: int,
@@ -185,46 +260,19 @@ def post_process_normal(
save_each_rank: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = -1,
line_break_id: int = -1,
) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation."""
# handle vl:
if model_output.think_end_id != -1:
thinking_mask = model_output.enable_thinking[: sampler_output.sampled_token_ids.shape[0]]
exists_think_end = (sampler_output.sampled_token_ids == model_output.think_end_id) & thinking_mask
paddle.assign(
paddle.where(
exists_think_end,
model_output.need_think_end - 1,
model_output.need_think_end,
),
model_output.need_think_end,
)
reasoning_index_update_cond = model_output.need_think_end.cast("bool") & thinking_mask
paddle.assign(
paddle.where(
reasoning_index_update_cond,
model_output.reasoning_index - 1,
model_output.reasoning_index,
),
model_output.reasoning_index,
)
stop_wo_think = ((model_output.reasoning_index == 0)) & (model_output.need_think_end > 0)
stop_wo_think = stop_wo_think & thinking_mask
sampler_output.sampled_token_ids = paddle.where(
stop_wo_think,
model_output.think_end_id,
sampler_output.sampled_token_ids,
)
paddle.assign(
paddle.where(
stop_wo_think,
model_output.need_think_end - 1,
model_output.need_think_end,
),
model_output.need_think_end,
if think_end_id > 0:
limit_thinking_content_length(
limit_strategy=envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR,
sampled_token_ids=sampler_output.sampled_token_ids,
max_think_lens=share_inputs["max_think_lens"],
step_idx=share_inputs["step_idx"],
limit_think_status=share_inputs["limit_think_status"],
think_end_id=think_end_id,
line_break_id=line_break_id,
)
# 1. Set stop value
paddle.assign(
@@ -337,10 +385,25 @@ def post_process_normal(
def post_process_specualate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
save_each_rank: bool = False,
skip_save_output: bool = False,
think_end_id: int = -1,
line_break_id: int = -1,
):
""""""
if think_end_id > 0:
speculate_limit_thinking_content_length(
limit_strategy=envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR,
accept_tokens=share_inputs["accept_tokens"],
max_think_lens=share_inputs["max_think_lens"],
step_idx=share_inputs["step_idx"],
limit_think_status=share_inputs["limit_think_status"],
accept_num=share_inputs["accept_num"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
think_end_id=think_end_id,
line_break_id=line_break_id,
)
speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
@@ -403,10 +466,20 @@ def post_process(
speculative_decoding: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = -1,
line_break_id: int = -1,
) -> None:
"""Post-processing steps after completing a single token generation."""
if speculative_decoding:
post_process_specualate(sampler_output, model_output, save_each_rank, skip_save_output)
post_process_specualate(
sampler_output,
model_output,
share_inputs,
save_each_rank,
skip_save_output,
think_end_id,
line_break_id,
)
else:
post_process_normal(
sampler_output,
@@ -416,6 +489,8 @@ def post_process(
save_each_rank,
skip_save_output,
async_output_queue,
think_end_id,
line_break_id,
)