[Feature] Support return logprob of generated tokens (#2784)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner

* only cuda support logprob

* get_worker() check platform

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chen
2025-07-10 15:47:42 +08:00
committed by GitHub
parent 39d2a1de46
commit 823a47e64a
21 changed files with 592 additions and 105 deletions

View File

@@ -20,14 +20,16 @@ import paddle
from fastdeploy import envs
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
speculate_clear_accept_nums, speculate_get_output_padding_offset,
speculate_get_padding_offset, speculate_get_seq_lens_output,
speculate_save_output, speculate_set_value_by_flags_and_idx,
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
step_paddle, step_system_cache, update_inputs, step_reschedule)
get_padding_offset, save_output, save_output_topk,
set_stop_value_multi_ends, speculate_clear_accept_nums,
speculate_get_output_padding_offset, speculate_get_padding_offset,
speculate_get_seq_lens_output, speculate_save_output,
speculate_set_value_by_flags_and_idx, speculate_step_paddle,
speculate_step_system_cache, speculate_update_v3, step_paddle,
step_reschedule, step_system_cache, update_inputs)
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import ModelOutputData
from fastdeploy.worker.output import (ModelOutputData, ModelRunnerOutput,
SamplerOutput)
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
@@ -102,10 +104,10 @@ def pre_process(
cu_seqlens_k, output_cum_offsets, output_padding_offset)
def post_process_normal(sampled_token_ids: paddle.Tensor,
def post_process_normal(sampler_output: SamplerOutput,
model_output: ModelOutputData,
save_each_rank: bool = False,
skip_save_output: bool = False) -> None:
skip_save_output: bool = False) -> ModelRunnerOutput:
""" Post-processing steps after completing a single token generation. """
# 1. Set stop value
paddle.assign(
@@ -123,7 +125,7 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
model_output.stop_flags,
)
# TODO(gongshaotian): Add use_stop_seqs
set_stop_value_multi_ends(sampled_token_ids, model_output.stop_flags,
set_stop_value_multi_ends(sampler_output.sampled_token_ids, model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens, False) # multi ends
@@ -138,18 +140,28 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampled_token_ids,
sampler_output.sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank, # save_each_rank
)
if sampler_output.logprobs_tensors is None:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank, # save_each_rank
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)
def post_process_specualate(model_output, skip_save_output: bool = False):
""""""
@@ -193,7 +205,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
)
def post_process(sampled_token_ids: paddle.Tensor,
def post_process(sampler_output: SamplerOutput,
model_output: ModelOutputData,
save_each_rank: bool = False,
speculative_decoding: bool = False,
@@ -202,7 +214,7 @@ def post_process(sampled_token_ids: paddle.Tensor,
if speculative_decoding:
post_process_specualate(model_output, skip_save_output)
else:
post_process_normal(sampled_token_ids, model_output, save_each_rank,
post_process_normal(sampler_output, model_output, save_each_rank,
skip_save_output)
@@ -217,7 +229,7 @@ def step_cuda(
TODO(gongshaotian): normalization name
"""
if speculative_config.method is not None:
if enable_prefix_caching:
speculate_step_system_cache(