[Feature] Support return logprob of generated tokens (#2784)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner

* only cuda support logprob

* get_worker() check platform

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chen
2025-07-10 15:47:42 +08:00
committed by GitHub
parent 39d2a1de46
commit 823a47e64a
21 changed files with 592 additions and 105 deletions

View File

@@ -63,6 +63,7 @@ class GPUModelRunner(ModelRunnerBase):
self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.guided_backend = None
if self.fd_config.parallel_config.guided_decoding_backend != "off":
@@ -612,6 +613,7 @@ class GPUModelRunner(ModelRunnerBase):
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
)
def load_model(self) -> None:
@@ -816,15 +818,15 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampled_token_ids = self.sampler(logits,
sampler_output = self.sampler(logits,
self.sampling_metadata)
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
else:
self.sampler(logits, self.sampling_metadata,
self.parallel_config.max_model_len,
self.share_inputs)
sampled_token_ids = None
sampler_output = None
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"], 0)
@@ -864,7 +866,7 @@ class GPUModelRunner(ModelRunnerBase):
accept_num=self.share_inputs["accept_num"]
if self.speculative_decoding else None)
post_process(sampled_token_ids=sampled_token_ids,
post_process(sampler_output=sampler_output,
model_output=model_output_data,
speculative_decoding=self.speculative_decoding,
skip_save_output=True)
@@ -1051,18 +1053,18 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampled_token_ids = self.sampler(
sampler_output = self.sampler(
logits,
self.sampling_metadata,
skip_idx_list,
)
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
else:
self.sampler(logits, self.sampling_metadata,
self.parallel_config.max_model_len, self.share_inputs)
sampled_token_ids = None
sampler_output = None
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"], 0)
@@ -1105,7 +1107,7 @@ class GPUModelRunner(ModelRunnerBase):
skip_save_output = True
else:
skip_save_output = False
post_process(sampled_token_ids=sampled_token_ids,
post_process(sampler_output=sampler_output,
model_output=model_output_data,
save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding,