mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 10:00:29 +08:00
[Fix]Fix vl when import fastdeploy and fix rl config rank bug (#2953)
* support vl ori_vacab_size * support trainer_degree in name_mapping * fix * fix import error * fix local rank
This commit is contained in:
@@ -58,6 +58,7 @@ class RolloutModelConfig:
|
||||
disable_any_whitespace: bool = True,
|
||||
enable_logprob: bool = False,
|
||||
graph_optimization_config: str = None,
|
||||
local_rank: int = 0
|
||||
):
|
||||
# Required parameters
|
||||
self.model_name_or_path = model_name_or_path
|
||||
@@ -98,10 +99,11 @@ class RolloutModelConfig:
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.enable_logprob = enable_logprob
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.local_rank = local_rank
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the final fd config"""
|
||||
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0)
|
||||
return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)
|
||||
|
Reference in New Issue
Block a user