mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)
* add prompt logprobs * Merge prompt_logprobs_tensors and prompt_logprobs * fix param check * trigger ci * fix unitest * fix logprobs bug
This commit is contained in:
@@ -16,12 +16,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
@@ -207,12 +208,17 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs < -1:
|
||||
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
|
||||
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
|
||||
|
||||
if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0)
|
||||
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
|
||||
if self.prompt_logprobs is not None:
|
||||
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
|
||||
else: # True (1)
|
||||
if self.logprobs is not None and self.logprobs < -1:
|
||||
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
|
||||
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
Reference in New Issue
Block a user