mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Logprobs]Support prompt_logprobs and max_logprobs (#4897)
* add prompt logprobs * trigger ci * fix unitest * Update fastdeploy/config.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/entrypoints/llm.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix max_logprobs --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
@@ -206,10 +207,12 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20:
|
||||
if self.logprobs is not None and self.logprobs < -1:
|
||||
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
|
||||
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
Reference in New Issue
Block a user