# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re file_dir = "./gpu_ops/w4afp8_gemm/" gemm_template_head = """ #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include #include #include #include """ gemm_template_case = """ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, const float * input_dequant_scale, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream); """ gemm_template_cu_head = """ #include "paddle/extension.h" #include "w4afp8_gemm_template.h" #include "w4afp8_gemm_kernel.hpp" """ gemm_template_cu_template = """ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, const float * input_dequant_scale, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream) {{ constexpr static int M = {M}; constexpr static int K = {K}; constexpr static int EXPERTS = {EXPERTS}; constexpr static int TokenPackSize = {PADDING}; constexpr static int kBlockN = {N}; constexpr static int kGroupSize = {GROUPSIZE}; constexpr static int kBlockM = 128; constexpr static int kBlockK = 128; constexpr static int kNWarps = 4 + kBlockM / 16; constexpr static int kStages = 5; constexpr int kCluster = 1; static_assert(K % kBlockK == 0); constexpr int kTiles = K / kBlockK; using Kernel_traits = Kernel_traits< kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles, M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t, {cutlass_type}>; run_gemm (weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream); }} """ # [M, K, Number of experts, token Padding Size, weight K group size] gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]] dtype = ["BF16"] use_fast_compile = True n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)] all_cu_files = [] for type in dtype: for case in gemm_case: for n in n_range: all_cu_files.append(f"w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu") for file_path, empty_list, file_name_list in os.walk(file_dir): for file_name in file_name_list: if re.match(r"^w4afp8_gemm_M\d+_N\d+_.*\.cu$", file_name): if file_name not in all_cu_files: print("delete w4afp8 kernel file", file_path + file_name) os.remove(file_path + file_name) def get_cutlass_type(type): if type == "BF16": return "cutlass::bfloat16_t" elif type == "FP16": return "cutlass::half_t" template_head_file = open(f"{file_dir}w4afp8_gemm_template.h", "w") template_head_file.write(gemm_template_head) for type in dtype: for case in gemm_case: for n in n_range: template_head_file.write( gemm_template_case.format( M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4], cutlass_type=get_cutlass_type(type), ) ) template_cu_file = open( f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu", "w" ) template_cu_file.write(gemm_template_cu_head) template_cu_file.write( gemm_template_cu_template.format( M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4], cutlass_type=get_cutlass_type(type), ) ) template_cu_file.close() for type in dtype: template_head_file.write("\n") template_head_file.write( """#define GEMM_SWITCH_{TYPE}(_M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, ...) {{ \\ if (_M == 0 && _K == 0 && _EXPERTS == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _GROUPSIZE == 0) {{ \\""".format( TYPE=type ) ) template_head_file.write("\n") for case in gemm_case: for n in n_range: template_head_file.write( """ }} else if (_M == {M} && _K == {K} && _EXPERTS == {EXPERTS} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _GROUPSIZE == {GROUPSIZE}) {{ \\ w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4] ) ) template_head_file.write("\n") template_head_file.write( """ } else { \\ PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d, please add [%d, %d, %d, %d, %d] to the gemm_case array in the custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py file and recompile it\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, _M, _K, _EXPERTS, _TokenPaddingSize, _GROUPSIZE)); \\ } \\ }""" ) template_head_file.close()