fix rl config local rank (#2957)

This commit is contained in:
gaoziyuan
2025-07-22 19:39:54 +08:00
committed by GitHub
parent 9b84d51e25
commit dbe6225b33

View File

@@ -57,6 +57,7 @@ class RolloutModelConfig:
disable_any_whitespace: bool = True, disable_any_whitespace: bool = True,
enable_logprob: bool = False, enable_logprob: bool = False,
graph_optimization_config: str = None, graph_optimization_config: str = None,
local_rank: int = 0,
): ):
# Required parameters # Required parameters
self.model_name_or_path = model_name_or_path self.model_name_or_path = model_name_or_path
@@ -97,10 +98,11 @@ class RolloutModelConfig:
self.disable_any_whitespace = disable_any_whitespace self.disable_any_whitespace = disable_any_whitespace
self.enable_logprob = enable_logprob self.enable_logprob = enable_logprob
self.graph_optimization_config = graph_optimization_config self.graph_optimization_config = graph_optimization_config
self.local_rank = local_rank
def __str__(self): def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
def initialize(self): def initialize(self):
"""Initialize the final fd config""" """Initialize the final fd config"""
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0) return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)