mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Support return logprob of generated tokens (#2784)
* online chat support logprobs * check xpu * check vl_gpu_model_runner * only cuda support logprob * get_worker() check platform --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -173,6 +173,13 @@ class SamplingParams:
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
|
||||
# quietly accept -1 as disabled, but prefer 0
|
||||
if self.top_k < -1:
|
||||
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
|
||||
f"got {self.top_k}.")
|
||||
if not isinstance(self.top_k, int):
|
||||
raise TypeError(
|
||||
f"top_k must be an integer, got {type(self.top_k).__name__}")
|
||||
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
@@ -192,6 +199,9 @@ class SamplingParams:
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20:
|
||||
raise ValueError(
|
||||
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got "
|
||||
|
Reference in New Issue
Block a user