fix import error (#2944)

This commit is contained in:
gaoziyuan
2025-07-22 14:06:01 +08:00
committed by GitHub
parent 8020927f50
commit 0eedbdaee0

View File

@@ -70,20 +70,24 @@ def wrap_unified_op(original_cpp_ext_op, original_custom_op):
original_cpp_ext_op: Original C++ extension operator function.
original_custom_op: Original custom operator function.
"""
try:
@paddle.jit.marker.unified
@functools.wraps(original_custom_op)
def unified_op(*args, **kwargs):
if paddle.in_dynamic_mode():
res = original_cpp_ext_op(*args, **kwargs)
if res is None:
return None
# TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension
if isinstance(res, list) and len(res) == 1:
return res[0]
return res
return original_custom_op(*args, **kwargs)
@paddle.jit.marker.unified
@functools.wraps(original_custom_op)
def unified_op(*args, **kwargs):
if paddle.in_dynamic_mode():
res = original_cpp_ext_op(*args, **kwargs)
if res is None:
return None
# TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension
if isinstance(res, list) and len(res) == 1:
return res[0]
return res
return original_custom_op(*args, **kwargs)
except:
unified_op = None
logger.warning("Paddle version not support JIT mode.")
return unified_op