mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support return logprob of generated tokens (#2784)
* online chat support logprobs * check xpu * check vl_gpu_model_runner * only cuda support logprob * get_worker() check platform --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
|
||||
top_k_top_p_sampling)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
@@ -188,14 +189,65 @@ class Sampler(nn.Layer):
|
||||
""" pre process before running """
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
|
||||
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
"""
|
||||
return F.log_softmax(logits, axis=-1)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs,
|
||||
num_logprobs,
|
||||
axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
skip_idx_list: List[int] = [],
|
||||
) -> paddle.Tensor:
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
"""
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
@@ -215,8 +267,19 @@ class Sampler(nn.Layer):
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
return next_tokens
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=next_tokens,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
class SpeculativeSampler(nn.Layer):
|
||||
|
Reference in New Issue
Block a user