diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 045214f76..4e9de922c 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -58,6 +58,7 @@ class RolloutModelConfig: max_capture_batch_size: int = 64, guided_decoding_backend: str = "off", disable_any_whitespace: bool = True, + enable_logprob: bool = False, ): # Required parameters self.model_name_or_path = model_name_or_path @@ -99,6 +100,7 @@ class RolloutModelConfig: self.max_capture_batch_size = max_capture_batch_size self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace + self.enable_logprob = enable_logprob def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())