mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -81,8 +81,7 @@ def multi_process_do(commands):
|
||||
i += THREADS
|
||||
|
||||
for i in range(THREADS):
|
||||
p = multiprocessing.Process(target=one_process_work,
|
||||
args=(commands, i))
|
||||
p = multiprocessing.Process(target=one_process_work, args=(commands, i))
|
||||
process.append(p)
|
||||
for p in process:
|
||||
p.start()
|
||||
@@ -118,7 +117,7 @@ def extract_triton_kernel(kernel, file_name):
|
||||
# assert len(re.findall("@haha()", py_script)) == 1
|
||||
# py_script = py_script.replace("@haha()", "@triton.jit")
|
||||
|
||||
py_script = py_script[py_script.find("def "):]
|
||||
py_script = py_script[py_script.find("def ") :]
|
||||
py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script
|
||||
|
||||
py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr")
|
||||
@@ -245,8 +244,7 @@ def build_package(generated_dir, python_package_name):
|
||||
setup_file_path = generated_dir + "/setup_cuda.py"
|
||||
python_path = sys.executable
|
||||
with open(setup_file_path, "w") as f:
|
||||
f.write(
|
||||
template_install.format(python_package_name=python_package_name))
|
||||
f.write(template_install.format(python_package_name=python_package_name))
|
||||
f.close()
|
||||
install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build"
|
||||
re = os.system(install_command)
|
||||
@@ -412,12 +410,15 @@ tune_and_invoke_part = """
|
||||
}
|
||||
"""
|
||||
|
||||
common_template = ("""
|
||||
common_template = (
|
||||
"""
|
||||
std::vector<paddle::Tensor> ${op_name}_func(${input_and_attr}) {
|
||||
${prepare_attr_for_triton_kernel}
|
||||
${prepare_ptr_for_triton_kernel}
|
||||
auto run_stream = ${arbitary_output_name}.stream();
|
||||
""" + tune_and_invoke_part + """
|
||||
"""
|
||||
+ tune_and_invoke_part
|
||||
+ """
|
||||
return {${return_tensor_names}};
|
||||
}
|
||||
|
||||
@@ -430,7 +431,8 @@ PD_BUILD_OP(${op_name})
|
||||
.SetKernelFn(PD_KERNEL(${op_name}_func))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def rendering_common_template(
|
||||
@@ -500,11 +502,11 @@ def rendering_common_template(
|
||||
"std::vector<std::vector<int64_t>> ${op_name}_InferShape("
|
||||
"const std::vector<int64_t>& A_shape) {"
|
||||
"return {${tmp}};"
|
||||
"}\n ")
|
||||
"}\n "
|
||||
)
|
||||
tmp = ",".join(["A_shape"] * len(return_tensor_names.split(",")))
|
||||
tmp_dict = {"tmp": tmp}
|
||||
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part,
|
||||
tmp_dict)
|
||||
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part, tmp_dict)
|
||||
|
||||
d2s_infer_code += d2s_infer_shape_part
|
||||
|
||||
@@ -513,11 +515,11 @@ def rendering_common_template(
|
||||
"std::vector<paddle::DataType> ${op_name}_InferDtype("
|
||||
"const paddle::DataType& A_dtype) {"
|
||||
"return {${tmp}};"
|
||||
"}\n ")
|
||||
"}\n "
|
||||
)
|
||||
tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(",")))
|
||||
tmp_dict = {"tmp": tmp}
|
||||
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part,
|
||||
tmp_dict)
|
||||
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part, tmp_dict)
|
||||
|
||||
d2s_infer_code += d2s_infer_dtype_part
|
||||
|
||||
@@ -568,13 +570,13 @@ class KernelInterface:
|
||||
self.annotations = dict(func.__annotations__)
|
||||
|
||||
self.constexprs = [
|
||||
self.arg_names.index(name) for name in self.arg_names
|
||||
self.arg_names.index(name)
|
||||
for name in self.arg_names
|
||||
if self.annotations.get(name) == triton.language.core.constexpr
|
||||
]
|
||||
|
||||
self.arg_exclude_constexpr = [
|
||||
self.arg_names[i] for i in range(len(self.arg_names))
|
||||
if i not in self.constexprs
|
||||
self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs
|
||||
]
|
||||
|
||||
import textwrap
|
||||
@@ -587,7 +589,7 @@ class KernelInterface:
|
||||
func_begin = re.findall(pat, py_script)
|
||||
assert len(func_begin) == 1
|
||||
func_begin = func_begin[0]
|
||||
py_script = py_script[py_script.find(func_begin):]
|
||||
py_script = py_script[py_script.find(func_begin) :]
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
"""
|
||||
@@ -626,11 +628,13 @@ class KernelInterface:
|
||||
const_hint_dict = {}
|
||||
for i in range(len(all_input)):
|
||||
ele = all_input[i]
|
||||
if (type(ele) == paddle.Tensor
|
||||
or type(ele) == paddle.base.framework.EagerParamBase
|
||||
or type(ele) == paddle.base.framework.Parameter
|
||||
or type(ele) == paddle.base.framework.Variable
|
||||
or type(ele) == paddle.base.libpaddle.pir.Value):
|
||||
if (
|
||||
type(ele) == paddle.Tensor
|
||||
or type(ele) == paddle.base.framework.EagerParamBase
|
||||
or type(ele) == paddle.base.framework.Parameter
|
||||
or type(ele) == paddle.base.framework.Variable
|
||||
or type(ele) == paddle.base.libpaddle.pir.Value
|
||||
):
|
||||
dtypes.append(ele.dtype)
|
||||
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
|
||||
elif i in self.constexprs:
|
||||
@@ -646,9 +650,10 @@ class KernelInterface:
|
||||
if generated_dir is None:
|
||||
generated_dir = f"/tmp/triton_cache/rank{tp_rank}"
|
||||
print("the kernel cache dir is:", generated_dir)
|
||||
assert (generated_dir is not None), (
|
||||
assert generated_dir is not None, (
|
||||
"TRITON_KERNEL_CACHE_DIR is None, please set it such as "
|
||||
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache ")
|
||||
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache "
|
||||
)
|
||||
generated_dir = f"{generated_dir}/{op_name}"
|
||||
os.makedirs(generated_dir, exist_ok=True)
|
||||
|
||||
@@ -676,13 +681,11 @@ class KernelInterface:
|
||||
lanuch_grid = ",".join(lanuch_grid)
|
||||
|
||||
op_dict = {"op_name": op_name, "reset_zero_when_tune": ""}
|
||||
op_dict["triton_kernel_args"] = ",".join(
|
||||
modified_arg_exclude_constexpr)
|
||||
op_dict["triton_kernel_args"] = ",".join(modified_arg_exclude_constexpr)
|
||||
op_dict["key"] = ",".join(self.key_args)
|
||||
# when tunning, we need to reset the out to zero.
|
||||
if "reset_zero_when_tune" in other_config.keys():
|
||||
op_dict["reset_zero_when_tune"] = other_config[
|
||||
"reset_zero_when_tune"]
|
||||
op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"]
|
||||
|
||||
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
|
||||
so_path = find_so_path(generated_dir, python_package_name)
|
||||
@@ -694,17 +697,19 @@ class KernelInterface:
|
||||
SubstituteTemplate(
|
||||
self.custom_op_template,
|
||||
op_dict,
|
||||
))
|
||||
)
|
||||
)
|
||||
f.close()
|
||||
|
||||
# ahead of time compile command.
|
||||
aot_template = (
|
||||
f"""{python_path} {compile_file} {py_script_file} """ +
|
||||
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
|
||||
+ f"""--out-name {op_name}_kernel """ +
|
||||
""" -w {num_warps} -ns {num_stages} """ +
|
||||
f""" -s"{address_hint} {value_hint} {const_args}" """ +
|
||||
f""" -g "{lanuch_grid}" """)
|
||||
f"""{python_path} {compile_file} {py_script_file} """
|
||||
+ f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
|
||||
+ f"""--out-name {op_name}_kernel """
|
||||
+ """ -w {num_warps} -ns {num_stages} """
|
||||
+ f""" -s"{address_hint} {value_hint} {const_args}" """
|
||||
+ f""" -g "{lanuch_grid}" """
|
||||
)
|
||||
all_tune_config = list(self.tune_config)
|
||||
if len(all_tune_config) == 0:
|
||||
# when user do not specify config, we use const_hint_dict as config.
|
||||
@@ -727,24 +732,24 @@ class KernelInterface:
|
||||
)
|
||||
raise ValueError(message)
|
||||
else:
|
||||
assert key in config.keys(
|
||||
), f"you must specify {key} in your config."
|
||||
assert key in config.keys(), f"you must specify {key} in your config."
|
||||
if "num_warps" not in config.keys():
|
||||
config["num_warps"] = 4
|
||||
if "num_stages" not in config.keys():
|
||||
config["num_stages"] = 4
|
||||
|
||||
for key in config:
|
||||
assert config[
|
||||
key] is not None, f"{key} must be specified."
|
||||
codegen_command = aot_template.format(**config, )
|
||||
assert config[key] is not None, f"{key} must be specified."
|
||||
codegen_command = aot_template.format(
|
||||
**config,
|
||||
)
|
||||
print(codegen_command)
|
||||
codegen_commands.append(codegen_command)
|
||||
multi_process_do(codegen_commands)
|
||||
|
||||
link_command = (
|
||||
f"{python_path} {link_file} "
|
||||
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
|
||||
f"{python_path} {link_file} " f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel"
|
||||
)
|
||||
re = os.system(link_command)
|
||||
assert re == 0
|
||||
|
||||
@@ -757,8 +762,7 @@ class KernelInterface:
|
||||
so_path = find_so_path(generated_dir, python_package_name)
|
||||
print("== we find so_path: ", so_path)
|
||||
assert so_path is not None
|
||||
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
|
||||
so_path)
|
||||
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
|
||||
|
||||
self.decorator = decorator
|
||||
|
||||
|
Reference in New Issue
Block a user