[Feature][Executor] GPU Model Runner Supports prompt_logprobs and max_logprobs (#4769)

This commit is contained in:
chen
2025-11-05 10:43:25 +08:00
committed by GitHub
parent 74722308f2
commit 1c3ca48128
13 changed files with 203 additions and 22 deletions

View File

@@ -92,7 +92,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import Scat
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput
class GPUModelRunner(ModelRunnerBase):
@@ -112,8 +112,12 @@ class GPUModelRunner(ModelRunnerBase):
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.max_logprobs = fd_config.model_config.max_logprobs
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
self.vocal_size = self.fd_config.model_config.vocab_size
self.prompt_logprobs_reqs: dict[str, Request] = {}
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}
# VL model config:
if self.enable_mm:
@@ -561,6 +565,9 @@ class GPUModelRunner(ModelRunnerBase):
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
self.share_inputs["pre_ids"][idx : idx + 1] = -1
# pooling model request.sampling_params is None
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
@@ -581,6 +588,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.prompt_logprobs_reqs.pop(request.request_id, None)
self.in_progress_prompt_logprobs.pop(request.request_id, None)
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
@@ -1282,7 +1291,7 @@ class GPUModelRunner(ModelRunnerBase):
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
max_num_logprobs=self.max_logprobs if self.enable_logprob else None,
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
@@ -2086,6 +2095,8 @@ class GPUModelRunner(ModelRunnerBase):
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
if self.is_pooling_model:
hidden_states = model_output
pooler_output = self._pool(hidden_states, num_running_requests)
@@ -2228,6 +2239,7 @@ class GPUModelRunner(ModelRunnerBase):
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
prompt_logprobs_list=prompt_logprobs_list,
)
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
@@ -2485,6 +2497,9 @@ class GPUModelRunner(ModelRunnerBase):
def clear_requests(self):
"""Dynamic model loader use to clear requests use for RL"""
self.share_inputs["stop_flags"][:] = True
# prompt_logprobs
self.prompt_logprobs_reqs.clear()
self.in_progress_prompt_logprobs.clear()
def update_parameters(self, pid):
"""Dynamic model loader use to update parameters use for RL"""
@@ -2694,3 +2709,69 @@ class GPUModelRunner(ModelRunnerBase):
cumsum_seqlens=cumsum_seqlens,
)
return rope_emb_lst
def _get_prompt_logprobs_list(
self,
hidden_states: paddle.Tensor,
) -> list[Optional[LogprobsTensors]]:
if len(self.prompt_logprobs_reqs) > 0:
assert (
not self.fd_config.cache_config.enable_prefix_caching
), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching."
logprobs_mode = self.fd_config.model_config.logprobs_mode
prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None]
completed_prefill_reqs: list[Request] = []
for req_id, request in self.prompt_logprobs_reqs.items():
num_prompt_logprobs = request.sampling_params.prompt_logprobs
if request.prompt_token_ids is None or num_prompt_logprobs is None:
continue
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.vocal_size
num_tokens = request.prefill_end_index - request.prefill_start_index
num_prompt_tokens = len(request.prompt_token_ids)
logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id)
if not logprobs_tensors:
logprobs_tensors = LogprobsTensors.empty(num_prompt_tokens - 1, num_prompt_logprobs + 1)
self.in_progress_prompt_logprobs[req_id] = logprobs_tensors
start_idx = request.prefill_start_index
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(request)
prompt_logprobs_list[request.idx] = logprobs_tensors
if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue
offset = self.share_inputs["cu_seqlens_q"][request.idx]
prompt_hidden_states = hidden_states[offset : offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states)
prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits]
prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64")
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.sampler.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False)
for req in completed_prefill_reqs:
del self.prompt_logprobs_reqs[req.request_id]
del self.in_progress_prompt_logprobs[req.request_id]
return prompt_logprobs_list