[Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3536)

* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing

* infer engine support temp_scaled_logprobs and top_p_normalized_logprobs

* code check

* code check

* fix tokenizer.decoder(-1), return 'Invalid Token'

* check seq len time shape

* logprob clip inf

* code check

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
chen
2025-08-25 14:11:18 +08:00
committed by GitHub
parent b7890cbe8d
commit 2136990144
5 changed files with 84 additions and 4 deletions

View File

@@ -95,6 +95,9 @@ class SamplingParams:
reasoning_max_tokens: Optional[int] = None
min_tokens: int = 1
logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None
@classmethod

View File

@@ -333,6 +333,9 @@ class CompletionRequest(BaseModel):
echo: Optional[bool] = False
frequency_penalty: Optional[float] = None
logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: Optional[float] = None
@@ -461,6 +464,11 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
# remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
@@ -515,6 +523,8 @@ class ChatCompletionRequest(BaseModel):
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
# parse request model into dict, priority: request params > metadata params
if self.metadata is not None:

View File

@@ -15,7 +15,7 @@
"""
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional
import paddle
@@ -46,3 +46,6 @@ class SamplingMetadata:
max_num_logprobs: Optional[int] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None
temp_scaled_logprobs: Optional[paddle.Tensor] = None
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
share_inputs: Optional[Dict[str, paddle.Tensor]] = None

View File

@@ -37,6 +37,18 @@ from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
def top_p_normalize_probs_paddle(
probs: paddle.Tensor,
top_ps: paddle.Tensor,
):
probs_idx = probs.argsort(axis=-1, descending=True)
probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1)
probs_sum = paddle.cumsum(probs_sort, axis=-1)
probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort)
probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True))
return paddle.put_along_axis(paddle.zeros_like(probs_sort), probs_idx, probs_sort, -1)
class SamplerProcessor:
"""
SamplingProcessor for guided decoding.
@@ -194,9 +206,45 @@ class Sampler(nn.Layer):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
""" """
return F.log_softmax(logits, axis=-1)
last_logits = logits
real_bsz = last_logits.shape[0]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
share_inputs = sampling_metadata.share_inputs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_req_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]
seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder
real_req_mask = seq_lens_time_sum > 0
top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask)
real_req_top_p = sampling_metadata.top_p[:real_bsz]
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0)
if top_p_req_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_req_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
@@ -221,6 +269,7 @@ class Sampler(nn.Layer):
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
@@ -247,7 +296,7 @@ class Sampler(nn.Layer):
""" """
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
logits = self.processor.apply_token_mask(logits, skip_idx_list)

View File

@@ -267,6 +267,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False)
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get(
"top_p_normalized_logprobs", False
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
@@ -431,6 +435,12 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
request, "presence_penalty", 0.0
)
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "temp_scaled_logprobs", False
)
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "top_p_normalized_logprobs", False
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
@@ -543,6 +553,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["presence_score"] = paddle.full(
[max_num_seqs, 1], self.model_config.presence_score, dtype="float32"
)
self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
self.share_inputs["max_dec_len"] = paddle.full(
@@ -748,6 +760,9 @@ class GPUModelRunner(ModelRunnerBase):
bad_words_token_ids=self.share_inputs["bad_tokens"],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
share_inputs=self.share_inputs,
)
def load_model(self) -> None: