mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
fix enable_logprob not in rl_config (#2808)
This commit is contained in:
@@ -58,6 +58,7 @@ class RolloutModelConfig:
|
|||||||
max_capture_batch_size: int = 64,
|
max_capture_batch_size: int = 64,
|
||||||
guided_decoding_backend: str = "off",
|
guided_decoding_backend: str = "off",
|
||||||
disable_any_whitespace: bool = True,
|
disable_any_whitespace: bool = True,
|
||||||
|
enable_logprob: bool = False,
|
||||||
):
|
):
|
||||||
# Required parameters
|
# Required parameters
|
||||||
self.model_name_or_path = model_name_or_path
|
self.model_name_or_path = model_name_or_path
|
||||||
@@ -99,6 +100,7 @@ class RolloutModelConfig:
|
|||||||
self.max_capture_batch_size = max_capture_batch_size
|
self.max_capture_batch_size = max_capture_batch_size
|
||||||
self.guided_decoding_backend = guided_decoding_backend
|
self.guided_decoding_backend = guided_decoding_backend
|
||||||
self.disable_any_whitespace = disable_any_whitespace
|
self.disable_any_whitespace = disable_any_whitespace
|
||||||
|
self.enable_logprob = enable_logprob
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||||
|
Reference in New Issue
Block a user