mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
update flake8 version to support pre-commit in python3.12 (#3000)
* update flake8 version to support pre-commit in python3.12 * polish code
This commit is contained in:
@@ -103,9 +103,9 @@ def extract_triton_kernel(kernel, file_name):
|
||||
import textwrap
|
||||
|
||||
fn = kernel
|
||||
if type(kernel) == triton.runtime.jit.JITFunction:
|
||||
if isinstance(kernel, triton.runtime.jit.JITFunction):
|
||||
fn = kernel.fn
|
||||
elif type(kernel) == triton.runtime.autotuner.Autotuner:
|
||||
elif isinstance(kernel, triton.runtime.autotuner.Autotuner):
|
||||
fn = kernel.fn.fn
|
||||
else:
|
||||
AssertionError("error occurs")
|
||||
@@ -195,14 +195,14 @@ def get_value_hint(x):
|
||||
"""
|
||||
hint = ""
|
||||
for ele in x:
|
||||
if type(ele) == int:
|
||||
if isinstance(ele, int):
|
||||
if ele % 16 == 0 and ele > 0:
|
||||
hint += "i64:16,"
|
||||
elif ele == 1:
|
||||
hint += "i64:1,"
|
||||
else:
|
||||
hint += "i64,"
|
||||
if type(ele) == float:
|
||||
if isinstance(ele, float):
|
||||
hint += "fp32,"
|
||||
return hint
|
||||
|
||||
@@ -467,16 +467,16 @@ def rendering_common_template(
|
||||
if arg_defaults[i] is None:
|
||||
input_and_attr += f"paddle::optional<paddle::Tensor> & {arg_names[i]},"
|
||||
paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),"""
|
||||
elif type(arg_defaults[i]) == float:
|
||||
elif isinstance(arg_defaults[i], float):
|
||||
input_and_attr += f"float {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: float","""
|
||||
elif type(arg_defaults[i]) == bool:
|
||||
elif isinstance(arg_defaults[i], bool):
|
||||
input_and_attr += f"bool {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: bool","""
|
||||
elif type(arg_defaults[i]) == int:
|
||||
elif isinstance(arg_defaults[i], int):
|
||||
input_and_attr += f"int64_t {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: int64_t","""
|
||||
elif type(arg_defaults[i]) == str:
|
||||
elif isinstance(arg_defaults[i], str):
|
||||
input_and_attr += f"std::string {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: std::string","""
|
||||
elif arg_names[i] == "config":
|
||||
@@ -629,11 +629,11 @@ class KernelInterface:
|
||||
for i in range(len(all_input)):
|
||||
ele = all_input[i]
|
||||
if (
|
||||
type(ele) == paddle.Tensor
|
||||
or type(ele) == paddle.base.framework.EagerParamBase
|
||||
or type(ele) == paddle.base.framework.Parameter
|
||||
or type(ele) == paddle.base.framework.Variable
|
||||
or type(ele) == paddle.base.libpaddle.pir.Value
|
||||
isinstance(ele, paddle.Tensor)
|
||||
or isinstance(ele, paddle.base.framework.EagerParamBase)
|
||||
or isinstance(ele, paddle.base.framework.Parameter)
|
||||
or isinstance(ele, paddle.base.framework.Variable)
|
||||
or isinstance(ele, paddle.base.libpaddle.pir.Value)
|
||||
):
|
||||
dtypes.append(ele.dtype)
|
||||
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
|
||||
@@ -668,7 +668,7 @@ class KernelInterface:
|
||||
lanuch_grid = list(self.grid)
|
||||
for i in range(len(lanuch_grid)):
|
||||
ele = lanuch_grid[i]
|
||||
if type(ele) == str:
|
||||
if isinstance(ele, str):
|
||||
for key in const_hint_dict.keys():
|
||||
if key in ele:
|
||||
ele = ele.replace(key, f"{{{key}}}")
|
||||
|
Reference in New Issue
Block a user