mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] Online Chat API Support Return logprobs (#2777)
* online chat support logprobs * check xpu * check vl_gpu_model_runner and xpu_model_runner * get_worker() check platform
This commit is contained in:
@@ -15,11 +15,80 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
"""
|
||||
"""
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: list[int]
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
"""slice"""
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
)
|
||||
|
||||
|
||||
class LogprobsTensors(NamedTuple):
|
||||
"""
|
||||
"""
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: paddle.Tensor
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: paddle.Tensor
|
||||
# [num_reqs]
|
||||
selected_token_ranks: paddle.Tensor
|
||||
|
||||
def tolists(self):
|
||||
"""Convert to lists."""
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def empty_cpu(num_positions: int,
|
||||
num_tokens_per_position: int) -> "LogprobsTensors":
|
||||
"""Create empty LogprobsTensors on CPU."""
|
||||
|
||||
logprob_token_ids = paddle.empty(
|
||||
[num_positions, num_tokens_per_position],
|
||||
dtype=paddle.int64).cpu()
|
||||
logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32)
|
||||
selected_token_ranks = paddle.empty([num_positions],
|
||||
dtype=paddle.int64).cpu()
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=selected_token_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
"""
|
||||
"""
|
||||
|
||||
# [num_reqs, max_num_generated_tokens]
|
||||
# Different requests can have different number of generated tokens.
|
||||
# All requests are padded to max_num_generated_tokens.
|
||||
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
|
||||
sampled_token_ids: paddle.Tensor
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
@dataclass
|
||||
class ModelOutputData:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user