From dbe6225b33fc47f582126b8ca8f9416894c6bcac Mon Sep 17 00:00:00 2001 From: gaoziyuan <88373061+gzy19990617@users.noreply.github.com> Date: Tue, 22 Jul 2025 19:39:54 +0800 Subject: [PATCH] fix rl config local rank (#2957) --- fastdeploy/rl/rollout_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 92bf0723a..40be4d774 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -57,6 +57,7 @@ class RolloutModelConfig: disable_any_whitespace: bool = True, enable_logprob: bool = False, graph_optimization_config: str = None, + local_rank: int = 0, ): # Required parameters self.model_name_or_path = model_name_or_path @@ -97,10 +98,11 @@ class RolloutModelConfig: self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob self.graph_optimization_config = graph_optimization_config + self.local_rank = local_rank def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) def initialize(self): """Initialize the final fd config""" - return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0) + return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)