update flake8 version to support pre-commit in python3.12 (#3000)

* update flake8 version to support pre-commit in python3.12

* polish code
This commit is contained in:
Zero Rains
2025-07-24 16:43:31 +08:00
committed by GitHub
parent 5151bc92c8
commit 0fb37ab7e4
30 changed files with 324 additions and 275 deletions

View File

@@ -103,9 +103,9 @@ def extract_triton_kernel(kernel, file_name):
import textwrap
fn = kernel
if type(kernel) == triton.runtime.jit.JITFunction:
if isinstance(kernel, triton.runtime.jit.JITFunction):
fn = kernel.fn
elif type(kernel) == triton.runtime.autotuner.Autotuner:
elif isinstance(kernel, triton.runtime.autotuner.Autotuner):
fn = kernel.fn.fn
else:
AssertionError("error occurs")
@@ -195,14 +195,14 @@ def get_value_hint(x):
"""
hint = ""
for ele in x:
if type(ele) == int:
if isinstance(ele, int):
if ele % 16 == 0 and ele > 0:
hint += "i64:16,"
elif ele == 1:
hint += "i64:1,"
else:
hint += "i64,"
if type(ele) == float:
if isinstance(ele, float):
hint += "fp32,"
return hint
@@ -467,16 +467,16 @@ def rendering_common_template(
if arg_defaults[i] is None:
input_and_attr += f"paddle::optional<paddle::Tensor> & {arg_names[i]},"
paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),"""
elif type(arg_defaults[i]) == float:
elif isinstance(arg_defaults[i], float):
input_and_attr += f"float {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: float","""
elif type(arg_defaults[i]) == bool:
elif isinstance(arg_defaults[i], bool):
input_and_attr += f"bool {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: bool","""
elif type(arg_defaults[i]) == int:
elif isinstance(arg_defaults[i], int):
input_and_attr += f"int64_t {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: int64_t","""
elif type(arg_defaults[i]) == str:
elif isinstance(arg_defaults[i], str):
input_and_attr += f"std::string {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: std::string","""
elif arg_names[i] == "config":
@@ -629,11 +629,11 @@ class KernelInterface:
for i in range(len(all_input)):
ele = all_input[i]
if (
type(ele) == paddle.Tensor
or type(ele) == paddle.base.framework.EagerParamBase
or type(ele) == paddle.base.framework.Parameter
or type(ele) == paddle.base.framework.Variable
or type(ele) == paddle.base.libpaddle.pir.Value
isinstance(ele, paddle.Tensor)
or isinstance(ele, paddle.base.framework.EagerParamBase)
or isinstance(ele, paddle.base.framework.Parameter)
or isinstance(ele, paddle.base.framework.Variable)
or isinstance(ele, paddle.base.libpaddle.pir.Value)
):
dtypes.append(ele.dtype)
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
@@ -668,7 +668,7 @@ class KernelInterface:
lanuch_grid = list(self.grid)
for i in range(len(lanuch_grid)):
ele = lanuch_grid[i]
if type(ele) == str:
if isinstance(ele, str):
for key in const_hint_dict.keys():
if key in ele:
ele = ele.replace(key, f"{{{key}}}")