polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -81,8 +81,7 @@ def multi_process_do(commands):
i += THREADS
for i in range(THREADS):
p = multiprocessing.Process(target=one_process_work,
args=(commands, i))
p = multiprocessing.Process(target=one_process_work, args=(commands, i))
process.append(p)
for p in process:
p.start()
@@ -118,7 +117,7 @@ def extract_triton_kernel(kernel, file_name):
# assert len(re.findall("@haha()", py_script)) == 1
# py_script = py_script.replace("@haha()", "@triton.jit")
py_script = py_script[py_script.find("def "):]
py_script = py_script[py_script.find("def ") :]
py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script
py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr")
@@ -245,8 +244,7 @@ def build_package(generated_dir, python_package_name):
setup_file_path = generated_dir + "/setup_cuda.py"
python_path = sys.executable
with open(setup_file_path, "w") as f:
f.write(
template_install.format(python_package_name=python_package_name))
f.write(template_install.format(python_package_name=python_package_name))
f.close()
install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build"
re = os.system(install_command)
@@ -412,12 +410,15 @@ tune_and_invoke_part = """
}
"""
common_template = ("""
common_template = (
"""
std::vector<paddle::Tensor> ${op_name}_func(${input_and_attr}) {
${prepare_attr_for_triton_kernel}
${prepare_ptr_for_triton_kernel}
auto run_stream = ${arbitary_output_name}.stream();
""" + tune_and_invoke_part + """
"""
+ tune_and_invoke_part
+ """
return {${return_tensor_names}};
}
@@ -430,7 +431,8 @@ PD_BUILD_OP(${op_name})
.SetKernelFn(PD_KERNEL(${op_name}_func))
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
""")
"""
)
def rendering_common_template(
@@ -500,11 +502,11 @@ def rendering_common_template(
"std::vector<std::vector<int64_t>> ${op_name}_InferShape("
"const std::vector<int64_t>& A_shape) {"
"return {${tmp}};"
"}\n ")
"}\n "
)
tmp = ",".join(["A_shape"] * len(return_tensor_names.split(",")))
tmp_dict = {"tmp": tmp}
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part,
tmp_dict)
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part, tmp_dict)
d2s_infer_code += d2s_infer_shape_part
@@ -513,11 +515,11 @@ def rendering_common_template(
"std::vector<paddle::DataType> ${op_name}_InferDtype("
"const paddle::DataType& A_dtype) {"
"return {${tmp}};"
"}\n ")
"}\n "
)
tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(",")))
tmp_dict = {"tmp": tmp}
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part,
tmp_dict)
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part, tmp_dict)
d2s_infer_code += d2s_infer_dtype_part
@@ -568,13 +570,13 @@ class KernelInterface:
self.annotations = dict(func.__annotations__)
self.constexprs = [
self.arg_names.index(name) for name in self.arg_names
self.arg_names.index(name)
for name in self.arg_names
if self.annotations.get(name) == triton.language.core.constexpr
]
self.arg_exclude_constexpr = [
self.arg_names[i] for i in range(len(self.arg_names))
if i not in self.constexprs
self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs
]
import textwrap
@@ -587,7 +589,7 @@ class KernelInterface:
func_begin = re.findall(pat, py_script)
assert len(func_begin) == 1
func_begin = func_begin[0]
py_script = py_script[py_script.find(func_begin):]
py_script = py_script[py_script.find(func_begin) :]
def decorator(*args, **kwargs):
"""
@@ -626,11 +628,13 @@ class KernelInterface:
const_hint_dict = {}
for i in range(len(all_input)):
ele = all_input[i]
if (type(ele) == paddle.Tensor
or type(ele) == paddle.base.framework.EagerParamBase
or type(ele) == paddle.base.framework.Parameter
or type(ele) == paddle.base.framework.Variable
or type(ele) == paddle.base.libpaddle.pir.Value):
if (
type(ele) == paddle.Tensor
or type(ele) == paddle.base.framework.EagerParamBase
or type(ele) == paddle.base.framework.Parameter
or type(ele) == paddle.base.framework.Variable
or type(ele) == paddle.base.libpaddle.pir.Value
):
dtypes.append(ele.dtype)
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
elif i in self.constexprs:
@@ -646,9 +650,10 @@ class KernelInterface:
if generated_dir is None:
generated_dir = f"/tmp/triton_cache/rank{tp_rank}"
print("the kernel cache dir is:", generated_dir)
assert (generated_dir is not None), (
assert generated_dir is not None, (
"TRITON_KERNEL_CACHE_DIR is None, please set it such as "
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache ")
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache "
)
generated_dir = f"{generated_dir}/{op_name}"
os.makedirs(generated_dir, exist_ok=True)
@@ -676,13 +681,11 @@ class KernelInterface:
lanuch_grid = ",".join(lanuch_grid)
op_dict = {"op_name": op_name, "reset_zero_when_tune": ""}
op_dict["triton_kernel_args"] = ",".join(
modified_arg_exclude_constexpr)
op_dict["triton_kernel_args"] = ",".join(modified_arg_exclude_constexpr)
op_dict["key"] = ",".join(self.key_args)
# when tunning, we need to reset the out to zero.
if "reset_zero_when_tune" in other_config.keys():
op_dict["reset_zero_when_tune"] = other_config[
"reset_zero_when_tune"]
op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"]
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
so_path = find_so_path(generated_dir, python_package_name)
@@ -694,17 +697,19 @@ class KernelInterface:
SubstituteTemplate(
self.custom_op_template,
op_dict,
))
)
)
f.close()
# ahead of time compile command.
aot_template = (
f"""{python_path} {compile_file} {py_script_file} """ +
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
+ f"""--out-name {op_name}_kernel """ +
""" -w {num_warps} -ns {num_stages} """ +
f""" -s"{address_hint} {value_hint} {const_args}" """ +
f""" -g "{lanuch_grid}" """)
f"""{python_path} {compile_file} {py_script_file} """
+ f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
+ f"""--out-name {op_name}_kernel """
+ """ -w {num_warps} -ns {num_stages} """
+ f""" -s"{address_hint} {value_hint} {const_args}" """
+ f""" -g "{lanuch_grid}" """
)
all_tune_config = list(self.tune_config)
if len(all_tune_config) == 0:
# when user do not specify config, we use const_hint_dict as config.
@@ -727,24 +732,24 @@ class KernelInterface:
)
raise ValueError(message)
else:
assert key in config.keys(
), f"you must specify {key} in your config."
assert key in config.keys(), f"you must specify {key} in your config."
if "num_warps" not in config.keys():
config["num_warps"] = 4
if "num_stages" not in config.keys():
config["num_stages"] = 4
for key in config:
assert config[
key] is not None, f"{key} must be specified."
codegen_command = aot_template.format(**config, )
assert config[key] is not None, f"{key} must be specified."
codegen_command = aot_template.format(
**config,
)
print(codegen_command)
codegen_commands.append(codegen_command)
multi_process_do(codegen_commands)
link_command = (
f"{python_path} {link_file} "
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
f"{python_path} {link_file} " f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel"
)
re = os.system(link_command)
assert re == 0
@@ -757,8 +762,7 @@ class KernelInterface:
so_path = find_so_path(generated_dir, python_package_name)
print("== we find so_path: ", so_path)
assert so_path is not None
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
so_path)
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
self.decorator = decorator