diff --git a/.gitignore b/.gitignore index b7c91af77..eb5cc068c 100644 --- a/.gitignore +++ b/.gitignore @@ -156,6 +156,9 @@ nohup.out custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute +#marlin_kernel +custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu + # buff custom_ops/tmp* diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py new file mode 100644 index 000000000..de2d9ddb4 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py @@ -0,0 +1,121 @@ +# adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/moe/marlin_moe_wna16/generate_kernels.py + +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) + +# int8 with zero point case (MARLIN_NAMESPACE_NAME::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = [ + "MARLIN_NAMESPACE_NAME::kU4", + "MARLIN_NAMESPACE_NAME::kU4B8", + # "MARLIN_NAMESPACE_NAME::kU8B128", "MARLIN_NAMESPACE_NAME::kFE4M3fn", + # "MARLIN_NAMESPACE_NAME::kFE2M1f" +] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product(GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + # act order case only support gptq-int4 and gptq-int8 + if group_blocks == 0 and scalar_type not in [ + "MARLIN_NAMESPACE_NAME::kU4B8", + "MARLIN_NAMESPACE_NAME::kU8B128", + ]: + continue + if thread_configs[2] == 256: + # for small batch (m_blocks == 1), we only need (128, 128, 256) + # for large batch (m_blocks > 1), we only need (64, 256, 256) + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + # we only support channelwise quantization and group_size == 128 + # for fp8 + if scalar_type == "MARLIN_NAMESPACE_NAME::kFE4M3fn" and group_blocks not in [-1, 8]: + continue + # nvfp4 only supports group_size == 16 + if scalar_type == "MARLIN_NAMESPACE_NAME::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "MARLIN_NAMESPACE_NAME::kFE2M1f" and group_blocks == 1: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + group_blocks=group_blocks, + is_zp_float=False, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[23:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu deleted file mode 100644 index 4d290cbe0..000000000 --- a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu +++ /dev/null @@ -1,89 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "moe/moe_wna16_marlin_utils/kernel.h" -#include "moe/moe_wna16_marlin_utils/marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4b8.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4b8.cu deleted file mode 100644 index 79730064a..000000000 --- a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4b8.cu +++ /dev/null @@ -1,89 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "moe/moe_wna16_marlin_utils/kernel.h" -#include "moe/moe_wna16_marlin_utils/marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu deleted file mode 100644 index d1d1e643b..000000000 --- a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu +++ /dev/null @@ -1,89 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "moe/moe_wna16_marlin_utils/kernel.h" -#include "moe/moe_wna16_marlin_utils/marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4b8.cu b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4b8.cu deleted file mode 100644 index b45f36947..000000000 --- a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4b8.cu +++ /dev/null @@ -1,109 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index f7c934aa1..38ccb357b 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -409,6 +409,7 @@ elif paddle.is_compiled_with_cuda(): sources += find_end_files("gpu_ops/speculate_decoding", ".cc") nvcc_compile_args += ["-DENABLE_BF16"] # moe + os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py") sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu") sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu") sources += find_end_files("gpu_ops/moe/", ".cu")