""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import inspect import os import re import sys import paddle import triton from paddle.base.framework import OpProtoHolder from fastdeploy import envs compile_file = triton.__path__[0] + "/tools/compile.py" link_file = triton.__path__[0] + "/tools/link.py" python_path = sys.executable def SubstituteTemplate(template, values): """ Substitute all variables in the given template string using the provided values dictionary. """ text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text def find_so_path(generated_dir, python_package_name): """ find the specified so in generated_dir, if not found it will return None. """ so_path = [] for root, dirs, files in os.walk(generated_dir): for file in files: if file.endswith(python_package_name + ".so"): so_path.append(os.path.join(root, file)) if len(so_path) == 0: return None else: assert len(so_path) == 1 return so_path[0] def multi_process_do(commands): """ Multi-threaded execution of commands. """ THREADS = 40 import multiprocessing process = [] def one_process_work(commands, thread_id): i = thread_id while i < len(commands): re = os.system(commands[i]) assert re == 0 i += THREADS for i in range(THREADS): p = multiprocessing.Process(target=one_process_work, args=(commands, i)) process.append(p) for p in process: p.start() for p in process: p.join() def extract_triton_kernel(kernel, file_name): """ Extract the triton kernel and write it to the specified file_name. Args: kernel: the triton kernel name. file_name: the file name you want to write. """ import inspect import re import textwrap fn = kernel if isinstance(kernel, triton.runtime.jit.JITFunction): fn = kernel.fn elif isinstance(kernel, triton.runtime.autotuner.Autotuner): fn = kernel.fn.fn else: AssertionError("error occurs") py_script = textwrap.dedent(inspect.getsource(fn)) # @triton.jit must only appear once # assert len(re.findall("@triton.jit", py_script)) == 1 assert len(re.findall("def ", py_script)) == 1 # assert len(re.findall("@haha()", py_script)) == 1 # py_script = py_script.replace("@haha()", "@triton.jit") py_script = py_script[py_script.find("def ") :] py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr") with open(file_name, "w") as f: f.write(py_script) f.close() template_install = """ import os generated_cu = [] for root, dirs, files in os.walk("./"): for file in files: if file.endswith(".c") or file.endswith(".cu"): generated_cu.append(os.path.join(root, file)) import paddle from paddle.utils.cpp_extension import CUDAExtension, setup def get_gencode_flags(): prop = paddle.device.cuda.get_device_properties() cc = prop.major * 10 + prop.minor return ["-gencode", "arch=compute_{{0}},code=sm_{{0}}".format(cc)] gencode_flags = get_gencode_flags() setup( name="{python_package_name}", ext_modules=CUDAExtension( sources = generated_cu, extra_compile_args={{ "cc": ["-lcuda"], "nvcc": [ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", ] + gencode_flags, }}, extra_link_args = ["-lcuda"] ), ) """ def get_op_name_with_suffix(op_name, x_list): """ Get the operator name with suffix. """ suffix = [] for x in x_list: if x % 16 == 0: suffix.append(16) elif x == 1: suffix.append(1) else: suffix.append(0) return op_name + "_".join([str(i) for i in suffix]) def get_value_hint(x): """ Get the value hint from input list. """ hint = "" for ele in x: if isinstance(ele, int): if ele % 16 == 0 and ele > 0: hint += "i64:16," elif ele == 1: hint += "i64:1," else: hint += "i64," if isinstance(ele, float): hint += "fp32," return hint def get_dtype_str(dtype): """ Get the dtype str. """ if dtype == paddle.float16: return "_fp16" if dtype == paddle.float8_e4m3fn: return "_float8_e4m3fn" elif dtype == paddle.uint8: return "_u8" elif dtype == paddle.int8: return "_i8" elif dtype == paddle.int16: return "_i16" elif dtype == paddle.int32: return "_i32" elif dtype == paddle.int64: return "_i64" elif dtype == paddle.float32: return "_fp32" elif dtype == paddle.bfloat16: return "_bf16" else: raise ValueError("Not support this dtype.") def build_package(generated_dir, python_package_name): """ Build the package, not install it. Args: generated_dir: the source cu file dir. python_package_name: the python package name. """ setup_file_path = generated_dir + "/setup_cuda.py" python_path = sys.executable with open(setup_file_path, "w") as f: f.write(template_install.format(python_package_name=python_package_name)) f.close() install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build" re = os.system(install_command) assert re == 0 def rename_c_to_cu(generated_dir): """ Rename the .c files int generated_dir to .cu file, because the triton aot tool generate the .c files. """ # rename the .c file to .cu for filename in os.listdir(generated_dir): if filename.endswith(".c"): old_path = os.path.join(generated_dir, filename) new_path = os.path.join(generated_dir, filename + "u") os.rename(old_path, new_path) def get_pointer_hint(dtypes): """ Get the pointer hint from input list. """ hint = "" for ele in dtypes: if ele == paddle.float16: hint += "*fp16:16," elif ele == paddle.uint8: hint += "*u8:16," elif ele == paddle.int8: hint += "*i8:16," elif ele == paddle.int16: hint += "*i16:16," elif ele == paddle.float32: hint += "*fp32:16," elif ele == paddle.bfloat16: hint += "*bf16:16," elif ele == paddle.int32: hint += "*i32:16," elif ele == paddle.int64: hint += "*i64," elif ele == paddle.float8_e4m3fn: hint += "*fp8e4nv:16," return hint paddle_custom_op_head_part = """#include #include #include "${op_name}_kernel.h" #include "paddle/extension.h" std::map, int> map_problem_${op_name}; CUdeviceptr get_tensor_ptr(const paddle::Tensor& input){ if (input.type() == paddle::DataType::FLOAT16) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::BFLOAT16) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::INT32) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::FLOAT32) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::UINT8) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::INT8) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::INT64) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::INT32) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::INT16) { return (CUdeviceptr)(input.data()); } else if (input.type() == paddle::DataType::FLOAT8_E4M3FN) { return (CUdeviceptr)(input.data()); } else { assert(false); return (CUdeviceptr)(nullptr); } } int triton_cdiv(int x, int y) { int result = (x + y - 1) / y; return (int)(result); } """ tune_and_invoke_part = """ std::vector problem_size = {${key}}; auto run_triton_kernel = [&](int algo_id) -> CUresult{ return ${op_name}_kernel(run_stream, ${triton_kernel_args}, algo_id); }; map_problem_${op_name}[problem_size] = 0; if (!map_problem_${op_name}.count(problem_size)) { std::cout << "we are tuning for ${op_name} which key is: {"; for (int i = 0; i < problem_size.size(); i++) { std::cout << problem_size[i] << ", "; } std::cout << "}" << std::endl; float min_time = 10000.f; int select_id = -1; constexpr int WARMUP = 5; constexpr int REPEAT = 10; for (int algo_id = 0; algo_id < ${op_name}_kernel_get_num_algos(); ++algo_id) { cudaEvent_t beg[REPEAT]; cudaEvent_t end[REPEAT]; float elapsed_times[REPEAT]; auto status = CUDA_SUCCESS; for (int ii = 0; ii < WARMUP + REPEAT; ii++) { int repeat_id = ii - WARMUP; if (repeat_id >= 0) { (cudaEventCreate(beg + repeat_id)); (cudaEventCreate(end + repeat_id)); (cudaEventRecord(beg[repeat_id])); } auto flush_l2_cache = paddle::full( {10 * 1024 * 1024}, 0, paddle::DataType::INT32, ${arbitary_output_name}.place()); // std::cout << &flush_l2_cache << std::endl; // this is used when out is need to be reset to zero, such as split-k gemm. ${reset_zero_when_tune}; status = run_triton_kernel(algo_id); // assert(status == CUDA_SUCCESS); if (repeat_id >= 0) { (cudaEventRecord(end[repeat_id])); (cudaEventSynchronize(end[repeat_id])); (cudaEventElapsedTime( elapsed_times + repeat_id, beg[repeat_id], end[repeat_id])); } } float avg_elapsed_time = 0.f; for (int ii = 0; ii < REPEAT; ++ii) { avg_elapsed_time += elapsed_times[ii]; } std::cout << "algo id " << algo_id << " costs " << avg_elapsed_time << " ms" << std::endl; if (avg_elapsed_time < min_time && status == CUDA_SUCCESS) { min_time = avg_elapsed_time; select_id = algo_id; } } map_problem_${op_name}[problem_size] = select_id; std::cout << "select algo id: " << select_id << std::endl; ${reset_zero_when_tune}; } if (map_problem_${op_name}.count(problem_size)) { int algo_id = map_problem_${op_name}[problem_size]; auto status = run_triton_kernel(algo_id); assert(status == CUDA_SUCCESS); } """ common_template = ( """ std::vector ${op_name}_func(${input_and_attr}) { ${prepare_attr_for_triton_kernel} ${prepare_ptr_for_triton_kernel} auto run_stream = ${arbitary_output_name}.stream(); """ + tune_and_invoke_part + """ return {${return_tensor_names}}; } ${d2s_infer_code} PD_BUILD_OP(${op_name}) .Inputs({${paddle_input_sig}}) .Outputs({${paddle_output_sig}}) .Attrs({${paddle_attr_sig}}) .SetKernelFn(PD_KERNEL(${op_name}_func)) .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); """ ) def rendering_common_template( func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel, return_tensor_names=None, d2s_infer_code="", ): """ Render a template with given function and its arguments. Args: func: The function to render. prepare_attr_for_triton_kernel: The code snippet that prepares attributes for Triton kernel. prepare_ptr_for_triton_kernel: The code snippet that prepares pointers for Triton kernel. return_tensor_names: The names of the returned tensors. Default is None. """ signature = inspect.signature(func) arg_names = [v.name for v in signature.parameters.values()] arg_defaults = [v.default for v in signature.parameters.values()] input_and_attr = "" paddle_input_sig = "" paddle_attr_sig = "" if return_tensor_names is None: return_tensor_names = "useless" prepare_ptr_for_triton_kernel += ( "auto useless = paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());" ) for i in range(len(arg_names)): if arg_defaults[i] is None: input_and_attr += f"paddle::optional & {arg_names[i]}," paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),""" elif isinstance(arg_defaults[i], float): input_and_attr += f"float {arg_names[i]}," paddle_attr_sig += f""""{arg_names[i]}: float",""" elif isinstance(arg_defaults[i], bool): input_and_attr += f"bool {arg_names[i]}," paddle_attr_sig += f""""{arg_names[i]}: bool",""" elif isinstance(arg_defaults[i], int): input_and_attr += f"int64_t {arg_names[i]}," paddle_attr_sig += f""""{arg_names[i]}: int64_t",""" elif isinstance(arg_defaults[i], str): input_and_attr += f"std::string {arg_names[i]}," paddle_attr_sig += f""""{arg_names[i]}: std::string",""" elif arg_names[i] == "config": continue else: input_and_attr += f"const paddle::Tensor & {arg_names[i]}," paddle_input_sig += f""""{arg_names[i]}",""" input_and_attr = input_and_attr[:-1] paddle_input_sig = paddle_input_sig[:-1] if len(paddle_attr_sig) > 1: paddle_attr_sig = paddle_attr_sig[:-1] paddle_output_sig = "" arbitary_output_name = "" for name in return_tensor_names.split(","): name = name.strip() arbitary_output_name = name paddle_output_sig += f""""{name}",""" paddle_output_sig = paddle_output_sig[:-1] if "${op_name}_InferShape" not in d2s_infer_code: d2s_infer_shape_part = ( "std::vector> ${op_name}_InferShape(" "const std::vector& A_shape) {" "return {${tmp}};" "}\n " ) tmp = ",".join(["A_shape"] * len(return_tensor_names.split(","))) tmp_dict = {"tmp": tmp} d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part, tmp_dict) d2s_infer_code += d2s_infer_shape_part if "${op_name}_InferDtype" not in d2s_infer_code: d2s_infer_dtype_part = ( "std::vector ${op_name}_InferDtype(" "const paddle::DataType& A_dtype) {" "return {${tmp}};" "}\n " ) tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(","))) tmp_dict = {"tmp": tmp} d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part, tmp_dict) d2s_infer_code += d2s_infer_dtype_part result_str = SubstituteTemplate( common_template, { "input_and_attr": input_and_attr, "prepare_attr_for_triton_kernel": prepare_attr_for_triton_kernel, "prepare_ptr_for_triton_kernel": prepare_ptr_for_triton_kernel, "return_tensor_names": return_tensor_names, "arbitary_output_name": arbitary_output_name, "d2s_infer_code": d2s_infer_code, "paddle_input_sig": paddle_input_sig, "paddle_output_sig": paddle_output_sig, "paddle_attr_sig": paddle_attr_sig, }, ) return paddle_custom_op_head_part + result_str class KernelInterface: """ triton kernel interface. """ def __init__( self, func, other_config, key_args=["1"], ): """ triton kernel interface. """ self.func = func self.key_args = key_args signature = inspect.signature(func) self.arg_names = [v.name for v in signature.parameters.values()] for ele in self.arg_names: assert self.arg_names.count(ele) == 1 # arg_defaults = [v.default for v in signature.parameters.values()] # self.annotations = { # name: ty for name, ty in func.__annotations__.items() # } self.annotations = dict(func.__annotations__) self.constexprs = [ self.arg_names.index(name) for name in self.arg_names if self.annotations.get(name) == triton.language.core.constexpr ] self.arg_exclude_constexpr = [ self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs ] import textwrap py_script = textwrap.dedent(inspect.getsource(func)) import re pat = r"def\s" + func.__name__ func_begin = re.findall(pat, py_script) assert len(func_begin) == 1 func_begin = func_begin[0] py_script = py_script[py_script.find(func_begin) :] def decorator(*args, **kwargs): """ decorator for triton kernels. Args: *args: positional arguments **kwargs: keyword arguments """ all_input = [] for i in range(len(args)): all_input.append(args[i]) position_arguments_num = len(all_input) for i in range(position_arguments_num, len(self.arg_names)): if self.arg_names[i] in kwargs.keys(): all_input.append(kwargs[self.arg_names[i]]) else: # means this input is not specified, it muse be a tl.constexpr. assert i in self.constexprs all_input.append(None) dtypes = [] x_list = [] const_args = [self.arg_names[i] for i in self.constexprs] # we dont allow there are two strings in const_args, and one is a substring of the other. for i in const_args: for j in const_args: if i != j and i.find(j) != -1: raise ValueError( f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, " "please modify your triton kernel arguments names to avoid this." ) modified_arg_exclude_constexpr = self.arg_exclude_constexpr const_hint_dict = {} for i in range(len(all_input)): ele = all_input[i] if ( isinstance(ele, paddle.Tensor) or isinstance(ele, paddle.base.framework.EagerParamBase) or isinstance(ele, paddle.base.framework.Parameter) or isinstance(ele, paddle.base.framework.Variable) or isinstance(ele, paddle.base.libpaddle.pir.Value) ): dtypes.append(ele.dtype) modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]" elif i in self.constexprs: const_hint_dict[self.arg_names[i]] = ele else: x_list.append(ele) op_name = self.op_name python_package_name = f"{op_name}_package" tp_rank = paddle.distributed.get_rank() generated_dir = envs.FD_TRITON_KERNEL_CACHE_DIR if generated_dir is None: generated_dir = f"/tmp/triton_cache/rank{tp_rank}" print("the kernel cache dir is:", generated_dir) assert generated_dir is not None, ( "TRITON_KERNEL_CACHE_DIR is None, please set it such as " "export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache " ) generated_dir = f"{generated_dir}/{op_name}" os.makedirs(generated_dir, exist_ok=True) py_script_file = f"{generated_dir}/triton_kernels.py" extract_triton_kernel(func, py_script_file) address_hint = get_pointer_hint(dtypes) value_hint = get_value_hint(x_list) const_args = [f"{{{ele}}}" for ele in const_args] const_args = ",".join(const_args) lanuch_grid = list(self.grid) for i in range(len(lanuch_grid)): ele = lanuch_grid[i] if isinstance(ele, str): for key in const_hint_dict.keys(): if key in ele: ele = ele.replace(key, f"{{{key}}}") else: ele = str(ele) lanuch_grid[i] = ele if len(lanuch_grid) < 3: lanuch_grid += ["1"] * (3 - len(lanuch_grid)) lanuch_grid = ",".join(lanuch_grid) op_dict = {"op_name": op_name, "reset_zero_when_tune": ""} op_dict["triton_kernel_args"] = ",".join(modified_arg_exclude_constexpr) op_dict["key"] = ",".join(self.key_args) # when tunning, we need to reset the out to zero. if "reset_zero_when_tune" in other_config.keys(): op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"] paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu" so_path = find_so_path(generated_dir, python_package_name) if so_path is None: print("== we do not find so_path, we need to compile it") with open(paddle_custom_op_file_path, "w") as f: f.write( SubstituteTemplate( self.custom_op_template, op_dict, ) ) f.close() # ahead of time compile command. aot_template = ( f"""{python_path} {compile_file} {py_script_file} """ + f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """ + f"""--out-name {op_name}_kernel """ + """ -w {num_warps} -ns {num_stages} """ + f""" -s"{address_hint} {value_hint} {const_args}" """ + f""" -g "{lanuch_grid}" """ ) all_tune_config = list(self.tune_config) if len(all_tune_config) == 0: # when user do not specify config, we use const_hint_dict as config. all_tune_config = [const_hint_dict] # reset const_hint_dict as empty. const_hint_dict = {} codegen_commands = [] for config in all_tune_config: for key in const_hint_dict.keys(): if const_hint_dict[key] is not None: if key not in config.keys(): config[key] = const_hint_dict[key] else: if config[key] == const_hint_dict[key]: pass else: message = ( f"you specify {key} both in arguments and config, " "and they are not same, this is wrong." ) raise ValueError(message) else: assert key in config.keys(), f"you must specify {key} in your config." if "num_warps" not in config.keys(): config["num_warps"] = 4 if "num_stages" not in config.keys(): config["num_stages"] = 4 for key in config: assert config[key] is not None, f"{key} must be specified." codegen_command = aot_template.format( **config, ) print(codegen_command) codegen_commands.append(codegen_command) multi_process_do(codegen_commands) link_command = ( f"{python_path} {link_file} " f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel" ) re = os.system(link_command) assert re == 0 # rename the .c file to .cu rename_c_to_cu(generated_dir) # build the package to so, not install build_package(generated_dir, python_package_name) if op_name not in OpProtoHolder.instance().op_proto_map.keys(): so_path = find_so_path(generated_dir, python_package_name) print("== we find so_path: ", so_path) assert so_path is not None paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) self.decorator = decorator def __getitem__(self, op_name_and_grid): """ override the operator [], which will call the decorator function. Args: op_name_and_grid: the name of the operator and the grid size. Returns: the decorator function. """ assert len(op_name_and_grid) >= 3, "len(op_name_and_grid) must >= 3." self.op_name = op_name_and_grid[0] self.custom_op_template = op_name_and_grid[1] self.grid = op_name_and_grid[2] if len(op_name_and_grid) == 3: self.tune_config = {} else: self.tune_config = op_name_and_grid[3] return self.decorator def paddle_use_triton(other_config={}, key=[]): """ The decorator function that wraps the original function. Args: func: the original function. Returns: the wrapped function. """ def decorator(func): """ The decorator function that wraps the original function. Args: func: the original function. Returns: the wrapped function. """ return KernelInterface(func, other_config, key) return decorator