diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 92bf0723a..40be4d774 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -57,6 +57,7 @@ class RolloutModelConfig: disable_any_whitespace: bool = True, enable_logprob: bool = False, graph_optimization_config: str = None, + local_rank: int = 0, ): # Required parameters self.model_name_or_path = model_name_or_path @@ -97,10 +98,11 @@ class RolloutModelConfig: self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob self.graph_optimization_config = graph_optimization_config + self.local_rank = local_rank def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) def initialize(self): """Initialize the final fd config""" - return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0) + return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)