mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
fix rl config local rank (#2957)
This commit is contained in:
@@ -57,6 +57,7 @@ class RolloutModelConfig:
|
||||
disable_any_whitespace: bool = True,
|
||||
enable_logprob: bool = False,
|
||||
graph_optimization_config: str = None,
|
||||
local_rank: int = 0,
|
||||
):
|
||||
# Required parameters
|
||||
self.model_name_or_path = model_name_or_path
|
||||
@@ -97,10 +98,11 @@ class RolloutModelConfig:
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.enable_logprob = enable_logprob
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.local_rank = local_rank
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the final fd config"""
|
||||
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0)
|
||||
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)
|
||||
|
Reference in New Issue
Block a user