# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """generate gemm_fused_kernel code.""" import argparse import os import re def get_candidate_tiles(): """ 获取候选的tile配置列表。 Args: 无参数。 Returns: List[Tuple[str, str, str]]: 包含tile配置的三元组列表,每个三元组中的字符串表示tile的形状"。 """ base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs.extend( [ ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"), ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ] ) return base_configs def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): """ 获取候选的gemm算子配置列表。 Args: sm (str): 计算能力,如"70"。 min_split_k (int): split k的最小值。 """ tiles = get_candidate_tiles() candidate_configs = list() stages = tuple(i for i in range(min_stages, max_stages + 1, 1)) splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1)) hasbias = ("false", "true") for act_tag in [ ("noact", "LinearCombination"), ("relu", "LinearCombinationRelu"), ("gelu", "LinearCombinationGELU"), ]: candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)]) return candidate_configs # this is a file's header part CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. #pragma once #include "fp8_gemm_fused/fuse_gemm_{act_tag}_template.h" """ CommonTail = """ """ GemmDeclare = """ template<> bool dispatch_fuse_gemm_{act_tag}(GemmEpilogueAllParams); """ GemmSplitKDeclare = """ template<> bool dispatch_fuse_gemm_split_k_{act_tag}(GemmEpilogueAllParams); """ LaunchGemmHead = """ #pragma once #include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h" """ LaunchGemmDeclare = """ bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params); """ LaunchGemmPart0 = """ #pragma once #include "launch_gemm_kernel.h" bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params){ if(split_k < 2){ params.split_k = 1; switch (type_id) { """ LaunchGemmPart1 = """ case {type_id}: return dispatch_fuse_gemm_{act_tag}(params); break; """ LaunchGemmPart2 = """ default: throw std::runtime_error("cutlass gemm config is invalid."); break; } }else{ switch (type_id) { """ LaunchGemmPart3 = """ case {type_id}: return dispatch_fuse_gemm_split_k_{act_tag}(params); break; """ LaunchGemmPart4 = """ default: throw std::runtime_error("cutlass gemm config is invalid."); break; } } return false; } """ code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. #include #include #include #include "helper.h" #include "fp8_fp8_gemm_scale_bias_act.h" #include "launch_gemm_kernel.h" COMMON_DECLARE_string(use_cutlass_device_best_config_path); std::map gemm_type_map{""" code_part1 = """ {"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """ code_part2 = """ }; std::map gemm_config_map{ """ code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}}, """ code_part4 = """}; bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){ switch (kernel_id) {""" code_part5 = """ case {tile_id}: return launch_gemm_kernel_{gemm_config}(type_id, split_k, params); break;""" code_part6 = """ default: throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid."); break; } return false; } bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) { if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) { throw std::runtime_error("fp8 gemm_fused config is invalid."); } int type_id = gemm_type_map[params.fuse_gemm_config]; int M = (params.M+31)/32 *32; int N = params.N; int K = params.K; std::string mnk_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; auto encoded_mnk_string = base64_encode(mnk_string); std::string mnk_split_k_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; auto encoded_mnk_split_k_string = base64_encode(mnk_split_k_string); int split_k; int kernel_id; std::string best_config; CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance(); if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path"); nlohmann::json* config_json = new nlohmann::json(); if (config_file_path != "default") { config_json = best_config_mannager.get_gemm_best_configs(config_file_path); } best_config = get_relative_best(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3"); split_k = get_relative_best(config_json, encoded_mnk_split_k_string, 1); if (gemm_config_map.find(best_config) == gemm_config_map.end()) { throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate."); } else { kernel_id = gemm_config_map[best_config]; } return launch_gemm_kernel(type_id, split_k, kernel_id, params); } else { // tune kernel int warm_up_times = 5; int tune_times = 10; std::string best_kernel_id = ""; int best_split_k = -1; float duratation = 1000000.f; // tune all split_k, kernel_id kernels for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k for(const auto& config_pair : gemm_config_map){ bool is_valid = true; // warm up for(int num_time = 0; num_time < warm_up_times; ++num_time){ if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){ is_valid = false; break; } } if(!is_valid){ continue; } cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); cudaStreamSynchronize(params.stream); cudaEventRecord(start, params.stream); for(int num_time = 0; num_time < tune_times; ++num_time){ if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){ is_valid = false; break; }; } cudaEventRecord(stop, params.stream); cudaEventSynchronize(stop); float elapsedTime; if(is_valid){ cudaEventElapsedTime(&elapsedTime, start, stop); } else { continue; } cudaEventDestroy(start); cudaEventDestroy(stop); if(elapsedTime < duratation){ best_kernel_id = config_pair.first; best_split_k = i; duratation = elapsedTime; } } } nlohmann::json new_json; new_json[encoded_mnk_string] = best_kernel_id; new_json[encoded_mnk_split_k_string] = best_split_k; best_config_mannager.up_date_configs(new_json); std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl; return true; } } """ def SubstituteTemplate(template, values): """ 生成函数模板 """ text = template changed = True while changed: changed = False for key, value in values.items(): regex = f"\\{{{key}\\}}" newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text def parse_args(): """ 代码参数解析 """ parser = argparse.ArgumentParser( description="The argument for generating the generic_mixed_gemm_kernelLauncher instance." ) parser.add_argument( "--cuda_arch", type=str, nargs="+", default=["89"], help="The CUDA architecture to be generated.", ) parser.add_argument( "--min_split_k", type=int, default=2, help="The min split k for the gemm kernel.", ) parser.add_argument( "--max_split_k", type=int, default=6, help="The max split k for the gemm kernel.", ) parser.add_argument( "--min_stages", type=int, default=2, help="The min stages for the gemm kernel.", ) parser.add_argument( "--max_stages", type=int, default=8, help="The max stages for the gemm kernel.", ) args = parser.parse_args() return args # generate source .cu def generate_source_cu( inputs_type: str, outputs_type: str, stages: int, tiles: str, act_tag: str, hasbiases: str, sm: str, ): """ 生成.cu源文件 """ value_dict = { "act_tag": act_tag, } all_code = SubstituteTemplate(CommonHead, value_dict) for input_type in inputs_type: for output_type in outputs_type: for stage in stages: for hasbias in hasbiases: for tile_config in tiles: value_dict = { "input_type": input_type, "output_type": output_type, "thread_block_shape": tile_config[0], "warp_shape": tile_config[1], "mma_shape": tile_config[2], "num_stages": str(stage), "act_tag": act_tag, "hasbias": hasbias, "SM": sm, } all_code += SubstituteTemplate(GemmDeclare, value_dict) for input_type in inputs_type: for output_type in outputs_type: for stage in stages: for hasbias in hasbiases: for tile_config in tiles: value_dict = { "input_type": input_type, "output_type": output_type, "thread_block_shape": tile_config[0], "warp_shape": tile_config[1], "mma_shape": tile_config[2], "num_stages": str(stage), "act_tag": act_tag, "hasbias": hasbias, "SM": sm, } all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict) all_code += CommonTail return all_code # generate gemm launch .cu def generate_launch_gemm_cus( generate_dir: str, inputs_type: str, outputs_type: str, stages: int, split_ks: int, tiles: str, act_tags: str, hasbiases: str, sm: str, min_split_k: int, max_split_k: int, ): code_map = {} head_path = os.path.join(generate_dir, "launch_gemm_kernel.h") head_all_code = LaunchGemmHead for tile in tiles: blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" for stage in stages: gemm_config_str = gemm_config + f"_stage{stage}" value_dict = { "gemm_config": gemm_config_str, } head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict) os.makedirs(generate_dir, exist_ok=True) with open(head_path, "w") as f: f.write(head_all_code) f.close() for tile in tiles: blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" for stage in stages: gemm_config_str = gemm_config + f"_stage{stage}" value_dict = { "gemm_config": gemm_config_str, } source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict) split_k_code = "" type_id = 0 for input_type in inputs_type: for output_type in outputs_type: for act_tag in act_tags: for hasbias in hasbiases: value_dict = { "act_tag": act_tag, "input_type": input_type, "output_type": output_type, "hasbias": hasbias, "type_id": str(type_id), "thread_block_shape": tile[0], "warp_shape": tile[1], "mma_shape": tile[2], "num_stages": str(stage), "SM": sm, } source_all_code += SubstituteTemplate( LaunchGemmPart1, value_dict ) split_k_code += SubstituteTemplate( LaunchGemmPart3, value_dict ) type_id += 1 source_all_code += LaunchGemmPart2 source_all_code += split_k_code source_all_code += LaunchGemmPart4 code_map[gemm_config_str] = source_all_code source_path = os.path.join( generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu" ) with open(source_path, "w") as f: f.write(source_all_code) f.close() return head_all_code, code_map # generate fp8_fp8_gemm_scale_bias_act.cu def generate_dispatch_gemm_cu( inputs_type: str, outputs_type: str, stages: int, split_ks: int, tiles: str, act_tags: str, hasbiases: str, sm: str, min_split_k: int, max_split_k: int, ): all_code = code_part0 type_id = 0 for input_type in inputs_type: for output_type in outputs_type: for act_tag in act_tags: for hasbias in hasbiases: value_dict = { "act_tag": act_tag, "input_type": input_type, "output_type": output_type, "hasbias": hasbias, "type_id": str(type_id), } all_code += SubstituteTemplate(code_part1, value_dict) type_id += 1 all_code += code_part2 tile_id = 0 for tile in tiles: for stage in stages: value_dict = { "thread_block_shape": tile[0], "warp_shape": tile[1], "mma_shape": tile[2], "num_stages": str(stage), "tile_id": str(tile_id), } all_code += SubstituteTemplate(code_part3, value_dict) tile_id += 1 all_code += code_part4 tile_id = 0 for tile in tiles: blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile] gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" for stage in stages: gemm_config_str = gemm_config + f"_stage{stage}" value_dict = { "tile_id": str(tile_id), "gemm_config": gemm_config_str, } all_code += SubstituteTemplate(code_part5, value_dict) tile_id += 1 value_dict.update( { "min_split_k": str(min_split_k), "max_split_k": str(max_split_k), } ) all_code += SubstituteTemplate(code_part6, value_dict) return all_code if __name__ == "__main__": args = parse_args() archs = args.cuda_arch min_split_k = args.min_split_k max_split_k = args.max_split_k min_stages = args.min_stages max_stages = args.max_stages inputs_type = ("float8_e4m3fn", "float8_e5m2") outputs_type = ("float16", "bfloat16") sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"} for sm in archs: if sm == "89": fuse_gemm_configs = get_candidate_configs( sm, min_split_k, max_split_k, min_stages, max_stages ) for fuse_gemm_config in fuse_gemm_configs: file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu" all_code = generate_source_cu( inputs_type, outputs_type, fuse_gemm_config[0], fuse_gemm_config[2], fuse_gemm_config[3][0], fuse_gemm_config[4], sm_dict[sm], ) file_dir = os.path.dirname(file_name) os.makedirs(file_dir, exist_ok=True) with open(file_name, "w") as f: f.write(all_code) f.close() fuse_gemm_config = list(fuse_gemm_configs)[0] act_tags = ["noact", "relu", "gelu"] # Compile parallelization generate_launch_gemm_cus( "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, outputs_type, fuse_gemm_config[0], fuse_gemm_config[1], fuse_gemm_config[2], act_tags, fuse_gemm_config[4], sm_dict[sm], min_split_k, max_split_k, ) # hard code for act_tag file_name = ( "gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu" ) all_code = generate_dispatch_gemm_cu( inputs_type, outputs_type, fuse_gemm_config[0], fuse_gemm_config[1], fuse_gemm_config[2], act_tags, fuse_gemm_config[4], sm_dict[sm], min_split_k, max_split_k, ) file_dir = os.path.dirname(file_name) os.makedirs(file_dir, exist_ok=True) with open(file_name, "w") as f: f.write(all_code) f.close() elif sm == 90: print("Not supported yet.") exit(0) else: raise ValueError(f"Unsupported SM: {sm}")