diff --git a/fastdeploy/import_ops.py b/fastdeploy/import_ops.py index 01fe251e9..f04cd1bc7 100644 --- a/fastdeploy/import_ops.py +++ b/fastdeploy/import_ops.py @@ -43,8 +43,7 @@ def import_custom_ops(package, module_name, global_ns): logger.warning(f"Failed to import op {func_name}: {e}") except Exception: - logger.warning( - f"Ops of {package} import failed, it may be not compiled.") + logger.warning(f"Ops of {package} import failed, it may be not compiled.") preprocess_static_op(global_ns) @@ -71,20 +70,24 @@ def wrap_unified_op(original_cpp_ext_op, original_custom_op): original_cpp_ext_op: Original C++ extension operator function. original_custom_op: Original custom operator function. """ + try: - @paddle.jit.marker.unified - @functools.wraps(original_custom_op) - def unified_op(*args, **kwargs): - if paddle.in_dynamic_mode(): - res = original_cpp_ext_op(*args, **kwargs) - if res is None: - return None - # TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension - if isinstance(res, list) and len(res) == 1: - return res[0] - return res - return original_custom_op(*args, **kwargs) + @paddle.jit.marker.unified + @functools.wraps(original_custom_op) + def unified_op(*args, **kwargs): + if paddle.in_dynamic_mode(): + res = original_cpp_ext_op(*args, **kwargs) + if res is None: + return None + # TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension + if isinstance(res, list) and len(res) == 1: + return res[0] + return res + return original_custom_op(*args, **kwargs) + except: + unified_op = None + logger.warning("Paddle version not support JIT mode.") return unified_op diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index dcb95ea2d..acf01da8f 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -58,6 +58,7 @@ class RolloutModelConfig: disable_any_whitespace: bool = True, enable_logprob: bool = False, graph_optimization_config: str = None, + local_rank: int = 0 ): # Required parameters self.model_name_or_path = model_name_or_path @@ -98,10 +99,11 @@ class RolloutModelConfig: self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob self.graph_optimization_config = graph_optimization_config + self.local_rank = local_rank def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) def initialize(self): """Initialize the final fd config""" - return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=0) + return initialize_fd_config(self, ranks=self.tensor_parallel_size, local_rank=self.local_rank)