diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 82074b70c..1fe797868 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -60,6 +60,7 @@ class RolloutModelConfig: early_stop_config: str = None, local_rank: int = 0, moba_attention_config: str = None, + data_parallel_size: int = 1, ): # Required parameters self.model = model_name_or_path @@ -95,6 +96,7 @@ class RolloutModelConfig: self.splitwise_role = splitwise_role self.expert_parallel_size = expert_parallel_size self.enable_expert_parallel = enable_expert_parallel + self.data_parallel_size = data_parallel_size self.ori_vocab_size = ori_vocab_size self.quantization = quantization self.guided_decoding_backend = guided_decoding_backend