mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3552)
* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing * infer engine support temp_scaled_logprobs and top_p_normalized_logprobs * delete some code * code check * code check and add doc * fix tokenizer.decoder(-1), return 'Invalid Token' * add ci for temp_scaled and top_p logprobs * check test * check seq len time shape * logprob clip inf --------- Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -51,3 +51,6 @@ class SamplingMetadata:
|
||||
stop_flags: Optional[paddle.Tensor] = None
|
||||
prompt_ids: Optional[paddle.Tensor] = None
|
||||
prompt_lens: Optional[paddle.Tensor] = None
|
||||
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
||||
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
||||
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
||||
|
@@ -40,6 +40,18 @@ from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
|
||||
def top_p_normalize_probs_paddle(
|
||||
probs: paddle.Tensor,
|
||||
top_ps: paddle.Tensor,
|
||||
):
|
||||
probs_idx = probs.argsort(axis=-1, descending=True)
|
||||
probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1)
|
||||
probs_sum = paddle.cumsum(probs_sort, axis=-1)
|
||||
probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort)
|
||||
probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True))
|
||||
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
"""
|
||||
SamplingProcessor for guided decoding.
|
||||
@@ -207,9 +219,45 @@ class Sampler(nn.Layer):
|
||||
"""pre process before running"""
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
|
||||
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
|
||||
def compute_logprobs(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> paddle.Tensor:
|
||||
""" """
|
||||
return F.log_softmax(logits, axis=-1)
|
||||
last_logits = logits
|
||||
real_bsz = last_logits.shape[0]
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
if temp_scaled_logprobs is not None:
|
||||
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||
temperature = sampling_metadata.temperature[:real_bsz]
|
||||
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
|
||||
last_logits = last_logits / temp_temperature
|
||||
|
||||
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||
top_p_logprob = None
|
||||
top_p_req_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
|
||||
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
|
||||
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]
|
||||
seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder
|
||||
real_req_mask = seq_lens_time_sum > 0
|
||||
top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask)
|
||||
real_req_top_p = sampling_metadata.top_p[:real_bsz]
|
||||
# Normalize logprobs if top_p normalization is enabled
|
||||
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
||||
top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0)
|
||||
if top_p_req_mask.any():
|
||||
probs = F.softmax(last_logits, axis=-1)
|
||||
probs = top_p_normalize_probs_paddle(probs, real_req_top_p)
|
||||
top_p_logprob = paddle.log(probs)
|
||||
if top_p_logprob is not None:
|
||||
last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs)
|
||||
return last_logprobs
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
@@ -234,6 +282,7 @@ class Sampler(nn.Layer):
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
@@ -260,7 +309,7 @@ class Sampler(nn.Layer):
|
||||
""" """
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
||||
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
|
Reference in New Issue
Block a user