mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
809 lines
27 KiB
Python
809 lines
27 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import inspect
|
|
import os
|
|
import re
|
|
import sys
|
|
|
|
import paddle
|
|
import triton
|
|
from paddle.base.framework import OpProtoHolder
|
|
|
|
from fastdeploy import envs
|
|
|
|
compile_file = triton.__path__[0] + "/tools/compile.py"
|
|
link_file = triton.__path__[0] + "/tools/link.py"
|
|
python_path = sys.executable
|
|
|
|
|
|
def SubstituteTemplate(template, values):
|
|
"""
|
|
Substitute all variables in the given template string using the provided values dictionary.
|
|
"""
|
|
text = template
|
|
changed = True
|
|
while changed:
|
|
changed = False
|
|
for key, value in values.items():
|
|
regex = "\\$\\{%s\\}" % key
|
|
newtext = re.sub(regex, value, text)
|
|
if newtext != text:
|
|
changed = True
|
|
text = newtext
|
|
return text
|
|
|
|
|
|
def find_so_path(generated_dir, python_package_name):
|
|
"""
|
|
find the specified so in generated_dir, if not found it will return None.
|
|
"""
|
|
|
|
so_path = []
|
|
for root, dirs, files in os.walk(generated_dir):
|
|
for file in files:
|
|
if file.endswith(python_package_name + ".so"):
|
|
so_path.append(os.path.join(root, file))
|
|
if len(so_path) == 0:
|
|
return None
|
|
else:
|
|
assert len(so_path) == 1
|
|
return so_path[0]
|
|
|
|
|
|
def multi_process_do(commands):
|
|
"""
|
|
Multi-threaded execution of commands.
|
|
"""
|
|
THREADS = 40
|
|
import multiprocessing
|
|
|
|
process = []
|
|
|
|
def one_process_work(commands, thread_id):
|
|
i = thread_id
|
|
while i < len(commands):
|
|
re = os.system(commands[i])
|
|
assert re == 0
|
|
i += THREADS
|
|
|
|
for i in range(THREADS):
|
|
p = multiprocessing.Process(target=one_process_work, args=(commands, i))
|
|
process.append(p)
|
|
for p in process:
|
|
p.start()
|
|
for p in process:
|
|
p.join()
|
|
|
|
|
|
def extract_triton_kernel(kernel, file_name):
|
|
"""
|
|
Extract the triton kernel and write it to the specified file_name.
|
|
|
|
Args:
|
|
kernel: the triton kernel name.
|
|
file_name: the file name you want to write.
|
|
"""
|
|
|
|
import inspect
|
|
import re
|
|
import textwrap
|
|
|
|
fn = kernel
|
|
if isinstance(kernel, triton.runtime.jit.JITFunction):
|
|
fn = kernel.fn
|
|
elif isinstance(kernel, triton.runtime.autotuner.Autotuner):
|
|
fn = kernel.fn.fn
|
|
else:
|
|
AssertionError("error occurs")
|
|
py_script = textwrap.dedent(inspect.getsource(fn))
|
|
|
|
# @triton.jit must only appear once
|
|
# assert len(re.findall("@triton.jit", py_script)) == 1
|
|
assert len(re.findall("def ", py_script)) == 1
|
|
# assert len(re.findall("@haha()", py_script)) == 1
|
|
# py_script = py_script.replace("@haha()", "@triton.jit")
|
|
|
|
py_script = py_script[py_script.find("def ") :]
|
|
py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script
|
|
|
|
py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr")
|
|
|
|
with open(file_name, "w") as f:
|
|
f.write(py_script)
|
|
f.close()
|
|
|
|
|
|
template_install = """
|
|
|
|
import os
|
|
generated_cu = []
|
|
for root, dirs, files in os.walk("./"):
|
|
for file in files:
|
|
if file.endswith(".c") or file.endswith(".cu"):
|
|
generated_cu.append(os.path.join(root, file))
|
|
|
|
|
|
import paddle
|
|
from paddle.utils.cpp_extension import CUDAExtension, setup
|
|
|
|
|
|
def get_gencode_flags():
|
|
prop = paddle.device.cuda.get_device_properties()
|
|
cc = prop.major * 10 + prop.minor
|
|
return ["-gencode", "arch=compute_{{0}},code=sm_{{0}}".format(cc)]
|
|
|
|
|
|
gencode_flags = get_gencode_flags()
|
|
|
|
|
|
|
|
setup(
|
|
name="{python_package_name}",
|
|
ext_modules=CUDAExtension(
|
|
sources = generated_cu,
|
|
extra_compile_args={{
|
|
"cc": ["-lcuda"],
|
|
"nvcc": [
|
|
"-O3",
|
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
|
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
|
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
|
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
|
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
|
]
|
|
+ gencode_flags,
|
|
}},
|
|
extra_link_args = ["-lcuda"]
|
|
),
|
|
)
|
|
"""
|
|
|
|
|
|
def get_op_name_with_suffix(op_name, x_list):
|
|
"""
|
|
Get the operator name with suffix.
|
|
"""
|
|
suffix = []
|
|
for x in x_list:
|
|
if x % 16 == 0:
|
|
suffix.append(16)
|
|
elif x == 1:
|
|
suffix.append(1)
|
|
else:
|
|
suffix.append(0)
|
|
return op_name + "_".join([str(i) for i in suffix])
|
|
|
|
|
|
def get_value_hint(x):
|
|
"""
|
|
Get the value hint from input list.
|
|
"""
|
|
hint = ""
|
|
for ele in x:
|
|
if isinstance(ele, int):
|
|
if ele % 16 == 0 and ele > 0:
|
|
hint += "i64:16,"
|
|
elif ele == 1:
|
|
hint += "i64:1,"
|
|
else:
|
|
hint += "i64,"
|
|
if isinstance(ele, float):
|
|
hint += "fp32,"
|
|
return hint
|
|
|
|
|
|
def get_dtype_str(dtype):
|
|
"""
|
|
Get the dtype str.
|
|
"""
|
|
if dtype == paddle.float16:
|
|
return "_fp16"
|
|
if dtype == paddle.float8_e4m3fn:
|
|
return "_float8_e4m3fn"
|
|
elif dtype == paddle.uint8:
|
|
return "_u8"
|
|
elif dtype == paddle.int8:
|
|
return "_i8"
|
|
elif dtype == paddle.int16:
|
|
return "_i16"
|
|
elif dtype == paddle.int32:
|
|
return "_i32"
|
|
elif dtype == paddle.int64:
|
|
return "_i64"
|
|
elif dtype == paddle.float32:
|
|
return "_fp32"
|
|
elif dtype == paddle.bfloat16:
|
|
return "_bf16"
|
|
else:
|
|
raise ValueError("Not support this dtype.")
|
|
|
|
|
|
def build_package(generated_dir, python_package_name):
|
|
"""
|
|
Build the package, not install it.
|
|
|
|
Args:
|
|
generated_dir: the source cu file dir.
|
|
python_package_name: the python package name.
|
|
"""
|
|
setup_file_path = generated_dir + "/setup_cuda.py"
|
|
python_path = sys.executable
|
|
with open(setup_file_path, "w") as f:
|
|
f.write(template_install.format(python_package_name=python_package_name))
|
|
f.close()
|
|
install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build"
|
|
re = os.system(install_command)
|
|
assert re == 0
|
|
|
|
|
|
def rename_c_to_cu(generated_dir):
|
|
"""
|
|
Rename the .c files int generated_dir to .cu file, because the triton aot tool generate the .c files.
|
|
"""
|
|
# rename the .c file to .cu
|
|
for filename in os.listdir(generated_dir):
|
|
if filename.endswith(".c"):
|
|
old_path = os.path.join(generated_dir, filename)
|
|
new_path = os.path.join(generated_dir, filename + "u")
|
|
os.rename(old_path, new_path)
|
|
|
|
|
|
def get_pointer_hint(dtypes):
|
|
"""
|
|
Get the pointer hint from input list.
|
|
"""
|
|
hint = ""
|
|
for ele in dtypes:
|
|
if ele == paddle.float16:
|
|
hint += "*fp16:16,"
|
|
elif ele == paddle.uint8:
|
|
hint += "*u8:16,"
|
|
elif ele == paddle.int8:
|
|
hint += "*i8:16,"
|
|
elif ele == paddle.int16:
|
|
hint += "*i16:16,"
|
|
elif ele == paddle.float32:
|
|
hint += "*fp32:16,"
|
|
elif ele == paddle.bfloat16:
|
|
hint += "*bf16:16,"
|
|
elif ele == paddle.int32:
|
|
hint += "*i32:16,"
|
|
elif ele == paddle.int64:
|
|
hint += "*i64,"
|
|
elif ele == paddle.float8_e4m3fn:
|
|
hint += "*fp8e4nv:16,"
|
|
return hint
|
|
|
|
|
|
paddle_custom_op_head_part = """#include <vector>
|
|
#include <map>
|
|
#include "${op_name}_kernel.h"
|
|
#include "paddle/extension.h"
|
|
|
|
std::map<std::vector<int>, int> map_problem_${op_name};
|
|
|
|
CUdeviceptr get_tensor_ptr(const paddle::Tensor& input){
|
|
if (input.type() == paddle::DataType::FLOAT16) {
|
|
return (CUdeviceptr)(input.data<phi::dtype::float16>());
|
|
} else if (input.type() == paddle::DataType::BFLOAT16) {
|
|
return (CUdeviceptr)(input.data<phi::dtype::bfloat16>());
|
|
} else if (input.type() == paddle::DataType::INT32) {
|
|
return (CUdeviceptr)(input.data<int>());
|
|
} else if (input.type() == paddle::DataType::FLOAT32) {
|
|
return (CUdeviceptr)(input.data<float>());
|
|
} else if (input.type() == paddle::DataType::UINT8) {
|
|
return (CUdeviceptr)(input.data<uint8_t>());
|
|
} else if (input.type() == paddle::DataType::INT8) {
|
|
return (CUdeviceptr)(input.data<int8_t>());
|
|
} else if (input.type() == paddle::DataType::INT64) {
|
|
return (CUdeviceptr)(input.data<int64_t>());
|
|
} else if (input.type() == paddle::DataType::INT32) {
|
|
return (CUdeviceptr)(input.data<int32_t>());
|
|
} else if (input.type() == paddle::DataType::INT16) {
|
|
return (CUdeviceptr)(input.data<int16_t>());
|
|
} else if (input.type() == paddle::DataType::FLOAT8_E4M3FN) {
|
|
return (CUdeviceptr)(input.data<phi::dtype::float8_e4m3fn>());
|
|
} else {
|
|
assert(false);
|
|
return (CUdeviceptr)(nullptr);
|
|
}
|
|
}
|
|
|
|
int triton_cdiv(int x, int y) {
|
|
int result = (x + y - 1) / y;
|
|
return (int)(result);
|
|
}
|
|
"""
|
|
|
|
tune_and_invoke_part = """
|
|
std::vector<int> problem_size = {${key}};
|
|
auto run_triton_kernel = [&](int algo_id) -> CUresult{
|
|
return ${op_name}_kernel(run_stream,
|
|
${triton_kernel_args},
|
|
algo_id);
|
|
};
|
|
|
|
map_problem_${op_name}[problem_size] = 0;
|
|
|
|
if (!map_problem_${op_name}.count(problem_size)) {
|
|
std::cout << "we are tuning for ${op_name} which key is: {";
|
|
for (int i = 0; i < problem_size.size(); i++) {
|
|
std::cout << problem_size[i] << ", ";
|
|
}
|
|
std::cout << "}" << std::endl;
|
|
|
|
float min_time = 10000.f;
|
|
int select_id = -1;
|
|
constexpr int WARMUP = 5;
|
|
constexpr int REPEAT = 10;
|
|
|
|
for (int algo_id = 0; algo_id < ${op_name}_kernel_get_num_algos(); ++algo_id) {
|
|
cudaEvent_t beg[REPEAT];
|
|
cudaEvent_t end[REPEAT];
|
|
float elapsed_times[REPEAT];
|
|
|
|
auto status = CUDA_SUCCESS;
|
|
|
|
for (int ii = 0; ii < WARMUP + REPEAT; ii++) {
|
|
int repeat_id = ii - WARMUP;
|
|
|
|
if (repeat_id >= 0) {
|
|
(cudaEventCreate(beg + repeat_id));
|
|
(cudaEventCreate(end + repeat_id));
|
|
(cudaEventRecord(beg[repeat_id]));
|
|
}
|
|
|
|
auto flush_l2_cache = paddle::full(
|
|
{10 * 1024 * 1024}, 0, paddle::DataType::INT32, ${arbitary_output_name}.place());
|
|
// std::cout << &flush_l2_cache << std::endl;
|
|
// this is used when out is need to be reset to zero, such as split-k gemm.
|
|
${reset_zero_when_tune};
|
|
|
|
status = run_triton_kernel(algo_id);
|
|
// assert(status == CUDA_SUCCESS);
|
|
|
|
if (repeat_id >= 0) {
|
|
(cudaEventRecord(end[repeat_id]));
|
|
(cudaEventSynchronize(end[repeat_id]));
|
|
(cudaEventElapsedTime(
|
|
elapsed_times + repeat_id, beg[repeat_id], end[repeat_id]));
|
|
}
|
|
}
|
|
|
|
float avg_elapsed_time = 0.f;
|
|
for (int ii = 0; ii < REPEAT; ++ii) {
|
|
avg_elapsed_time += elapsed_times[ii];
|
|
}
|
|
|
|
std::cout << "algo id " << algo_id << " costs " << avg_elapsed_time << " ms" << std::endl;
|
|
|
|
if (avg_elapsed_time < min_time && status == CUDA_SUCCESS) {
|
|
min_time = avg_elapsed_time;
|
|
select_id = algo_id;
|
|
}
|
|
}
|
|
|
|
map_problem_${op_name}[problem_size] = select_id;
|
|
std::cout << "select algo id: " << select_id << std::endl;
|
|
${reset_zero_when_tune};
|
|
}
|
|
|
|
if (map_problem_${op_name}.count(problem_size)) {
|
|
int algo_id = map_problem_${op_name}[problem_size];
|
|
auto status = run_triton_kernel(algo_id);
|
|
assert(status == CUDA_SUCCESS);
|
|
}
|
|
"""
|
|
|
|
common_template = (
|
|
"""
|
|
std::vector<paddle::Tensor> ${op_name}_func(${input_and_attr}) {
|
|
${prepare_attr_for_triton_kernel}
|
|
${prepare_ptr_for_triton_kernel}
|
|
auto run_stream = ${arbitary_output_name}.stream();
|
|
"""
|
|
+ tune_and_invoke_part
|
|
+ """
|
|
return {${return_tensor_names}};
|
|
}
|
|
|
|
${d2s_infer_code}
|
|
|
|
PD_BUILD_OP(${op_name})
|
|
.Inputs({${paddle_input_sig}})
|
|
.Outputs({${paddle_output_sig}})
|
|
.Attrs({${paddle_attr_sig}})
|
|
.SetKernelFn(PD_KERNEL(${op_name}_func))
|
|
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
|
|
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
|
|
"""
|
|
)
|
|
|
|
|
|
def rendering_common_template(
|
|
func,
|
|
prepare_attr_for_triton_kernel,
|
|
prepare_ptr_for_triton_kernel,
|
|
return_tensor_names=None,
|
|
d2s_infer_code="",
|
|
):
|
|
"""
|
|
Render a template with given function and its arguments.
|
|
Args:
|
|
func: The function to render.
|
|
prepare_attr_for_triton_kernel: The code snippet that prepares attributes for Triton kernel.
|
|
prepare_ptr_for_triton_kernel: The code snippet that prepares pointers for Triton kernel.
|
|
return_tensor_names: The names of the returned tensors. Default is None.
|
|
"""
|
|
signature = inspect.signature(func)
|
|
arg_names = [v.name for v in signature.parameters.values()]
|
|
arg_defaults = [v.default for v in signature.parameters.values()]
|
|
input_and_attr = ""
|
|
paddle_input_sig = ""
|
|
paddle_attr_sig = ""
|
|
|
|
if return_tensor_names is None:
|
|
return_tensor_names = "useless"
|
|
prepare_ptr_for_triton_kernel += (
|
|
"auto useless = paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());"
|
|
)
|
|
|
|
for i in range(len(arg_names)):
|
|
if arg_defaults[i] is None:
|
|
input_and_attr += f"paddle::optional<paddle::Tensor> & {arg_names[i]},"
|
|
paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),"""
|
|
elif isinstance(arg_defaults[i], float):
|
|
input_and_attr += f"float {arg_names[i]},"
|
|
paddle_attr_sig += f""""{arg_names[i]}: float","""
|
|
elif isinstance(arg_defaults[i], bool):
|
|
input_and_attr += f"bool {arg_names[i]},"
|
|
paddle_attr_sig += f""""{arg_names[i]}: bool","""
|
|
elif isinstance(arg_defaults[i], int):
|
|
input_and_attr += f"int64_t {arg_names[i]},"
|
|
paddle_attr_sig += f""""{arg_names[i]}: int64_t","""
|
|
elif isinstance(arg_defaults[i], str):
|
|
input_and_attr += f"std::string {arg_names[i]},"
|
|
paddle_attr_sig += f""""{arg_names[i]}: std::string","""
|
|
elif arg_names[i] == "config":
|
|
continue
|
|
else:
|
|
input_and_attr += f"const paddle::Tensor & {arg_names[i]},"
|
|
paddle_input_sig += f""""{arg_names[i]}","""
|
|
input_and_attr = input_and_attr[:-1]
|
|
paddle_input_sig = paddle_input_sig[:-1]
|
|
if len(paddle_attr_sig) > 1:
|
|
paddle_attr_sig = paddle_attr_sig[:-1]
|
|
|
|
paddle_output_sig = ""
|
|
arbitary_output_name = ""
|
|
for name in return_tensor_names.split(","):
|
|
name = name.strip()
|
|
arbitary_output_name = name
|
|
paddle_output_sig += f""""{name}","""
|
|
paddle_output_sig = paddle_output_sig[:-1]
|
|
|
|
if "${op_name}_InferShape" not in d2s_infer_code:
|
|
d2s_infer_shape_part = (
|
|
"std::vector<std::vector<int64_t>> ${op_name}_InferShape("
|
|
"const std::vector<int64_t>& A_shape) {"
|
|
"return {${tmp}};"
|
|
"}\n "
|
|
)
|
|
tmp = ",".join(["A_shape"] * len(return_tensor_names.split(",")))
|
|
tmp_dict = {"tmp": tmp}
|
|
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part, tmp_dict)
|
|
|
|
d2s_infer_code += d2s_infer_shape_part
|
|
|
|
if "${op_name}_InferDtype" not in d2s_infer_code:
|
|
d2s_infer_dtype_part = (
|
|
"std::vector<paddle::DataType> ${op_name}_InferDtype("
|
|
"const paddle::DataType& A_dtype) {"
|
|
"return {${tmp}};"
|
|
"}\n "
|
|
)
|
|
tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(",")))
|
|
tmp_dict = {"tmp": tmp}
|
|
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part, tmp_dict)
|
|
|
|
d2s_infer_code += d2s_infer_dtype_part
|
|
|
|
result_str = SubstituteTemplate(
|
|
common_template,
|
|
{
|
|
"input_and_attr": input_and_attr,
|
|
"prepare_attr_for_triton_kernel": prepare_attr_for_triton_kernel,
|
|
"prepare_ptr_for_triton_kernel": prepare_ptr_for_triton_kernel,
|
|
"return_tensor_names": return_tensor_names,
|
|
"arbitary_output_name": arbitary_output_name,
|
|
"d2s_infer_code": d2s_infer_code,
|
|
"paddle_input_sig": paddle_input_sig,
|
|
"paddle_output_sig": paddle_output_sig,
|
|
"paddle_attr_sig": paddle_attr_sig,
|
|
},
|
|
)
|
|
|
|
return paddle_custom_op_head_part + result_str
|
|
|
|
|
|
class KernelInterface:
|
|
"""
|
|
triton kernel interface.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
func,
|
|
other_config,
|
|
key_args=["1"],
|
|
):
|
|
"""
|
|
triton kernel interface.
|
|
"""
|
|
self.func = func
|
|
self.key_args = key_args
|
|
|
|
signature = inspect.signature(func)
|
|
self.arg_names = [v.name for v in signature.parameters.values()]
|
|
for ele in self.arg_names:
|
|
assert self.arg_names.count(ele) == 1
|
|
# arg_defaults = [v.default for v in signature.parameters.values()]
|
|
|
|
# self.annotations = {
|
|
# name: ty for name, ty in func.__annotations__.items()
|
|
# }
|
|
self.annotations = dict(func.__annotations__)
|
|
|
|
self.constexprs = [
|
|
self.arg_names.index(name)
|
|
for name in self.arg_names
|
|
if self.annotations.get(name) == triton.language.core.constexpr
|
|
]
|
|
|
|
self.arg_exclude_constexpr = [
|
|
self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs
|
|
]
|
|
|
|
import textwrap
|
|
|
|
py_script = textwrap.dedent(inspect.getsource(func))
|
|
|
|
import re
|
|
|
|
pat = r"def\s" + func.__name__
|
|
func_begin = re.findall(pat, py_script)
|
|
assert len(func_begin) == 1
|
|
func_begin = func_begin[0]
|
|
py_script = py_script[py_script.find(func_begin) :]
|
|
|
|
def decorator(*args, **kwargs):
|
|
"""
|
|
decorator for triton kernels.
|
|
Args:
|
|
*args: positional arguments
|
|
**kwargs: keyword arguments
|
|
"""
|
|
all_input = []
|
|
|
|
for i in range(len(args)):
|
|
all_input.append(args[i])
|
|
|
|
position_arguments_num = len(all_input)
|
|
for i in range(position_arguments_num, len(self.arg_names)):
|
|
if self.arg_names[i] in kwargs.keys():
|
|
all_input.append(kwargs[self.arg_names[i]])
|
|
else:
|
|
# means this input is not specified, it muse be a tl.constexpr.
|
|
assert i in self.constexprs
|
|
all_input.append(None)
|
|
|
|
dtypes = []
|
|
x_list = []
|
|
const_args = [self.arg_names[i] for i in self.constexprs]
|
|
# we dont allow there are two strings in const_args, and one is a substring of the other.
|
|
for i in const_args:
|
|
for j in const_args:
|
|
if i != j and i.find(j) != -1:
|
|
raise ValueError(
|
|
f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, "
|
|
"please modify your triton kernel arguments names to avoid this."
|
|
)
|
|
|
|
modified_arg_exclude_constexpr = self.arg_exclude_constexpr
|
|
const_hint_dict = {}
|
|
for i in range(len(all_input)):
|
|
ele = all_input[i]
|
|
if (
|
|
isinstance(ele, paddle.Tensor)
|
|
or isinstance(ele, paddle.base.framework.EagerParamBase)
|
|
or isinstance(ele, paddle.base.framework.Parameter)
|
|
or isinstance(ele, paddle.base.framework.Variable)
|
|
or isinstance(ele, paddle.base.libpaddle.pir.Value)
|
|
):
|
|
dtypes.append(ele.dtype)
|
|
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
|
|
elif i in self.constexprs:
|
|
const_hint_dict[self.arg_names[i]] = ele
|
|
else:
|
|
x_list.append(ele)
|
|
|
|
op_name = self.op_name
|
|
|
|
python_package_name = f"{op_name}_package"
|
|
tp_rank = paddle.distributed.get_rank()
|
|
generated_dir = envs.FD_TRITON_KERNEL_CACHE_DIR
|
|
if generated_dir is None:
|
|
generated_dir = f"/tmp/triton_cache/rank{tp_rank}"
|
|
print("the kernel cache dir is:", generated_dir)
|
|
assert generated_dir is not None, (
|
|
"TRITON_KERNEL_CACHE_DIR is None, please set it such as "
|
|
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache "
|
|
)
|
|
generated_dir = f"{generated_dir}/{op_name}"
|
|
os.makedirs(generated_dir, exist_ok=True)
|
|
|
|
py_script_file = f"{generated_dir}/triton_kernels.py"
|
|
extract_triton_kernel(func, py_script_file)
|
|
|
|
address_hint = get_pointer_hint(dtypes)
|
|
value_hint = get_value_hint(x_list)
|
|
const_args = [f"{{{ele}}}" for ele in const_args]
|
|
const_args = ",".join(const_args)
|
|
|
|
lanuch_grid = list(self.grid)
|
|
for i in range(len(lanuch_grid)):
|
|
ele = lanuch_grid[i]
|
|
if isinstance(ele, str):
|
|
for key in const_hint_dict.keys():
|
|
if key in ele:
|
|
ele = ele.replace(key, f"{{{key}}}")
|
|
else:
|
|
ele = str(ele)
|
|
|
|
lanuch_grid[i] = ele
|
|
if len(lanuch_grid) < 3:
|
|
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
|
|
lanuch_grid = ",".join(lanuch_grid)
|
|
|
|
op_dict = {"op_name": op_name, "reset_zero_when_tune": ""}
|
|
op_dict["triton_kernel_args"] = ",".join(modified_arg_exclude_constexpr)
|
|
op_dict["key"] = ",".join(self.key_args)
|
|
# when tunning, we need to reset the out to zero.
|
|
if "reset_zero_when_tune" in other_config.keys():
|
|
op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"]
|
|
|
|
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
|
|
so_path = find_so_path(generated_dir, python_package_name)
|
|
|
|
if so_path is None:
|
|
print("== we do not find so_path, we need to compile it")
|
|
with open(paddle_custom_op_file_path, "w") as f:
|
|
f.write(
|
|
SubstituteTemplate(
|
|
self.custom_op_template,
|
|
op_dict,
|
|
)
|
|
)
|
|
f.close()
|
|
|
|
# ahead of time compile command.
|
|
aot_template = (
|
|
f"""{python_path} {compile_file} {py_script_file} """
|
|
+ f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
|
|
+ f"""--out-name {op_name}_kernel """
|
|
+ """ -w {num_warps} -ns {num_stages} """
|
|
+ f""" -s"{address_hint} {value_hint} {const_args}" """
|
|
+ f""" -g "{lanuch_grid}" """
|
|
)
|
|
all_tune_config = list(self.tune_config)
|
|
if len(all_tune_config) == 0:
|
|
# when user do not specify config, we use const_hint_dict as config.
|
|
all_tune_config = [const_hint_dict]
|
|
# reset const_hint_dict as empty.
|
|
const_hint_dict = {}
|
|
codegen_commands = []
|
|
for config in all_tune_config:
|
|
for key in const_hint_dict.keys():
|
|
if const_hint_dict[key] is not None:
|
|
if key not in config.keys():
|
|
config[key] = const_hint_dict[key]
|
|
else:
|
|
if config[key] == const_hint_dict[key]:
|
|
pass
|
|
else:
|
|
message = (
|
|
f"you specify {key} both in arguments and config, "
|
|
"and they are not same, this is wrong."
|
|
)
|
|
raise ValueError(message)
|
|
else:
|
|
assert key in config.keys(), f"you must specify {key} in your config."
|
|
if "num_warps" not in config.keys():
|
|
config["num_warps"] = 4
|
|
if "num_stages" not in config.keys():
|
|
config["num_stages"] = 4
|
|
|
|
for key in config:
|
|
assert config[key] is not None, f"{key} must be specified."
|
|
codegen_command = aot_template.format(
|
|
**config,
|
|
)
|
|
print(codegen_command)
|
|
codegen_commands.append(codegen_command)
|
|
multi_process_do(codegen_commands)
|
|
|
|
link_command = (
|
|
f"{python_path} {link_file} " f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel"
|
|
)
|
|
re = os.system(link_command)
|
|
assert re == 0
|
|
|
|
# rename the .c file to .cu
|
|
rename_c_to_cu(generated_dir)
|
|
# build the package to so, not install
|
|
build_package(generated_dir, python_package_name)
|
|
|
|
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
|
|
so_path = find_so_path(generated_dir, python_package_name)
|
|
print("== we find so_path: ", so_path)
|
|
assert so_path is not None
|
|
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
|
|
|
|
self.decorator = decorator
|
|
|
|
def __getitem__(self, op_name_and_grid):
|
|
"""
|
|
override the operator [], which will call the decorator function.
|
|
Args:
|
|
op_name_and_grid: the name of the operator and the grid size.
|
|
Returns:
|
|
the decorator function.
|
|
"""
|
|
assert len(op_name_and_grid) >= 3, "len(op_name_and_grid) must >= 3."
|
|
self.op_name = op_name_and_grid[0]
|
|
self.custom_op_template = op_name_and_grid[1]
|
|
self.grid = op_name_and_grid[2]
|
|
if len(op_name_and_grid) == 3:
|
|
self.tune_config = {}
|
|
else:
|
|
self.tune_config = op_name_and_grid[3]
|
|
|
|
return self.decorator
|
|
|
|
|
|
def paddle_use_triton(other_config={}, key=[]):
|
|
"""
|
|
The decorator function that wraps the original function.
|
|
Args:
|
|
func: the original function.
|
|
Returns:
|
|
the wrapped function.
|
|
"""
|
|
|
|
def decorator(func):
|
|
"""
|
|
The decorator function that wraps the original function.
|
|
Args:
|
|
func: the original function.
|
|
Returns:
|
|
the wrapped function.
|
|
"""
|
|
return KernelInterface(func, other_config, key)
|
|
|
|
return decorator
|