mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
22
fastdeploy/model_executor/ops/triton_ops/__init__.py
Normal file
22
fastdeploy/model_executor/ops/triton_ops/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
try:
|
||||
from .wint2_fused_moe import fused_moe_wint2_triton
|
||||
|
||||
__all__ = ["fused_moe_wint2_triton"]
|
||||
except:
|
||||
pass
|
804
fastdeploy/model_executor/ops/triton_ops/triton_utils.py
Normal file
804
fastdeploy/model_executor/ops/triton_ops/triton_utils.py
Normal file
@@ -0,0 +1,804 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import triton
|
||||
from paddle.base.framework import OpProtoHolder
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
compile_file = triton.__path__[0] + "/tools/compile.py"
|
||||
link_file = triton.__path__[0] + "/tools/link.py"
|
||||
python_path = sys.executable
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
"""
|
||||
Substitute all variables in the given template string using the provided values dictionary.
|
||||
"""
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def find_so_path(generated_dir, python_package_name):
|
||||
"""
|
||||
find the specified so in generated_dir, if not found it will return None.
|
||||
"""
|
||||
|
||||
so_path = []
|
||||
for root, dirs, files in os.walk(generated_dir):
|
||||
for file in files:
|
||||
if file.endswith(python_package_name + ".so"):
|
||||
so_path.append(os.path.join(root, file))
|
||||
if len(so_path) == 0:
|
||||
return None
|
||||
else:
|
||||
assert len(so_path) == 1
|
||||
return so_path[0]
|
||||
|
||||
|
||||
def multi_process_do(commands):
|
||||
"""
|
||||
Multi-threaded execution of commands.
|
||||
"""
|
||||
THREADS = 40
|
||||
import multiprocessing
|
||||
|
||||
process = []
|
||||
|
||||
def one_process_work(commands, thread_id):
|
||||
i = thread_id
|
||||
while i < len(commands):
|
||||
re = os.system(commands[i])
|
||||
assert re == 0
|
||||
i += THREADS
|
||||
|
||||
for i in range(THREADS):
|
||||
p = multiprocessing.Process(target=one_process_work,
|
||||
args=(commands, i))
|
||||
process.append(p)
|
||||
for p in process:
|
||||
p.start()
|
||||
for p in process:
|
||||
p.join()
|
||||
|
||||
|
||||
def extract_triton_kernel(kernel, file_name):
|
||||
"""
|
||||
Extract the triton kernel and write it to the specified file_name.
|
||||
|
||||
Args:
|
||||
kernel: the triton kernel name.
|
||||
file_name: the file name you want to write.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import textwrap
|
||||
|
||||
fn = kernel
|
||||
if type(kernel) == triton.runtime.jit.JITFunction:
|
||||
fn = kernel.fn
|
||||
elif type(kernel) == triton.runtime.autotuner.Autotuner:
|
||||
fn = kernel.fn.fn
|
||||
else:
|
||||
AssertionError("error occurs")
|
||||
py_script = textwrap.dedent(inspect.getsource(fn))
|
||||
|
||||
# @triton.jit must only appear once
|
||||
# assert len(re.findall("@triton.jit", py_script)) == 1
|
||||
assert len(re.findall("def ", py_script)) == 1
|
||||
# assert len(re.findall("@haha()", py_script)) == 1
|
||||
# py_script = py_script.replace("@haha()", "@triton.jit")
|
||||
|
||||
py_script = py_script[py_script.find("def "):]
|
||||
py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script
|
||||
|
||||
py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr")
|
||||
|
||||
with open(file_name, "w") as f:
|
||||
f.write(py_script)
|
||||
f.close()
|
||||
|
||||
|
||||
template_install = """
|
||||
|
||||
import os
|
||||
generated_cu = []
|
||||
for root, dirs, files in os.walk("./"):
|
||||
for file in files:
|
||||
if file.endswith(".c") or file.endswith(".cu"):
|
||||
generated_cu.append(os.path.join(root, file))
|
||||
|
||||
|
||||
import paddle
|
||||
from paddle.utils.cpp_extension import CUDAExtension, setup
|
||||
|
||||
|
||||
def get_gencode_flags():
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
return ["-gencode", "arch=compute_{{0}},code=sm_{{0}}".format(cc)]
|
||||
|
||||
|
||||
gencode_flags = get_gencode_flags()
|
||||
|
||||
|
||||
|
||||
setup(
|
||||
name="{python_package_name}",
|
||||
ext_modules=CUDAExtension(
|
||||
sources = generated_cu,
|
||||
extra_compile_args={{
|
||||
"cc": ["-lcuda"],
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
||||
]
|
||||
+ gencode_flags,
|
||||
}},
|
||||
extra_link_args = ["-lcuda"]
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def get_op_name_with_suffix(op_name, x_list):
|
||||
"""
|
||||
Get the operator name with suffix.
|
||||
"""
|
||||
suffix = []
|
||||
for x in x_list:
|
||||
if x % 16 == 0:
|
||||
suffix.append(16)
|
||||
elif x == 1:
|
||||
suffix.append(1)
|
||||
else:
|
||||
suffix.append(0)
|
||||
return op_name + "_".join([str(i) for i in suffix])
|
||||
|
||||
|
||||
def get_value_hint(x):
|
||||
"""
|
||||
Get the value hint from input list.
|
||||
"""
|
||||
hint = ""
|
||||
for ele in x:
|
||||
if type(ele) == int:
|
||||
if ele % 16 == 0 and ele > 0:
|
||||
hint += "i64:16,"
|
||||
elif ele == 1:
|
||||
hint += "i64:1,"
|
||||
else:
|
||||
hint += "i64,"
|
||||
if type(ele) == float:
|
||||
hint += "fp32,"
|
||||
return hint
|
||||
|
||||
|
||||
def get_dtype_str(dtype):
|
||||
"""
|
||||
Get the dtype str.
|
||||
"""
|
||||
if dtype == paddle.float16:
|
||||
return "_fp16"
|
||||
if dtype == paddle.float8_e4m3fn:
|
||||
return "_float8_e4m3fn"
|
||||
elif dtype == paddle.uint8:
|
||||
return "_u8"
|
||||
elif dtype == paddle.int8:
|
||||
return "_i8"
|
||||
elif dtype == paddle.int16:
|
||||
return "_i16"
|
||||
elif dtype == paddle.int32:
|
||||
return "_i32"
|
||||
elif dtype == paddle.int64:
|
||||
return "_i64"
|
||||
elif dtype == paddle.float32:
|
||||
return "_fp32"
|
||||
elif dtype == paddle.bfloat16:
|
||||
return "_bf16"
|
||||
else:
|
||||
raise ValueError("Not support this dtype.")
|
||||
|
||||
|
||||
def build_package(generated_dir, python_package_name):
|
||||
"""
|
||||
Build the package, not install it.
|
||||
|
||||
Args:
|
||||
generated_dir: the source cu file dir.
|
||||
python_package_name: the python package name.
|
||||
"""
|
||||
setup_file_path = generated_dir + "/setup_cuda.py"
|
||||
python_path = sys.executable
|
||||
with open(setup_file_path, "w") as f:
|
||||
f.write(
|
||||
template_install.format(python_package_name=python_package_name))
|
||||
f.close()
|
||||
install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build"
|
||||
re = os.system(install_command)
|
||||
assert re == 0
|
||||
|
||||
|
||||
def rename_c_to_cu(generated_dir):
|
||||
"""
|
||||
Rename the .c files int generated_dir to .cu file, because the triton aot tool generate the .c files.
|
||||
"""
|
||||
# rename the .c file to .cu
|
||||
for filename in os.listdir(generated_dir):
|
||||
if filename.endswith(".c"):
|
||||
old_path = os.path.join(generated_dir, filename)
|
||||
new_path = os.path.join(generated_dir, filename + "u")
|
||||
os.rename(old_path, new_path)
|
||||
|
||||
|
||||
def get_pointer_hint(dtypes):
|
||||
"""
|
||||
Get the pointer hint from input list.
|
||||
"""
|
||||
hint = ""
|
||||
for ele in dtypes:
|
||||
if ele == paddle.float16:
|
||||
hint += "*fp16:16,"
|
||||
elif ele == paddle.uint8:
|
||||
hint += "*u8:16,"
|
||||
elif ele == paddle.int8:
|
||||
hint += "*i8:16,"
|
||||
elif ele == paddle.int16:
|
||||
hint += "*i16:16,"
|
||||
elif ele == paddle.float32:
|
||||
hint += "*fp32:16,"
|
||||
elif ele == paddle.bfloat16:
|
||||
hint += "*bf16:16,"
|
||||
elif ele == paddle.int32:
|
||||
hint += "*i32:16,"
|
||||
elif ele == paddle.int64:
|
||||
hint += "*i64,"
|
||||
elif ele == paddle.float8_e4m3fn:
|
||||
hint += "*fp8e4nv:16,"
|
||||
return hint
|
||||
|
||||
|
||||
paddle_custom_op_head_part = """#include <vector>
|
||||
#include <map>
|
||||
#include "${op_name}_kernel.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
std::map<std::vector<int>, int> map_problem_${op_name};
|
||||
|
||||
CUdeviceptr get_tensor_ptr(const paddle::Tensor& input){
|
||||
if (input.type() == paddle::DataType::FLOAT16) {
|
||||
return (CUdeviceptr)(input.data<phi::dtype::float16>());
|
||||
} else if (input.type() == paddle::DataType::BFLOAT16) {
|
||||
return (CUdeviceptr)(input.data<phi::dtype::bfloat16>());
|
||||
} else if (input.type() == paddle::DataType::INT32) {
|
||||
return (CUdeviceptr)(input.data<int>());
|
||||
} else if (input.type() == paddle::DataType::FLOAT32) {
|
||||
return (CUdeviceptr)(input.data<float>());
|
||||
} else if (input.type() == paddle::DataType::UINT8) {
|
||||
return (CUdeviceptr)(input.data<uint8_t>());
|
||||
} else if (input.type() == paddle::DataType::INT8) {
|
||||
return (CUdeviceptr)(input.data<int8_t>());
|
||||
} else if (input.type() == paddle::DataType::INT64) {
|
||||
return (CUdeviceptr)(input.data<int64_t>());
|
||||
} else if (input.type() == paddle::DataType::INT32) {
|
||||
return (CUdeviceptr)(input.data<int32_t>());
|
||||
} else if (input.type() == paddle::DataType::INT16) {
|
||||
return (CUdeviceptr)(input.data<int16_t>());
|
||||
} else if (input.type() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
return (CUdeviceptr)(input.data<phi::dtype::float8_e4m3fn>());
|
||||
} else {
|
||||
assert(false);
|
||||
return (CUdeviceptr)(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
int triton_cdiv(int x, int y) {
|
||||
int result = (x + y - 1) / y;
|
||||
return (int)(result);
|
||||
}
|
||||
"""
|
||||
|
||||
tune_and_invoke_part = """
|
||||
std::vector<int> problem_size = {${key}};
|
||||
auto run_triton_kernel = [&](int algo_id) -> CUresult{
|
||||
return ${op_name}_kernel(run_stream,
|
||||
${triton_kernel_args},
|
||||
algo_id);
|
||||
};
|
||||
|
||||
map_problem_${op_name}[problem_size] = 0;
|
||||
|
||||
if (!map_problem_${op_name}.count(problem_size)) {
|
||||
std::cout << "we are tuning for ${op_name} which key is: {";
|
||||
for (int i = 0; i < problem_size.size(); i++) {
|
||||
std::cout << problem_size[i] << ", ";
|
||||
}
|
||||
std::cout << "}" << std::endl;
|
||||
|
||||
float min_time = 10000.f;
|
||||
int select_id = -1;
|
||||
constexpr int WARMUP = 5;
|
||||
constexpr int REPEAT = 10;
|
||||
|
||||
for (int algo_id = 0; algo_id < ${op_name}_kernel_get_num_algos(); ++algo_id) {
|
||||
cudaEvent_t beg[REPEAT];
|
||||
cudaEvent_t end[REPEAT];
|
||||
float elapsed_times[REPEAT];
|
||||
|
||||
auto status = CUDA_SUCCESS;
|
||||
|
||||
for (int ii = 0; ii < WARMUP + REPEAT; ii++) {
|
||||
int repeat_id = ii - WARMUP;
|
||||
|
||||
if (repeat_id >= 0) {
|
||||
(cudaEventCreate(beg + repeat_id));
|
||||
(cudaEventCreate(end + repeat_id));
|
||||
(cudaEventRecord(beg[repeat_id]));
|
||||
}
|
||||
|
||||
auto flush_l2_cache = paddle::full(
|
||||
{10 * 1024 * 1024}, 0, paddle::DataType::INT32, ${arbitary_output_name}.place());
|
||||
// std::cout << &flush_l2_cache << std::endl;
|
||||
// this is used when out is need to be reset to zero, such as split-k gemm.
|
||||
${reset_zero_when_tune};
|
||||
|
||||
status = run_triton_kernel(algo_id);
|
||||
// assert(status == CUDA_SUCCESS);
|
||||
|
||||
if (repeat_id >= 0) {
|
||||
(cudaEventRecord(end[repeat_id]));
|
||||
(cudaEventSynchronize(end[repeat_id]));
|
||||
(cudaEventElapsedTime(
|
||||
elapsed_times + repeat_id, beg[repeat_id], end[repeat_id]));
|
||||
}
|
||||
}
|
||||
|
||||
float avg_elapsed_time = 0.f;
|
||||
for (int ii = 0; ii < REPEAT; ++ii) {
|
||||
avg_elapsed_time += elapsed_times[ii];
|
||||
}
|
||||
|
||||
std::cout << "algo id " << algo_id << " costs " << avg_elapsed_time << " ms" << std::endl;
|
||||
|
||||
if (avg_elapsed_time < min_time && status == CUDA_SUCCESS) {
|
||||
min_time = avg_elapsed_time;
|
||||
select_id = algo_id;
|
||||
}
|
||||
}
|
||||
|
||||
map_problem_${op_name}[problem_size] = select_id;
|
||||
std::cout << "select algo id: " << select_id << std::endl;
|
||||
${reset_zero_when_tune};
|
||||
}
|
||||
|
||||
if (map_problem_${op_name}.count(problem_size)) {
|
||||
int algo_id = map_problem_${op_name}[problem_size];
|
||||
auto status = run_triton_kernel(algo_id);
|
||||
assert(status == CUDA_SUCCESS);
|
||||
}
|
||||
"""
|
||||
|
||||
common_template = ("""
|
||||
std::vector<paddle::Tensor> ${op_name}_func(${input_and_attr}) {
|
||||
${prepare_attr_for_triton_kernel}
|
||||
${prepare_ptr_for_triton_kernel}
|
||||
auto run_stream = ${arbitary_output_name}.stream();
|
||||
""" + tune_and_invoke_part + """
|
||||
return {${return_tensor_names}};
|
||||
}
|
||||
|
||||
${d2s_infer_code}
|
||||
|
||||
PD_BUILD_OP(${op_name})
|
||||
.Inputs({${paddle_input_sig}})
|
||||
.Outputs({${paddle_output_sig}})
|
||||
.Attrs({${paddle_attr_sig}})
|
||||
.SetKernelFn(PD_KERNEL(${op_name}_func))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
|
||||
""")
|
||||
|
||||
|
||||
def rendering_common_template(
|
||||
func,
|
||||
prepare_attr_for_triton_kernel,
|
||||
prepare_ptr_for_triton_kernel,
|
||||
return_tensor_names=None,
|
||||
d2s_infer_code="",
|
||||
):
|
||||
"""
|
||||
Render a template with given function and its arguments.
|
||||
Args:
|
||||
func: The function to render.
|
||||
prepare_attr_for_triton_kernel: The code snippet that prepares attributes for Triton kernel.
|
||||
prepare_ptr_for_triton_kernel: The code snippet that prepares pointers for Triton kernel.
|
||||
return_tensor_names: The names of the returned tensors. Default is None.
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
arg_names = [v.name for v in signature.parameters.values()]
|
||||
arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
input_and_attr = ""
|
||||
paddle_input_sig = ""
|
||||
paddle_attr_sig = ""
|
||||
|
||||
if return_tensor_names is None:
|
||||
return_tensor_names = "useless"
|
||||
prepare_ptr_for_triton_kernel += (
|
||||
"auto useless = paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());"
|
||||
)
|
||||
|
||||
for i in range(len(arg_names)):
|
||||
if arg_defaults[i] is None:
|
||||
input_and_attr += f"paddle::optional<paddle::Tensor> & {arg_names[i]},"
|
||||
paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),"""
|
||||
elif type(arg_defaults[i]) == float:
|
||||
input_and_attr += f"float {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: float","""
|
||||
elif type(arg_defaults[i]) == bool:
|
||||
input_and_attr += f"bool {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: bool","""
|
||||
elif type(arg_defaults[i]) == int:
|
||||
input_and_attr += f"int64_t {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: int64_t","""
|
||||
elif type(arg_defaults[i]) == str:
|
||||
input_and_attr += f"std::string {arg_names[i]},"
|
||||
paddle_attr_sig += f""""{arg_names[i]}: std::string","""
|
||||
elif arg_names[i] == "config":
|
||||
continue
|
||||
else:
|
||||
input_and_attr += f"const paddle::Tensor & {arg_names[i]},"
|
||||
paddle_input_sig += f""""{arg_names[i]}","""
|
||||
input_and_attr = input_and_attr[:-1]
|
||||
paddle_input_sig = paddle_input_sig[:-1]
|
||||
if len(paddle_attr_sig) > 1:
|
||||
paddle_attr_sig = paddle_attr_sig[:-1]
|
||||
|
||||
paddle_output_sig = ""
|
||||
arbitary_output_name = ""
|
||||
for name in return_tensor_names.split(","):
|
||||
name = name.strip()
|
||||
arbitary_output_name = name
|
||||
paddle_output_sig += f""""{name}","""
|
||||
paddle_output_sig = paddle_output_sig[:-1]
|
||||
|
||||
if "${op_name}_InferShape" not in d2s_infer_code:
|
||||
d2s_infer_shape_part = (
|
||||
"std::vector<std::vector<int64_t>> ${op_name}_InferShape("
|
||||
"const std::vector<int64_t>& A_shape) {"
|
||||
"return {${tmp}};"
|
||||
"}\n ")
|
||||
tmp = ",".join(["A_shape"] * len(return_tensor_names.split(",")))
|
||||
tmp_dict = {"tmp": tmp}
|
||||
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part,
|
||||
tmp_dict)
|
||||
|
||||
d2s_infer_code += d2s_infer_shape_part
|
||||
|
||||
if "${op_name}_InferDtype" not in d2s_infer_code:
|
||||
d2s_infer_dtype_part = (
|
||||
"std::vector<paddle::DataType> ${op_name}_InferDtype("
|
||||
"const paddle::DataType& A_dtype) {"
|
||||
"return {${tmp}};"
|
||||
"}\n ")
|
||||
tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(",")))
|
||||
tmp_dict = {"tmp": tmp}
|
||||
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part,
|
||||
tmp_dict)
|
||||
|
||||
d2s_infer_code += d2s_infer_dtype_part
|
||||
|
||||
result_str = SubstituteTemplate(
|
||||
common_template,
|
||||
{
|
||||
"input_and_attr": input_and_attr,
|
||||
"prepare_attr_for_triton_kernel": prepare_attr_for_triton_kernel,
|
||||
"prepare_ptr_for_triton_kernel": prepare_ptr_for_triton_kernel,
|
||||
"return_tensor_names": return_tensor_names,
|
||||
"arbitary_output_name": arbitary_output_name,
|
||||
"d2s_infer_code": d2s_infer_code,
|
||||
"paddle_input_sig": paddle_input_sig,
|
||||
"paddle_output_sig": paddle_output_sig,
|
||||
"paddle_attr_sig": paddle_attr_sig,
|
||||
},
|
||||
)
|
||||
|
||||
return paddle_custom_op_head_part + result_str
|
||||
|
||||
|
||||
class KernelInterface:
|
||||
"""
|
||||
triton kernel interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func,
|
||||
other_config,
|
||||
key_args=["1"],
|
||||
):
|
||||
"""
|
||||
triton kernel interface.
|
||||
"""
|
||||
self.func = func
|
||||
self.key_args = key_args
|
||||
|
||||
signature = inspect.signature(func)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
for ele in self.arg_names:
|
||||
assert self.arg_names.count(ele) == 1
|
||||
# arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
# self.annotations = {
|
||||
# name: ty for name, ty in func.__annotations__.items()
|
||||
# }
|
||||
self.annotations = dict(func.__annotations__)
|
||||
|
||||
self.constexprs = [
|
||||
self.arg_names.index(name) for name in self.arg_names
|
||||
if self.annotations.get(name) == triton.language.core.constexpr
|
||||
]
|
||||
|
||||
self.arg_exclude_constexpr = [
|
||||
self.arg_names[i] for i in range(len(self.arg_names))
|
||||
if i not in self.constexprs
|
||||
]
|
||||
|
||||
import textwrap
|
||||
|
||||
py_script = textwrap.dedent(inspect.getsource(func))
|
||||
|
||||
import re
|
||||
|
||||
pat = r"def\s" + func.__name__
|
||||
func_begin = re.findall(pat, py_script)
|
||||
assert len(func_begin) == 1
|
||||
func_begin = func_begin[0]
|
||||
py_script = py_script[py_script.find(func_begin):]
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
"""
|
||||
decorator for triton kernels.
|
||||
Args:
|
||||
*args: positional arguments
|
||||
**kwargs: keyword arguments
|
||||
"""
|
||||
all_input = []
|
||||
|
||||
for i in range(len(args)):
|
||||
all_input.append(args[i])
|
||||
|
||||
position_arguments_num = len(all_input)
|
||||
for i in range(position_arguments_num, len(self.arg_names)):
|
||||
if self.arg_names[i] in kwargs.keys():
|
||||
all_input.append(kwargs[self.arg_names[i]])
|
||||
else:
|
||||
# means this input is not specified, it muse be a tl.constexpr.
|
||||
assert i in self.constexprs
|
||||
all_input.append(None)
|
||||
|
||||
dtypes = []
|
||||
x_list = []
|
||||
const_args = [self.arg_names[i] for i in self.constexprs]
|
||||
# we dont allow there are two strings in const_args, and one is a substring of the other.
|
||||
for i in const_args:
|
||||
for j in const_args:
|
||||
if i != j and i.find(j) != -1:
|
||||
raise ValueError(
|
||||
f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, "
|
||||
"please modify your triton kernel arguments names to avoid this."
|
||||
)
|
||||
|
||||
modified_arg_exclude_constexpr = self.arg_exclude_constexpr
|
||||
const_hint_dict = {}
|
||||
for i in range(len(all_input)):
|
||||
ele = all_input[i]
|
||||
if (type(ele) == paddle.Tensor
|
||||
or type(ele) == paddle.base.framework.EagerParamBase
|
||||
or type(ele) == paddle.base.framework.Parameter
|
||||
or type(ele) == paddle.base.framework.Variable
|
||||
or type(ele) == paddle.base.libpaddle.pir.Value):
|
||||
dtypes.append(ele.dtype)
|
||||
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
|
||||
elif i in self.constexprs:
|
||||
const_hint_dict[self.arg_names[i]] = ele
|
||||
else:
|
||||
x_list.append(ele)
|
||||
|
||||
op_name = self.op_name
|
||||
|
||||
python_package_name = f"{op_name}_package"
|
||||
tp_rank = paddle.distributed.get_rank()
|
||||
generated_dir = envs.FD_TRITON_KERNEL_CACHE_DIR
|
||||
if generated_dir is None:
|
||||
generated_dir = f"/tmp/triton_cache/rank{tp_rank}"
|
||||
print("the kernel cache dir is:", generated_dir)
|
||||
assert (generated_dir is not None), (
|
||||
"TRITON_KERNEL_CACHE_DIR is None, please set it such as "
|
||||
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache ")
|
||||
generated_dir = f"{generated_dir}/{op_name}"
|
||||
os.makedirs(generated_dir, exist_ok=True)
|
||||
|
||||
py_script_file = f"{generated_dir}/triton_kernels.py"
|
||||
extract_triton_kernel(func, py_script_file)
|
||||
|
||||
address_hint = get_pointer_hint(dtypes)
|
||||
value_hint = get_value_hint(x_list)
|
||||
const_args = [f"{{{ele}}}" for ele in const_args]
|
||||
const_args = ",".join(const_args)
|
||||
|
||||
lanuch_grid = list(self.grid)
|
||||
for i in range(len(lanuch_grid)):
|
||||
ele = lanuch_grid[i]
|
||||
if type(ele) == str:
|
||||
for key in const_hint_dict.keys():
|
||||
if key in ele:
|
||||
ele = ele.replace(key, f"{{{key}}}")
|
||||
else:
|
||||
ele = str(ele)
|
||||
|
||||
lanuch_grid[i] = ele
|
||||
if len(lanuch_grid) < 3:
|
||||
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
|
||||
lanuch_grid = ",".join(lanuch_grid)
|
||||
|
||||
op_dict = {"op_name": op_name, "reset_zero_when_tune": ""}
|
||||
op_dict["triton_kernel_args"] = ",".join(
|
||||
modified_arg_exclude_constexpr)
|
||||
op_dict["key"] = ",".join(self.key_args)
|
||||
# when tunning, we need to reset the out to zero.
|
||||
if "reset_zero_when_tune" in other_config.keys():
|
||||
op_dict["reset_zero_when_tune"] = other_config[
|
||||
"reset_zero_when_tune"]
|
||||
|
||||
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
|
||||
so_path = find_so_path(generated_dir, python_package_name)
|
||||
|
||||
if so_path is None:
|
||||
print("== we do not find so_path, we need to compile it")
|
||||
with open(paddle_custom_op_file_path, "w") as f:
|
||||
f.write(
|
||||
SubstituteTemplate(
|
||||
self.custom_op_template,
|
||||
op_dict,
|
||||
))
|
||||
f.close()
|
||||
|
||||
# ahead of time compile command.
|
||||
aot_template = (
|
||||
f"""{python_path} {compile_file} {py_script_file} """ +
|
||||
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
|
||||
+ f"""--out-name {op_name}_kernel """ +
|
||||
""" -w {num_warps} -ns {num_stages} """ +
|
||||
f""" -s"{address_hint} {value_hint} {const_args}" """ +
|
||||
f""" -g "{lanuch_grid}" """)
|
||||
all_tune_config = list(self.tune_config)
|
||||
if len(all_tune_config) == 0:
|
||||
# when user do not specify config, we use const_hint_dict as config.
|
||||
all_tune_config = [const_hint_dict]
|
||||
# reset const_hint_dict as empty.
|
||||
const_hint_dict = {}
|
||||
codegen_commands = []
|
||||
for config in all_tune_config:
|
||||
for key in const_hint_dict.keys():
|
||||
if const_hint_dict[key] is not None:
|
||||
if key not in config.keys():
|
||||
config[key] = const_hint_dict[key]
|
||||
else:
|
||||
if config[key] == const_hint_dict[key]:
|
||||
pass
|
||||
else:
|
||||
message = (
|
||||
f"you specify {key} both in arguments and config, "
|
||||
"and they are not same, this is wrong."
|
||||
)
|
||||
raise ValueError(message)
|
||||
else:
|
||||
assert key in config.keys(
|
||||
), f"you must specify {key} in your config."
|
||||
if "num_warps" not in config.keys():
|
||||
config["num_warps"] = 4
|
||||
if "num_stages" not in config.keys():
|
||||
config["num_stages"] = 4
|
||||
|
||||
for key in config:
|
||||
assert config[
|
||||
key] is not None, f"{key} must be specified."
|
||||
codegen_command = aot_template.format(**config, )
|
||||
print(codegen_command)
|
||||
codegen_commands.append(codegen_command)
|
||||
multi_process_do(codegen_commands)
|
||||
|
||||
link_command = (
|
||||
f"{python_path} {link_file} "
|
||||
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
|
||||
re = os.system(link_command)
|
||||
assert re == 0
|
||||
|
||||
# rename the .c file to .cu
|
||||
rename_c_to_cu(generated_dir)
|
||||
# build the package to so, not install
|
||||
build_package(generated_dir, python_package_name)
|
||||
|
||||
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
|
||||
so_path = find_so_path(generated_dir, python_package_name)
|
||||
print("== we find so_path: ", so_path)
|
||||
assert so_path is not None
|
||||
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
|
||||
so_path)
|
||||
|
||||
self.decorator = decorator
|
||||
|
||||
def __getitem__(self, op_name_and_grid):
|
||||
"""
|
||||
override the operator [], which will call the decorator function.
|
||||
Args:
|
||||
op_name_and_grid: the name of the operator and the grid size.
|
||||
Returns:
|
||||
the decorator function.
|
||||
"""
|
||||
assert len(op_name_and_grid) >= 3, "len(op_name_and_grid) must >= 3."
|
||||
self.op_name = op_name_and_grid[0]
|
||||
self.custom_op_template = op_name_and_grid[1]
|
||||
self.grid = op_name_and_grid[2]
|
||||
if len(op_name_and_grid) == 3:
|
||||
self.tune_config = {}
|
||||
else:
|
||||
self.tune_config = op_name_and_grid[3]
|
||||
|
||||
return self.decorator
|
||||
|
||||
|
||||
def paddle_use_triton(other_config={}, key=[]):
|
||||
"""
|
||||
The decorator function that wraps the original function.
|
||||
Args:
|
||||
func: the original function.
|
||||
Returns:
|
||||
the wrapped function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
The decorator function that wraps the original function.
|
||||
Args:
|
||||
func: the original function.
|
||||
Returns:
|
||||
the wrapped function.
|
||||
"""
|
||||
return KernelInterface(func, other_config, key)
|
||||
|
||||
return decorator
|
549
fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py
Normal file
549
fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import triton.language as tl
|
||||
from paddle import _C_ops
|
||||
from paddle.base.framework import OpProtoHolder
|
||||
from paddle.framework import in_dynamic_or_pir_mode
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
|
||||
get_dtype_str, paddle_use_triton, rendering_common_template)
|
||||
|
||||
BLOCK_SIZE_M = 16
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_super_scale,
|
||||
B_code_scale,
|
||||
B_code_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight=False,
|
||||
top_k=-1,
|
||||
group_size=-1,
|
||||
):
|
||||
"""
|
||||
Invoke Fused Moe Kernel
|
||||
"""
|
||||
KK = A.shape[-1]
|
||||
NN = B.shape[-1]
|
||||
sstride_am, sstride_ak = A.shape[1], 1
|
||||
sstride_be, sstride_bk, sstride_bn = B.shape[1] * B.shape[2], B.shape[2], 1
|
||||
sstride_cm, sstride_cn = C.shape[-1], 1
|
||||
sstride_bse, sstride_bsk, sstride_bsn = B_scale.shape[1] * B_scale.shape[
|
||||
2], B_scale.shape[2], 1
|
||||
sstride_bce, sstride_bck, sstride_bcn = B_code_scale.shape[1], 1, 1
|
||||
|
||||
ddouble_quant = B_super_scale is not None
|
||||
|
||||
prepare_attr_for_triton_kernel = """
|
||||
auto N = B.shape()[2];
|
||||
auto K = A.shape()[1];
|
||||
auto EM = sorted_token_ids.shape()[0];
|
||||
auto num_valid_tokens = (topk_ids.shape()[0]) * (topk_ids.shape()[1]);
|
||||
auto stride_am = A.strides()[0];
|
||||
auto stride_ak = A.strides()[1];
|
||||
auto stride_be = B.strides()[0];
|
||||
auto stride_bk = B.strides()[1];
|
||||
auto stride_bn = B.strides()[2];
|
||||
auto stride_cm = C.strides()[1];
|
||||
auto stride_cn = C.strides()[2];
|
||||
auto stride_bse = B_scale.strides()[0];
|
||||
auto stride_bsk = B_scale.strides()[1];
|
||||
auto stride_bsn = 1;
|
||||
auto stride_bce = B_code_scale.strides()[0];
|
||||
auto stride_bck = 1;
|
||||
auto stride_bcn = 1;
|
||||
auto double_quant = true;
|
||||
"""
|
||||
if mul_routed_weight:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 8,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 512,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 12,
|
||||
}
|
||||
configs = []
|
||||
|
||||
configs.append(dict(config))
|
||||
|
||||
op_name = "wint2_moe_ffn"
|
||||
op_name += f"{get_dtype_str(A.dtype)}"
|
||||
op_name += f"{B.shape[0]}"
|
||||
op_name += f"{B.shape[1]}"
|
||||
op_name += f"{B.shape[2]}"
|
||||
|
||||
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
|
||||
prepare_ptr_for_triton_kernel = """
|
||||
CUdeviceptr input_ptrs[11] = {
|
||||
get_tensor_ptr(A),
|
||||
get_tensor_ptr(B),
|
||||
get_tensor_ptr(C),
|
||||
get_tensor_ptr(B_scale),
|
||||
get_tensor_ptr(B_super_scale),
|
||||
get_tensor_ptr(B_code_scale),
|
||||
get_tensor_ptr(B_code_zp),
|
||||
get_tensor_ptr(topk_weights),
|
||||
get_tensor_ptr(sorted_token_ids),
|
||||
get_tensor_ptr(expert_ids),
|
||||
get_tensor_ptr(num_tokens_post_padded),
|
||||
};
|
||||
"""
|
||||
template_used = rendering_common_template(
|
||||
invoke_fused_moe_kernel,
|
||||
prepare_attr_for_triton_kernel,
|
||||
prepare_ptr_for_triton_kernel,
|
||||
)
|
||||
grid = (
|
||||
"(EM+BLOCK_SIZE_M-1)/BLOCK_SIZE_M * ((N+BLOCK_SIZE_N-1)/BLOCK_SIZE_N)",
|
||||
)
|
||||
|
||||
moe_wint2_ffn_kernel[(op_name, template_used, grid, configs)](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_super_scale,
|
||||
B_code_scale,
|
||||
B_code_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
NN,
|
||||
KK,
|
||||
-1, #EEM,
|
||||
-1, #nnum_valid_tokens,
|
||||
sstride_am,
|
||||
sstride_ak,
|
||||
sstride_be,
|
||||
sstride_bk,
|
||||
sstride_bn,
|
||||
sstride_cm,
|
||||
sstride_cn,
|
||||
sstride_bse,
|
||||
sstride_bsk,
|
||||
sstride_bsn,
|
||||
sstride_bce,
|
||||
sstride_bck,
|
||||
sstride_bcn,
|
||||
MUL_ROUTED_WEIGHT=(int)(mul_routed_weight),
|
||||
USE_DOUBLE_QUANT=(int)(ddouble_quant),
|
||||
top_k=top_k,
|
||||
BLOCK_SIZE_K=group_size,
|
||||
)
|
||||
if in_dynamic_or_pir_mode():
|
||||
|
||||
outs = _C_ops._run_custom_op(
|
||||
op_name,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_super_scale,
|
||||
B_code_scale,
|
||||
B_code_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
group_size,
|
||||
)
|
||||
return outs[0]
|
||||
|
||||
|
||||
@paddle_use_triton(key=["1"], )
|
||||
def moe_wint2_ffn_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
bs_ptr,
|
||||
superbs_ptr,
|
||||
codebs_ptr,
|
||||
codebzp_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bce,
|
||||
stride_bck,
|
||||
stride_bcn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
USE_DOUBLE_QUANT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
|
||||
if USE_DOUBLE_QUANT:
|
||||
# INT4 scale
|
||||
s_packnums: tl.constexpr = 2
|
||||
bzp: tl.constexpr = 32
|
||||
w_mask: tl.constexpr = 0x3F
|
||||
pack_num: tl.constexpr = 4
|
||||
real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
compute_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_bk = tl.arange(0, real_k_size)
|
||||
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_bk[None, :] * pack_num * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[
|
||||
None, :] * stride_bsn # group-wise, need advanced
|
||||
|
||||
off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn
|
||||
# load channel-wise scale & zero-point
|
||||
if USE_DOUBLE_QUANT:
|
||||
superbs_ptrs = superbs_ptr + off_set # channel-wise
|
||||
super_bs = tl.load(superbs_ptrs) # super scale
|
||||
|
||||
codebs_ptrs = codebs_ptr + off_set # channel-wise
|
||||
code_bs = tl.load(codebs_ptrs) # code scale
|
||||
codebzp_ptrs = codebzp_ptr + off_set # channel-wise
|
||||
code_bzp = tl.load(codebzp_ptrs) # code zp
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
|
||||
b = tl.load(b_ptrs)
|
||||
|
||||
bs = tl.load(bs_ptrs)
|
||||
if USE_DOUBLE_QUANT:
|
||||
s_shift_bits = (1 - k % s_packnums) * 4
|
||||
bs = ((bs >> s_shift_bits) & 0xF) * super_bs
|
||||
|
||||
# reverse to int16
|
||||
b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to(
|
||||
tl.int16)
|
||||
# dequant
|
||||
b1 = (((b >> 9) & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b1.to(a.dtype))
|
||||
|
||||
b1 = (((b >> 6) & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs + 1,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b1.to(a.dtype))
|
||||
|
||||
b1 = (((b >> 3) & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs + 2,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b1.to(a.dtype))
|
||||
|
||||
b = ((b & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs + 3,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b.to(a.dtype))
|
||||
|
||||
b_ptrs += real_k_size * stride_bk
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
|
||||
# advance scale ptr
|
||||
if USE_DOUBLE_QUANT:
|
||||
bs_ptrs += stride_bsk * (k % s_packnums)
|
||||
else:
|
||||
bs_ptrs += stride_bsk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def fused_moe_wint2_impl(
|
||||
hidden_states,
|
||||
ffn1_quant_weight,
|
||||
ffn2_quant_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
# inplace: bool = False,
|
||||
ffn1_weight_scale=None,
|
||||
ffn2_weight_scale=None,
|
||||
ffn1_super_scales=None,
|
||||
ffn2_super_scales=None,
|
||||
ffn1_code_scale=None,
|
||||
ffn2_code_scale=None,
|
||||
ffn1_code_zp=None,
|
||||
ffn2_code_zp=None,
|
||||
group_size=64,
|
||||
bit="wint2",
|
||||
):
|
||||
"""
|
||||
Implementation of Fused MoE kernels on GPU.
|
||||
"""
|
||||
# Check constraints.
|
||||
# A: [M, K]
|
||||
# B: [E, K, N]
|
||||
# assert hidden_states.shape[1] == ffn1_weight_scale.shape[1],
|
||||
# f"Hidden size mismatch, {hidden_states.shape[1]} != {ffn1_quant_weight.shape[1]}"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert ffn1_quant_weight.is_contiguous(
|
||||
), "Expert weights1 must be contiguous"
|
||||
assert ffn2_quant_weight.is_contiguous(
|
||||
), "Expert weights2 must be contiguous"
|
||||
assert group_size > 0, "Group size must be greater than 0"
|
||||
|
||||
num_tokens, K = hidden_states.shape
|
||||
E, _, N = ffn1_quant_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
if group_size < 0:
|
||||
group_size = K // ffn1_weight_scale.shape[1]
|
||||
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[M, top_k, N],
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(M * top_k, N // 2),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(M, top_k, K),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, E, BLOCK_SIZE_M)
|
||||
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
A=hidden_states,
|
||||
B=ffn1_quant_weight,
|
||||
C=intermediate_cache1,
|
||||
B_scale=ffn1_weight_scale,
|
||||
B_super_scale=ffn1_super_scales,
|
||||
B_code_scale=ffn1_code_scale,
|
||||
B_code_zp=ffn1_code_zp,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
mul_routed_weight=False,
|
||||
top_k=top_k,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1.reshape([-1, N]))
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
A=intermediate_cache2,
|
||||
B=ffn2_quant_weight,
|
||||
C=intermediate_cache3,
|
||||
B_scale=ffn2_weight_scale,
|
||||
B_super_scale=ffn2_super_scales,
|
||||
B_code_scale=ffn2_code_scale,
|
||||
B_code_zp=ffn2_code_zp,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
mul_routed_weight=True,
|
||||
top_k=1,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
out_hidden_states = paddle.sum(intermediate_cache3, axis=1)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def fused_moe_wint2_triton(
|
||||
hidden_states,
|
||||
ffn1_quant_weight,
|
||||
ffn2_quant_weight,
|
||||
scores,
|
||||
gate_correction_bias,
|
||||
topk,
|
||||
ffn1_weight_scale,
|
||||
ffn2_weight_scale,
|
||||
ffn1_super_scales,
|
||||
ffn2_super_scales,
|
||||
ffn1_code_scale,
|
||||
ffn2_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_code_zp,
|
||||
):
|
||||
"""
|
||||
Fuse MoE with WINT2 quantization scheme and Triton backend.
|
||||
Args:
|
||||
hidden_states: input tensor.
|
||||
ffn1_quant_weight: ffn1 weight matrix for experts.
|
||||
ffn2_quant_weight: ffn2 weight matrix for experts.
|
||||
scores: gate scores.
|
||||
gate_correction_bias: bias correction for gates.
|
||||
topk: number of experts to use.
|
||||
ffn1_weight_scale: scaling factor for ffn1_quant_weight.
|
||||
ffn2_weight_scale: scaling factor for ffn2_quant_weight.
|
||||
ffn1_super_scales: super scaling factor for ffn1_scale.
|
||||
ffn2_super_scales: super scaling factor for ffn2_weight_scale.
|
||||
ffn1_code_scale: code scaling factor for ffn1_quant_weight.
|
||||
ffn2_code_scale: code scaling factor for ffn2_quant_weight.
|
||||
ffn1_code_zp: code zero point for ffn1_quant_weight.
|
||||
ffn2_code_zp: code zero point for ffn2_quant_weight.
|
||||
Returns:
|
||||
output tensor.
|
||||
"""
|
||||
|
||||
score = gate_correction_bias + scores
|
||||
_, topk_ids = paddle.topk(score, k=topk, axis=-1)
|
||||
topk_weights, _ = paddle.topk(scores, k=topk, axis=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
return fused_moe_wint2_impl(
|
||||
hidden_states,
|
||||
ffn1_quant_weight,
|
||||
ffn2_quant_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ffn1_weight_scale,
|
||||
ffn2_weight_scale,
|
||||
ffn1_super_scales,
|
||||
ffn2_super_scales,
|
||||
ffn1_code_scale,
|
||||
ffn2_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_code_zp,
|
||||
bit="wint2",
|
||||
)
|
Reference in New Issue
Block a user