[Feature] Online Chat API Support Return logprobs (#2777)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner and xpu_model_runner

* get_worker() check platform
This commit is contained in:
chen
2025-07-10 16:33:40 +08:00
committed by GitHub
parent 24f934f1f9
commit d33105baeb
22 changed files with 608 additions and 114 deletions

View File

@@ -60,6 +60,7 @@ class GCUModelRunner(ModelRunnerBase):
self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.guided_backend = None
if self.fd_config.parallel_config.guided_decoding_backend != "off":
@@ -602,6 +603,7 @@ class GCUModelRunner(ModelRunnerBase):
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
)
def load_model(self) -> None:
@@ -806,15 +808,15 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampled_token_ids = self.sampler(logits,
sampler_output = self.sampler(logits,
self.sampling_metadata)
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
else:
self.sampler(logits, self.sampling_metadata,
self.parallel_config.max_model_len,
self.share_inputs)
sampled_token_ids = None
sampler_output = None
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"], 0)
@@ -854,7 +856,7 @@ class GCUModelRunner(ModelRunnerBase):
accept_num=self.share_inputs["accept_num"]
if self.speculative_decoding else None)
post_process(sampled_token_ids=sampled_token_ids,
post_process(sampler_output=sampler_output,
model_output=model_output_data,
speculative_decoding=self.speculative_decoding,
skip_save_output=True)
@@ -1036,18 +1038,18 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampled_token_ids = self.sampler(
sampler_output = self.sampler(
logits,
self.sampling_metadata,
skip_idx_list,
)
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
else:
self.sampler(logits, self.sampling_metadata,
self.parallel_config.max_model_len, self.share_inputs)
sampled_token_ids = None
sampler_output = None
if self.parallel_config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"], 0)
@@ -1090,7 +1092,7 @@ class GCUModelRunner(ModelRunnerBase):
skip_save_output = True
else:
skip_save_output = False
post_process(sampled_token_ids=sampled_token_ids,
post_process(sampler_output=sampler_output,
model_output=model_output_data,
save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding,