Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,22 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
try:
from .wint2_fused_moe import fused_moe_wint2_triton
__all__ = ["fused_moe_wint2_triton"]
except:
pass

View File

@@ -0,0 +1,804 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import inspect
import os
import re
import sys
import paddle
import triton
from paddle.base.framework import OpProtoHolder
from fastdeploy import envs
compile_file = triton.__path__[0] + "/tools/compile.py"
link_file = triton.__path__[0] + "/tools/link.py"
python_path = sys.executable
def SubstituteTemplate(template, values):
"""
Substitute all variables in the given template string using the provided values dictionary.
"""
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def find_so_path(generated_dir, python_package_name):
"""
find the specified so in generated_dir, if not found it will return None.
"""
so_path = []
for root, dirs, files in os.walk(generated_dir):
for file in files:
if file.endswith(python_package_name + ".so"):
so_path.append(os.path.join(root, file))
if len(so_path) == 0:
return None
else:
assert len(so_path) == 1
return so_path[0]
def multi_process_do(commands):
"""
Multi-threaded execution of commands.
"""
THREADS = 40
import multiprocessing
process = []
def one_process_work(commands, thread_id):
i = thread_id
while i < len(commands):
re = os.system(commands[i])
assert re == 0
i += THREADS
for i in range(THREADS):
p = multiprocessing.Process(target=one_process_work,
args=(commands, i))
process.append(p)
for p in process:
p.start()
for p in process:
p.join()
def extract_triton_kernel(kernel, file_name):
"""
Extract the triton kernel and write it to the specified file_name.
Args:
kernel: the triton kernel name.
file_name: the file name you want to write.
"""
import inspect
import re
import textwrap
fn = kernel
if type(kernel) == triton.runtime.jit.JITFunction:
fn = kernel.fn
elif type(kernel) == triton.runtime.autotuner.Autotuner:
fn = kernel.fn.fn
else:
AssertionError("error occurs")
py_script = textwrap.dedent(inspect.getsource(fn))
# @triton.jit must only appear once
# assert len(re.findall("@triton.jit", py_script)) == 1
assert len(re.findall("def ", py_script)) == 1
# assert len(re.findall("@haha()", py_script)) == 1
# py_script = py_script.replace("@haha()", "@triton.jit")
py_script = py_script[py_script.find("def "):]
py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script
py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr")
with open(file_name, "w") as f:
f.write(py_script)
f.close()
template_install = """
import os
generated_cu = []
for root, dirs, files in os.walk("./"):
for file in files:
if file.endswith(".c") or file.endswith(".cu"):
generated_cu.append(os.path.join(root, file))
import paddle
from paddle.utils.cpp_extension import CUDAExtension, setup
def get_gencode_flags():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return ["-gencode", "arch=compute_{{0}},code=sm_{{0}}".format(cc)]
gencode_flags = get_gencode_flags()
setup(
name="{python_package_name}",
ext_modules=CUDAExtension(
sources = generated_cu,
extra_compile_args={{
"cc": ["-lcuda"],
"nvcc": [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
]
+ gencode_flags,
}},
extra_link_args = ["-lcuda"]
),
)
"""
def get_op_name_with_suffix(op_name, x_list):
"""
Get the operator name with suffix.
"""
suffix = []
for x in x_list:
if x % 16 == 0:
suffix.append(16)
elif x == 1:
suffix.append(1)
else:
suffix.append(0)
return op_name + "_".join([str(i) for i in suffix])
def get_value_hint(x):
"""
Get the value hint from input list.
"""
hint = ""
for ele in x:
if type(ele) == int:
if ele % 16 == 0 and ele > 0:
hint += "i64:16,"
elif ele == 1:
hint += "i64:1,"
else:
hint += "i64,"
if type(ele) == float:
hint += "fp32,"
return hint
def get_dtype_str(dtype):
"""
Get the dtype str.
"""
if dtype == paddle.float16:
return "_fp16"
if dtype == paddle.float8_e4m3fn:
return "_float8_e4m3fn"
elif dtype == paddle.uint8:
return "_u8"
elif dtype == paddle.int8:
return "_i8"
elif dtype == paddle.int16:
return "_i16"
elif dtype == paddle.int32:
return "_i32"
elif dtype == paddle.int64:
return "_i64"
elif dtype == paddle.float32:
return "_fp32"
elif dtype == paddle.bfloat16:
return "_bf16"
else:
raise ValueError("Not support this dtype.")
def build_package(generated_dir, python_package_name):
"""
Build the package, not install it.
Args:
generated_dir: the source cu file dir.
python_package_name: the python package name.
"""
setup_file_path = generated_dir + "/setup_cuda.py"
python_path = sys.executable
with open(setup_file_path, "w") as f:
f.write(
template_install.format(python_package_name=python_package_name))
f.close()
install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build"
re = os.system(install_command)
assert re == 0
def rename_c_to_cu(generated_dir):
"""
Rename the .c files int generated_dir to .cu file, because the triton aot tool generate the .c files.
"""
# rename the .c file to .cu
for filename in os.listdir(generated_dir):
if filename.endswith(".c"):
old_path = os.path.join(generated_dir, filename)
new_path = os.path.join(generated_dir, filename + "u")
os.rename(old_path, new_path)
def get_pointer_hint(dtypes):
"""
Get the pointer hint from input list.
"""
hint = ""
for ele in dtypes:
if ele == paddle.float16:
hint += "*fp16:16,"
elif ele == paddle.uint8:
hint += "*u8:16,"
elif ele == paddle.int8:
hint += "*i8:16,"
elif ele == paddle.int16:
hint += "*i16:16,"
elif ele == paddle.float32:
hint += "*fp32:16,"
elif ele == paddle.bfloat16:
hint += "*bf16:16,"
elif ele == paddle.int32:
hint += "*i32:16,"
elif ele == paddle.int64:
hint += "*i64,"
elif ele == paddle.float8_e4m3fn:
hint += "*fp8e4nv:16,"
return hint
paddle_custom_op_head_part = """#include <vector>
#include <map>
#include "${op_name}_kernel.h"
#include "paddle/extension.h"
std::map<std::vector<int>, int> map_problem_${op_name};
CUdeviceptr get_tensor_ptr(const paddle::Tensor& input){
if (input.type() == paddle::DataType::FLOAT16) {
return (CUdeviceptr)(input.data<phi::dtype::float16>());
} else if (input.type() == paddle::DataType::BFLOAT16) {
return (CUdeviceptr)(input.data<phi::dtype::bfloat16>());
} else if (input.type() == paddle::DataType::INT32) {
return (CUdeviceptr)(input.data<int>());
} else if (input.type() == paddle::DataType::FLOAT32) {
return (CUdeviceptr)(input.data<float>());
} else if (input.type() == paddle::DataType::UINT8) {
return (CUdeviceptr)(input.data<uint8_t>());
} else if (input.type() == paddle::DataType::INT8) {
return (CUdeviceptr)(input.data<int8_t>());
} else if (input.type() == paddle::DataType::INT64) {
return (CUdeviceptr)(input.data<int64_t>());
} else if (input.type() == paddle::DataType::INT32) {
return (CUdeviceptr)(input.data<int32_t>());
} else if (input.type() == paddle::DataType::INT16) {
return (CUdeviceptr)(input.data<int16_t>());
} else if (input.type() == paddle::DataType::FLOAT8_E4M3FN) {
return (CUdeviceptr)(input.data<phi::dtype::float8_e4m3fn>());
} else {
assert(false);
return (CUdeviceptr)(nullptr);
}
}
int triton_cdiv(int x, int y) {
int result = (x + y - 1) / y;
return (int)(result);
}
"""
tune_and_invoke_part = """
std::vector<int> problem_size = {${key}};
auto run_triton_kernel = [&](int algo_id) -> CUresult{
return ${op_name}_kernel(run_stream,
${triton_kernel_args},
algo_id);
};
map_problem_${op_name}[problem_size] = 0;
if (!map_problem_${op_name}.count(problem_size)) {
std::cout << "we are tuning for ${op_name} which key is: {";
for (int i = 0; i < problem_size.size(); i++) {
std::cout << problem_size[i] << ", ";
}
std::cout << "}" << std::endl;
float min_time = 10000.f;
int select_id = -1;
constexpr int WARMUP = 5;
constexpr int REPEAT = 10;
for (int algo_id = 0; algo_id < ${op_name}_kernel_get_num_algos(); ++algo_id) {
cudaEvent_t beg[REPEAT];
cudaEvent_t end[REPEAT];
float elapsed_times[REPEAT];
auto status = CUDA_SUCCESS;
for (int ii = 0; ii < WARMUP + REPEAT; ii++) {
int repeat_id = ii - WARMUP;
if (repeat_id >= 0) {
(cudaEventCreate(beg + repeat_id));
(cudaEventCreate(end + repeat_id));
(cudaEventRecord(beg[repeat_id]));
}
auto flush_l2_cache = paddle::full(
{10 * 1024 * 1024}, 0, paddle::DataType::INT32, ${arbitary_output_name}.place());
// std::cout << &flush_l2_cache << std::endl;
// this is used when out is need to be reset to zero, such as split-k gemm.
${reset_zero_when_tune};
status = run_triton_kernel(algo_id);
// assert(status == CUDA_SUCCESS);
if (repeat_id >= 0) {
(cudaEventRecord(end[repeat_id]));
(cudaEventSynchronize(end[repeat_id]));
(cudaEventElapsedTime(
elapsed_times + repeat_id, beg[repeat_id], end[repeat_id]));
}
}
float avg_elapsed_time = 0.f;
for (int ii = 0; ii < REPEAT; ++ii) {
avg_elapsed_time += elapsed_times[ii];
}
std::cout << "algo id " << algo_id << " costs " << avg_elapsed_time << " ms" << std::endl;
if (avg_elapsed_time < min_time && status == CUDA_SUCCESS) {
min_time = avg_elapsed_time;
select_id = algo_id;
}
}
map_problem_${op_name}[problem_size] = select_id;
std::cout << "select algo id: " << select_id << std::endl;
${reset_zero_when_tune};
}
if (map_problem_${op_name}.count(problem_size)) {
int algo_id = map_problem_${op_name}[problem_size];
auto status = run_triton_kernel(algo_id);
assert(status == CUDA_SUCCESS);
}
"""
common_template = ("""
std::vector<paddle::Tensor> ${op_name}_func(${input_and_attr}) {
${prepare_attr_for_triton_kernel}
${prepare_ptr_for_triton_kernel}
auto run_stream = ${arbitary_output_name}.stream();
""" + tune_and_invoke_part + """
return {${return_tensor_names}};
}
${d2s_infer_code}
PD_BUILD_OP(${op_name})
.Inputs({${paddle_input_sig}})
.Outputs({${paddle_output_sig}})
.Attrs({${paddle_attr_sig}})
.SetKernelFn(PD_KERNEL(${op_name}_func))
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
""")
def rendering_common_template(
func,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names=None,
d2s_infer_code="",
):
"""
Render a template with given function and its arguments.
Args:
func: The function to render.
prepare_attr_for_triton_kernel: The code snippet that prepares attributes for Triton kernel.
prepare_ptr_for_triton_kernel: The code snippet that prepares pointers for Triton kernel.
return_tensor_names: The names of the returned tensors. Default is None.
"""
signature = inspect.signature(func)
arg_names = [v.name for v in signature.parameters.values()]
arg_defaults = [v.default for v in signature.parameters.values()]
input_and_attr = ""
paddle_input_sig = ""
paddle_attr_sig = ""
if return_tensor_names is None:
return_tensor_names = "useless"
prepare_ptr_for_triton_kernel += (
"auto useless = paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());"
)
for i in range(len(arg_names)):
if arg_defaults[i] is None:
input_and_attr += f"paddle::optional<paddle::Tensor> & {arg_names[i]},"
paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),"""
elif type(arg_defaults[i]) == float:
input_and_attr += f"float {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: float","""
elif type(arg_defaults[i]) == bool:
input_and_attr += f"bool {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: bool","""
elif type(arg_defaults[i]) == int:
input_and_attr += f"int64_t {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: int64_t","""
elif type(arg_defaults[i]) == str:
input_and_attr += f"std::string {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: std::string","""
elif arg_names[i] == "config":
continue
else:
input_and_attr += f"const paddle::Tensor & {arg_names[i]},"
paddle_input_sig += f""""{arg_names[i]}","""
input_and_attr = input_and_attr[:-1]
paddle_input_sig = paddle_input_sig[:-1]
if len(paddle_attr_sig) > 1:
paddle_attr_sig = paddle_attr_sig[:-1]
paddle_output_sig = ""
arbitary_output_name = ""
for name in return_tensor_names.split(","):
name = name.strip()
arbitary_output_name = name
paddle_output_sig += f""""{name}","""
paddle_output_sig = paddle_output_sig[:-1]
if "${op_name}_InferShape" not in d2s_infer_code:
d2s_infer_shape_part = (
"std::vector<std::vector<int64_t>> ${op_name}_InferShape("
"const std::vector<int64_t>& A_shape) {"
"return {${tmp}};"
"}\n ")
tmp = ",".join(["A_shape"] * len(return_tensor_names.split(",")))
tmp_dict = {"tmp": tmp}
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part,
tmp_dict)
d2s_infer_code += d2s_infer_shape_part
if "${op_name}_InferDtype" not in d2s_infer_code:
d2s_infer_dtype_part = (
"std::vector<paddle::DataType> ${op_name}_InferDtype("
"const paddle::DataType& A_dtype) {"
"return {${tmp}};"
"}\n ")
tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(",")))
tmp_dict = {"tmp": tmp}
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part,
tmp_dict)
d2s_infer_code += d2s_infer_dtype_part
result_str = SubstituteTemplate(
common_template,
{
"input_and_attr": input_and_attr,
"prepare_attr_for_triton_kernel": prepare_attr_for_triton_kernel,
"prepare_ptr_for_triton_kernel": prepare_ptr_for_triton_kernel,
"return_tensor_names": return_tensor_names,
"arbitary_output_name": arbitary_output_name,
"d2s_infer_code": d2s_infer_code,
"paddle_input_sig": paddle_input_sig,
"paddle_output_sig": paddle_output_sig,
"paddle_attr_sig": paddle_attr_sig,
},
)
return paddle_custom_op_head_part + result_str
class KernelInterface:
"""
triton kernel interface.
"""
def __init__(
self,
func,
other_config,
key_args=["1"],
):
"""
triton kernel interface.
"""
self.func = func
self.key_args = key_args
signature = inspect.signature(func)
self.arg_names = [v.name for v in signature.parameters.values()]
for ele in self.arg_names:
assert self.arg_names.count(ele) == 1
# arg_defaults = [v.default for v in signature.parameters.values()]
# self.annotations = {
# name: ty for name, ty in func.__annotations__.items()
# }
self.annotations = dict(func.__annotations__)
self.constexprs = [
self.arg_names.index(name) for name in self.arg_names
if self.annotations.get(name) == triton.language.core.constexpr
]
self.arg_exclude_constexpr = [
self.arg_names[i] for i in range(len(self.arg_names))
if i not in self.constexprs
]
import textwrap
py_script = textwrap.dedent(inspect.getsource(func))
import re
pat = r"def\s" + func.__name__
func_begin = re.findall(pat, py_script)
assert len(func_begin) == 1
func_begin = func_begin[0]
py_script = py_script[py_script.find(func_begin):]
def decorator(*args, **kwargs):
"""
decorator for triton kernels.
Args:
*args: positional arguments
**kwargs: keyword arguments
"""
all_input = []
for i in range(len(args)):
all_input.append(args[i])
position_arguments_num = len(all_input)
for i in range(position_arguments_num, len(self.arg_names)):
if self.arg_names[i] in kwargs.keys():
all_input.append(kwargs[self.arg_names[i]])
else:
# means this input is not specified, it muse be a tl.constexpr.
assert i in self.constexprs
all_input.append(None)
dtypes = []
x_list = []
const_args = [self.arg_names[i] for i in self.constexprs]
# we dont allow there are two strings in const_args, and one is a substring of the other.
for i in const_args:
for j in const_args:
if i != j and i.find(j) != -1:
raise ValueError(
f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, "
"please modify your triton kernel arguments names to avoid this."
)
modified_arg_exclude_constexpr = self.arg_exclude_constexpr
const_hint_dict = {}
for i in range(len(all_input)):
ele = all_input[i]
if (type(ele) == paddle.Tensor
or type(ele) == paddle.base.framework.EagerParamBase
or type(ele) == paddle.base.framework.Parameter
or type(ele) == paddle.base.framework.Variable
or type(ele) == paddle.base.libpaddle.pir.Value):
dtypes.append(ele.dtype)
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
elif i in self.constexprs:
const_hint_dict[self.arg_names[i]] = ele
else:
x_list.append(ele)
op_name = self.op_name
python_package_name = f"{op_name}_package"
tp_rank = paddle.distributed.get_rank()
generated_dir = envs.FD_TRITON_KERNEL_CACHE_DIR
if generated_dir is None:
generated_dir = f"/tmp/triton_cache/rank{tp_rank}"
print("the kernel cache dir is:", generated_dir)
assert (generated_dir is not None), (
"TRITON_KERNEL_CACHE_DIR is None, please set it such as "
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache ")
generated_dir = f"{generated_dir}/{op_name}"
os.makedirs(generated_dir, exist_ok=True)
py_script_file = f"{generated_dir}/triton_kernels.py"
extract_triton_kernel(func, py_script_file)
address_hint = get_pointer_hint(dtypes)
value_hint = get_value_hint(x_list)
const_args = [f"{{{ele}}}" for ele in const_args]
const_args = ",".join(const_args)
lanuch_grid = list(self.grid)
for i in range(len(lanuch_grid)):
ele = lanuch_grid[i]
if type(ele) == str:
for key in const_hint_dict.keys():
if key in ele:
ele = ele.replace(key, f"{{{key}}}")
else:
ele = str(ele)
lanuch_grid[i] = ele
if len(lanuch_grid) < 3:
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
lanuch_grid = ",".join(lanuch_grid)
op_dict = {"op_name": op_name, "reset_zero_when_tune": ""}
op_dict["triton_kernel_args"] = ",".join(
modified_arg_exclude_constexpr)
op_dict["key"] = ",".join(self.key_args)
# when tunning, we need to reset the out to zero.
if "reset_zero_when_tune" in other_config.keys():
op_dict["reset_zero_when_tune"] = other_config[
"reset_zero_when_tune"]
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
so_path = find_so_path(generated_dir, python_package_name)
if so_path is None:
print("== we do not find so_path, we need to compile it")
with open(paddle_custom_op_file_path, "w") as f:
f.write(
SubstituteTemplate(
self.custom_op_template,
op_dict,
))
f.close()
# ahead of time compile command.
aot_template = (
f"""{python_path} {compile_file} {py_script_file} """ +
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
+ f"""--out-name {op_name}_kernel """ +
""" -w {num_warps} -ns {num_stages} """ +
f""" -s"{address_hint} {value_hint} {const_args}" """ +
f""" -g "{lanuch_grid}" """)
all_tune_config = list(self.tune_config)
if len(all_tune_config) == 0:
# when user do not specify config, we use const_hint_dict as config.
all_tune_config = [const_hint_dict]
# reset const_hint_dict as empty.
const_hint_dict = {}
codegen_commands = []
for config in all_tune_config:
for key in const_hint_dict.keys():
if const_hint_dict[key] is not None:
if key not in config.keys():
config[key] = const_hint_dict[key]
else:
if config[key] == const_hint_dict[key]:
pass
else:
message = (
f"you specify {key} both in arguments and config, "
"and they are not same, this is wrong."
)
raise ValueError(message)
else:
assert key in config.keys(
), f"you must specify {key} in your config."
if "num_warps" not in config.keys():
config["num_warps"] = 4
if "num_stages" not in config.keys():
config["num_stages"] = 4
for key in config:
assert config[
key] is not None, f"{key} must be specified."
codegen_command = aot_template.format(**config, )
print(codegen_command)
codegen_commands.append(codegen_command)
multi_process_do(codegen_commands)
link_command = (
f"{python_path} {link_file} "
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
re = os.system(link_command)
assert re == 0
# rename the .c file to .cu
rename_c_to_cu(generated_dir)
# build the package to so, not install
build_package(generated_dir, python_package_name)
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
so_path = find_so_path(generated_dir, python_package_name)
print("== we find so_path: ", so_path)
assert so_path is not None
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
so_path)
self.decorator = decorator
def __getitem__(self, op_name_and_grid):
"""
override the operator [], which will call the decorator function.
Args:
op_name_and_grid: the name of the operator and the grid size.
Returns:
the decorator function.
"""
assert len(op_name_and_grid) >= 3, "len(op_name_and_grid) must >= 3."
self.op_name = op_name_and_grid[0]
self.custom_op_template = op_name_and_grid[1]
self.grid = op_name_and_grid[2]
if len(op_name_and_grid) == 3:
self.tune_config = {}
else:
self.tune_config = op_name_and_grid[3]
return self.decorator
def paddle_use_triton(other_config={}, key=[]):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
def decorator(func):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
return KernelInterface(func, other_config, key)
return decorator

View File

@@ -0,0 +1,549 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
import triton.language as tl
from paddle import _C_ops
from paddle.base.framework import OpProtoHolder
from paddle.framework import in_dynamic_or_pir_mode
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
get_dtype_str, paddle_use_triton, rendering_common_template)
BLOCK_SIZE_M = 16
def invoke_fused_moe_kernel(
A,
B,
C,
B_scale,
B_super_scale,
B_code_scale,
B_code_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight=False,
top_k=-1,
group_size=-1,
):
"""
Invoke Fused Moe Kernel
"""
KK = A.shape[-1]
NN = B.shape[-1]
sstride_am, sstride_ak = A.shape[1], 1
sstride_be, sstride_bk, sstride_bn = B.shape[1] * B.shape[2], B.shape[2], 1
sstride_cm, sstride_cn = C.shape[-1], 1
sstride_bse, sstride_bsk, sstride_bsn = B_scale.shape[1] * B_scale.shape[
2], B_scale.shape[2], 1
sstride_bce, sstride_bck, sstride_bcn = B_code_scale.shape[1], 1, 1
ddouble_quant = B_super_scale is not None
prepare_attr_for_triton_kernel = """
auto N = B.shape()[2];
auto K = A.shape()[1];
auto EM = sorted_token_ids.shape()[0];
auto num_valid_tokens = (topk_ids.shape()[0]) * (topk_ids.shape()[1]);
auto stride_am = A.strides()[0];
auto stride_ak = A.strides()[1];
auto stride_be = B.strides()[0];
auto stride_bk = B.strides()[1];
auto stride_bn = B.strides()[2];
auto stride_cm = C.strides()[1];
auto stride_cn = C.strides()[2];
auto stride_bse = B_scale.strides()[0];
auto stride_bsk = B_scale.strides()[1];
auto stride_bsn = 1;
auto stride_bce = B_code_scale.strides()[0];
auto stride_bck = 1;
auto stride_bcn = 1;
auto double_quant = true;
"""
if mul_routed_weight:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 8,
}
else:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 12,
}
configs = []
configs.append(dict(config))
op_name = "wint2_moe_ffn"
op_name += f"{get_dtype_str(A.dtype)}"
op_name += f"{B.shape[0]}"
op_name += f"{B.shape[1]}"
op_name += f"{B.shape[2]}"
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
prepare_ptr_for_triton_kernel = """
CUdeviceptr input_ptrs[11] = {
get_tensor_ptr(A),
get_tensor_ptr(B),
get_tensor_ptr(C),
get_tensor_ptr(B_scale),
get_tensor_ptr(B_super_scale),
get_tensor_ptr(B_code_scale),
get_tensor_ptr(B_code_zp),
get_tensor_ptr(topk_weights),
get_tensor_ptr(sorted_token_ids),
get_tensor_ptr(expert_ids),
get_tensor_ptr(num_tokens_post_padded),
};
"""
template_used = rendering_common_template(
invoke_fused_moe_kernel,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
)
grid = (
"(EM+BLOCK_SIZE_M-1)/BLOCK_SIZE_M * ((N+BLOCK_SIZE_N-1)/BLOCK_SIZE_N)",
)
moe_wint2_ffn_kernel[(op_name, template_used, grid, configs)](
A,
B,
C,
B_scale,
B_super_scale,
B_code_scale,
B_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
NN,
KK,
-1, #EEM,
-1, #nnum_valid_tokens,
sstride_am,
sstride_ak,
sstride_be,
sstride_bk,
sstride_bn,
sstride_cm,
sstride_cn,
sstride_bse,
sstride_bsk,
sstride_bsn,
sstride_bce,
sstride_bck,
sstride_bcn,
MUL_ROUTED_WEIGHT=(int)(mul_routed_weight),
USE_DOUBLE_QUANT=(int)(ddouble_quant),
top_k=top_k,
BLOCK_SIZE_K=group_size,
)
if in_dynamic_or_pir_mode():
outs = _C_ops._run_custom_op(
op_name,
A,
B,
C,
B_scale,
B_super_scale,
B_code_scale,
B_code_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
group_size,
)
return outs[0]
@paddle_use_triton(key=["1"], )
def moe_wint2_ffn_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
bs_ptr,
superbs_ptr,
codebs_ptr,
codebzp_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,
stride_bsn,
stride_bce,
stride_bck,
stride_bcn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
USE_DOUBLE_QUANT: tl.constexpr,
top_k: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
if USE_DOUBLE_QUANT:
# INT4 scale
s_packnums: tl.constexpr = 2
bzp: tl.constexpr = 32
w_mask: tl.constexpr = 0x3F
pack_num: tl.constexpr = 4
real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
compute_type = c_ptr.dtype.element_ty
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, real_k_size)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_bk[None, :] * pack_num * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn # group-wise, need advanced
off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn
# load channel-wise scale & zero-point
if USE_DOUBLE_QUANT:
superbs_ptrs = superbs_ptr + off_set # channel-wise
super_bs = tl.load(superbs_ptrs) # super scale
codebs_ptrs = codebs_ptr + off_set # channel-wise
code_bs = tl.load(codebs_ptrs) # code scale
codebzp_ptrs = codebzp_ptr + off_set # channel-wise
code_bzp = tl.load(codebzp_ptrs) # code zp
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
b = tl.load(b_ptrs)
bs = tl.load(bs_ptrs)
if USE_DOUBLE_QUANT:
s_shift_bits = (1 - k % s_packnums) * 4
bs = ((bs >> s_shift_bits) & 0xF) * super_bs
# reverse to int16
b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to(
tl.int16)
# dequant
b1 = (((b >> 9) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b1 = (((b >> 6) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 1,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b1 = (((b >> 3) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 2,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b = ((b & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 3,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b.to(a.dtype))
b_ptrs += real_k_size * stride_bk
a_ptrs += BLOCK_SIZE_K * stride_ak
# advance scale ptr
if USE_DOUBLE_QUANT:
bs_ptrs += stride_bsk * (k % s_packnums)
else:
bs_ptrs += stride_bsk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def fused_moe_wint2_impl(
hidden_states,
ffn1_quant_weight,
ffn2_quant_weight,
topk_weights,
topk_ids,
# inplace: bool = False,
ffn1_weight_scale=None,
ffn2_weight_scale=None,
ffn1_super_scales=None,
ffn2_super_scales=None,
ffn1_code_scale=None,
ffn2_code_scale=None,
ffn1_code_zp=None,
ffn2_code_zp=None,
group_size=64,
bit="wint2",
):
"""
Implementation of Fused MoE kernels on GPU.
"""
# Check constraints.
# A: [M, K]
# B: [E, K, N]
# assert hidden_states.shape[1] == ffn1_weight_scale.shape[1],
# f"Hidden size mismatch, {hidden_states.shape[1]} != {ffn1_quant_weight.shape[1]}"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert ffn1_quant_weight.is_contiguous(
), "Expert weights1 must be contiguous"
assert ffn2_quant_weight.is_contiguous(
), "Expert weights2 must be contiguous"
assert group_size > 0, "Group size must be greater than 0"
num_tokens, K = hidden_states.shape
E, _, N = ffn1_quant_weight.shape
M = num_tokens
if group_size < 0:
group_size = K // ffn1_weight_scale.shape[1]
top_k = topk_ids.shape[1]
intermediate_cache1 = paddle.empty(
[M, top_k, N],
dtype=hidden_states.dtype,
)
intermediate_cache2 = paddle.empty(
(M * top_k, N // 2),
dtype=hidden_states.dtype,
)
intermediate_cache3 = paddle.empty(
(M, top_k, K),
dtype=hidden_states.dtype,
)
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, E, BLOCK_SIZE_M)
invoke_fused_moe_kernel(
A=hidden_states,
B=ffn1_quant_weight,
C=intermediate_cache1,
B_scale=ffn1_weight_scale,
B_super_scale=ffn1_super_scales,
B_code_scale=ffn1_code_scale,
B_code_zp=ffn1_code_zp,
topk_weights=topk_weights,
topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
mul_routed_weight=False,
top_k=top_k,
group_size=group_size,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1.reshape([-1, N]))
invoke_fused_moe_kernel(
A=intermediate_cache2,
B=ffn2_quant_weight,
C=intermediate_cache3,
B_scale=ffn2_weight_scale,
B_super_scale=ffn2_super_scales,
B_code_scale=ffn2_code_scale,
B_code_zp=ffn2_code_zp,
topk_weights=topk_weights,
topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
mul_routed_weight=True,
top_k=1,
group_size=group_size,
)
out_hidden_states = paddle.sum(intermediate_cache3, axis=1)
return out_hidden_states
def fused_moe_wint2_triton(
hidden_states,
ffn1_quant_weight,
ffn2_quant_weight,
scores,
gate_correction_bias,
topk,
ffn1_weight_scale,
ffn2_weight_scale,
ffn1_super_scales,
ffn2_super_scales,
ffn1_code_scale,
ffn2_code_scale,
ffn1_code_zp,
ffn2_code_zp,
):
"""
Fuse MoE with WINT2 quantization scheme and Triton backend.
Args:
hidden_states: input tensor.
ffn1_quant_weight: ffn1 weight matrix for experts.
ffn2_quant_weight: ffn2 weight matrix for experts.
scores: gate scores.
gate_correction_bias: bias correction for gates.
topk: number of experts to use.
ffn1_weight_scale: scaling factor for ffn1_quant_weight.
ffn2_weight_scale: scaling factor for ffn2_quant_weight.
ffn1_super_scales: super scaling factor for ffn1_scale.
ffn2_super_scales: super scaling factor for ffn2_weight_scale.
ffn1_code_scale: code scaling factor for ffn1_quant_weight.
ffn2_code_scale: code scaling factor for ffn2_quant_weight.
ffn1_code_zp: code zero point for ffn1_quant_weight.
ffn2_code_zp: code zero point for ffn2_quant_weight.
Returns:
output tensor.
"""
score = gate_correction_bias + scores
_, topk_ids = paddle.topk(score, k=topk, axis=-1)
topk_weights, _ = paddle.topk(scores, k=topk, axis=-1)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
return fused_moe_wint2_impl(
hidden_states,
ffn1_quant_weight,
ffn2_quant_weight,
topk_weights,
topk_ids,
ffn1_weight_scale,
ffn2_weight_scale,
ffn1_super_scales,
ffn2_super_scales,
ffn1_code_scale,
ffn2_code_scale,
ffn1_code_zp,
ffn2_code_zp,
bit="wint2",
)