mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix rl config local rank (#2957)
This commit is contained in:
@@ -57,6 +57,7 @@ class RolloutModelConfig:
|
|||||||
disable_any_whitespace: bool = True,
|
disable_any_whitespace: bool = True,
|
||||||
enable_logprob: bool = False,
|
enable_logprob: bool = False,
|
||||||
graph_optimization_config: str = None,
|
graph_optimization_config: str = None,
|
||||||
|
local_rank: int = 0,
|
||||||
):
|
):
|
||||||
# Required parameters
|
# Required parameters
|
||||||
self.model_name_or_path = model_name_or_path
|
self.model_name_or_path = model_name_or_path
|
||||||
@@ -97,10 +98,11 @@ class RolloutModelConfig:
|
|||||||
self.disable_any_whitespace = disable_any_whitespace
|
self.disable_any_whitespace = disable_any_whitespace
|
||||||
self.enable_logprob = enable_logprob
|
self.enable_logprob = enable_logprob
|
||||||
self.graph_optimization_config = graph_optimization_config
|
self.graph_optimization_config = graph_optimization_config
|
||||||
|
self.local_rank = local_rank
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialize the final fd config"""
|
"""Initialize the final fd config"""
|
||||||
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0)
|
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)
|
||||||
|
Reference in New Issue
Block a user