mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-26 18:10:32 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			114 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """import ops"""
 | |
| import functools
 | |
| import importlib
 | |
| import inspect
 | |
| 
 | |
| import paddle
 | |
| 
 | |
| from fastdeploy.utils import llm_logger as logger
 | |
| 
 | |
| 
 | |
| def import_custom_ops(package, module_name, global_ns):
 | |
|     """
 | |
|     Imports custom operations from a specified module within a package and adds them to a global namespace.
 | |
| 
 | |
|     Args:
 | |
|         package (str): The name of the package containing the module.
 | |
|         module_name (str): The name of the module within the package.
 | |
|         global_ns (dict): The global namespace to add the imported functions to.
 | |
|     """
 | |
|     try:
 | |
|         module = importlib.import_module(module_name, package=package)
 | |
|         functions = inspect.getmembers(module)
 | |
|         for func_name, func in functions:
 | |
|             if func_name.startswith("__") or func_name == "_C_ops":
 | |
|                 continue
 | |
|             logger.debug(f"Import {func_name} from {package}")
 | |
|             try:
 | |
|                 global_ns[func_name] = func
 | |
|             except Exception as e:
 | |
|                 logger.warning(f"Failed to import op {func_name}: {e}")
 | |
| 
 | |
|     except Exception:
 | |
|         logger.warning(f"Ops of {package} import failed, it may be not compiled.")
 | |
| 
 | |
|     preprocess_static_op(global_ns)
 | |
| 
 | |
| 
 | |
| def rename_imported_op(old_name, new_name, global_ns):
 | |
|     """
 | |
|     Renames an imported operation in the global namespace.
 | |
| 
 | |
|     Args:
 | |
|         old_name (str): The original name of the operation in the global namespace.
 | |
|         new_name (str): The new name to be given to the operation.
 | |
|         global_ns (dict): The global namespace where the operation is stored.
 | |
|     """
 | |
|     if old_name not in global_ns:
 | |
|         return
 | |
|     global_ns[new_name] = global_ns[old_name]
 | |
|     del global_ns[old_name]
 | |
| 
 | |
| 
 | |
| def wrap_unified_op(original_cpp_ext_op, original_custom_op):
 | |
|     """
 | |
|     Wrap a static operator into a unified operator with runtime dispatching.
 | |
|     Args:
 | |
|         original_cpp_ext_op: Original C++ extension operator function.
 | |
|         original_custom_op: Original custom operator function.
 | |
|     """
 | |
|     try:
 | |
| 
 | |
|         @paddle.jit.marker.unified
 | |
|         @functools.wraps(original_custom_op)
 | |
|         def unified_op(*args, **kwargs):
 | |
|             if paddle.in_dynamic_mode():
 | |
|                 res = original_cpp_ext_op(*args, **kwargs)
 | |
|                 if res is None:
 | |
|                     return None
 | |
|                 # TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension
 | |
|                 if isinstance(res, list) and len(res) == 1:
 | |
|                     return res[0]
 | |
|                 return res
 | |
|             return original_custom_op(*args, **kwargs)
 | |
| 
 | |
|     except:
 | |
|         unified_op = None
 | |
|         logger.warning("Paddle version not support JIT mode.")
 | |
|     return unified_op
 | |
| 
 | |
| 
 | |
| def preprocess_static_op(global_ns):
 | |
|     """
 | |
|     Transforms operator/function references in the global namespace based on the presence of 'static_op_' prefixes.
 | |
| 
 | |
|     Args:
 | |
|         global_ns (dict): The global namespace (typically globals()) to modify.
 | |
|         flag (bool): Determines transformation behavior.
 | |
|     """
 | |
|     static_op_prefix = "static_op_"
 | |
|     static_op_names = [k for k in global_ns if k.startswith(static_op_prefix)]
 | |
| 
 | |
|     for static_op_name in static_op_names:
 | |
|         op_name = static_op_name.removeprefix(static_op_prefix)
 | |
|         if op_name not in global_ns:
 | |
|             global_ns[op_name] = global_ns[static_op_name]
 | |
|             continue
 | |
| 
 | |
|         original_cpp_ext_op = global_ns[op_name]
 | |
|         original_custom_op = global_ns[static_op_name]
 | |
|         global_ns[op_name] = wrap_unified_op(original_cpp_ext_op, original_custom_op)
 | 
