mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
【Inference Optimize】Support automatic generation of marlin kernel (#3149)
* Support automatic generation of marlin kernel
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -156,6 +156,9 @@ nohup.out
|
||||
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass
|
||||
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
|
||||
|
||||
#marlin_kernel
|
||||
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
|
||||
|
||||
# buff
|
||||
custom_ops/tmp*
|
||||
|
||||
|
@@ -0,0 +1,121 @@
|
||||
# adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/moe/marlin_moe_wna16/generate_kernels.py
|
||||
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (MARLIN_NAMESPACE_NAME::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = [
|
||||
"MARLIN_NAMESPACE_NAME::kU4",
|
||||
"MARLIN_NAMESPACE_NAME::kU4B8",
|
||||
# "MARLIN_NAMESPACE_NAME::kU8B128", "MARLIN_NAMESPACE_NAME::kFE4M3fn",
|
||||
# "MARLIN_NAMESPACE_NAME::kFE2M1f"
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"MARLIN_NAMESPACE_NAME::kU4B8",
|
||||
"MARLIN_NAMESPACE_NAME::kU8B128",
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
# we only support channelwise quantization and group_size == 128
|
||||
# for fp8
|
||||
if scalar_type == "MARLIN_NAMESPACE_NAME::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
if scalar_type == "MARLIN_NAMESPACE_NAME::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "MARLIN_NAMESPACE_NAME::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[23:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
@@ -1,89 +0,0 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "moe/moe_wna16_marlin_utils/kernel.h"
|
||||
#include "moe/moe_wna16_marlin_utils/marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
@@ -1,89 +0,0 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "moe/moe_wna16_marlin_utils/kernel.h"
|
||||
#include "moe/moe_wna16_marlin_utils/marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<nv_bfloat16, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
@@ -1,89 +0,0 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "moe/moe_wna16_marlin_utils/kernel.h"
|
||||
#include "moe/moe_wna16_marlin_utils/marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
@@ -1,109 +0,0 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 0, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, -1, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 2, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 4, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||
|
||||
}
|
@@ -409,6 +409,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
|
||||
nvcc_compile_args += ["-DENABLE_BF16"]
|
||||
# moe
|
||||
os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py")
|
||||
sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu")
|
||||
sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu")
|
||||
sources += find_end_files("gpu_ops/moe/", ".cu")
|
||||
|
Reference in New Issue
Block a user