Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,631 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate cutlass_fp8_fp8_half_block_gemm_fused code."""
import argparse
import copy
import os
import re
def get_candidate_tiles():
"""
get_candidate_tiles returns a list of candidate tiles.
"""
cta_shape = [
("<_128, _128, _128>"),
# ("<_256, _128, _128>"),
]
cluster_shape = [
("<_1, _1, _1>"),
("<_2, _1, _1>"),
("<_1, _2, _1>"),
("<_2, _2, _1>"),
# ("<_1, _8, _1>"),
# ("<_8, _1, _1>"),
]
base_configs = [(x, y) for x in cta_shape for y in cluster_shape]
return base_configs
def get_candidate_configs(sm):
"""
get_candidate_configs returns a list of candidate configs.
"""
tiles = get_candidate_tiles()
candidate_configs = list()
hasbias = ("false", "true")
KernelSchedule = (
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", )
EpilogueSchedule = ("TmaWarpSpecializedCooperative", )
TileSchedule = ("PersistentScheduler", "StreamKScheduler")
for act_tag in [
("noact", "Identity"),
# ("relu", "ReLu"),
# ("gelu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule, TileSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
return tile_shape string.
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule,
tile_schedule):
"""
check the cutlass config valid.
"""
blocks, clusters = get_shape_str(tile_shape)
if int(blocks[0]) < 128 and "Cooperative" in kernel_schedule:
return False
return True
# this is a file's header part
CommonHead = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
#pragma once
#include "fp8_gemm_fused/fuse_block_gemm_act_template_3x.h"
"""
GemmDeclare = """
template<>
bool dispatch_fuse_block_gemm_c3x<phi::dtype::{input_type}, phi::dtype::{output_type},
{hasbias},
cutlass::epilogue::thread::{Activation},
Shape{TileShape},
Shape{ClusterShape},
cutlass::gemm::{KernelSchedule},
cutlass::epilogue::{EpilogueSchedule},
cutlass::gemm::{TileSchedule},
{SM}
>(GemmEpilogueAllParams);
"""
LaunchGemmHead = """
#pragma once
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
"""
LaunchGemmDeclare = """
bool launch_block_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params);
"""
LaunchGemmPart0 = """
#pragma once
#include "launch_block_gemm_kernel_sm{sm}.h"
bool launch_block_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params){
switch (type_id) {
"""
LaunchGemmPart1 = """
case {type_id}:
return dispatch_fuse_block_gemm_c3x<phi::dtype::{input_type}, phi::dtype::{output_type},
{hasbias},
cutlass::epilogue::thread::{Activation},
Shape{TileShape},
Shape{ClusterShape},
cutlass::gemm::{KernelSchedule},
cutlass::epilogue::{EpilogueSchedule},
cutlass::gemm::{TileSchedule},
{SM}
>(params);
break;
"""
LaunchGemmPart2 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
return false;
}
"""
code_part0 = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
#include <map>
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
#include "launch_block_gemm_kernel_sm{sm}.h"
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
std::map<std::string, int> block_gemm_type_map{"""
code_part1 = """
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
code_part2 = """
};
std::map<std::string, int> block_gemm_config_map{
"""
code_part3 = """ {"{TileShape}, {ClusterShape}, {kernel_schedule}, {epilogue_schedule}, {tile_schedule}", {tile_id}},
"""
code_part4 = """};
bool launch_block_gemm_kernel(const int type_id, const int kernel_id, GemmEpilogueAllParams params){
switch (kernel_id) {"""
code_part5 = """
case {tile_id}:
return launch_block_gemm_kernel_sm{sm}_{gemm_config}(type_id, params);
break;
"""
code_part6 = """
default:
throw std::runtime_error("fp8 gemm_fused Config is invalid.");
break;
}
return false;
}
template <typename T>
T get_relative_best(nlohmann::json* json_data,
const std::string& target_key,
const int& m,
const int& n,
const int& k) {
if (json_data->contains(target_key)) {
return json_data->at(target_key);
} else {
if (k > 3 * n){
return "<_128, _128, _128>, <_1, _2, _1>, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, TmaWarpSpecializedCooperative, StreamKScheduler";
}else{
return "<_128, _128, _128>, <_1, _2, _1>, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, TmaWarpSpecializedCooperative, PersistentScheduler";
}
}
}
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params) {
if (block_gemm_type_map.find(params.fuse_gemm_config) == block_gemm_type_map.end()) {
throw std::runtime_error("fp8 gemm_fused config is invalid.");
}
int type_id = block_gemm_type_map[params.fuse_gemm_config];
int M = (params.M + 31) / 32 * 32;
int N = params.N;
int K = params.K;
int kernel_id;
std::string mnk_string = "block_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
if(config_file_path == "tune"){ // tune kernel
int warm_up_times = 5;
int tune_times = 10;
std::string best_kernel_id = "";
float duratation = 1000000.f;
// tune all kernel_id kernels
for(const auto& config_pair : block_gemm_config_map){
std::cout << "Running tune kernel: " << config_pair.first<< std::endl;
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
if(!launch_block_gemm_kernel(type_id, config_pair.second, params)){
is_valid = false;
break;
}
}
if(!is_valid){
continue;
}
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaStreamSynchronize(params.stream);
cudaEventRecord(start, params.stream);
for(int num_time = 0; num_time < tune_times; ++num_time){
if(!launch_block_gemm_kernel(type_id, config_pair.second, params)){
is_valid = false;
break;
};
}
cudaEventRecord(stop, params.stream);
cudaEventSynchronize(stop);
float elapsedTime;
if(is_valid){
cudaEventElapsedTime(&elapsedTime, start, stop);
} else {
continue;
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
if(elapsedTime < duratation){
best_kernel_id = config_pair.first;
duratation = elapsedTime;
}
}
nlohmann::json new_json;
new_json[mnk_string] = best_kernel_id;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
return true;
} else { // run kernel
nlohmann::json* config_json = new nlohmann::json();
if (config_file_path != "" && config_file_path != "default") {
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
}
best_config = get_relative_best<std::string>(config_json, mnk_string, M, N, K);
if (block_gemm_config_map.find(best_config) == block_gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py and re-generate.");
} else {
kernel_id = block_gemm_config_map[best_config];
}
return launch_block_gemm_kernel(type_id, kernel_id, params);
}
}
"""
def SubstituteTemplate(template, values_base):
"""
SubstituteTemplate
"""
values = copy.deepcopy(values_base)
if values.get("KernelSchedule"
) is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule"
) is not None and "Auto" in values["EpilogueSchedule"]:
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = f"\\{{{key}\\}}"
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def parse_args():
"""
parse_args
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
type=str,
nargs="+",
default=["90"],
help="The CUDA architecture to be generated.",
)
args = parser.parse_args()
return args
# generate source .cu
def generate_source_cu(
inputs_type: (str),
outputs_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
TileSchedule: (str),
sm: str,
):
"""
generate_source_cu
"""
all_code = CommonHead
for input_type in inputs_type:
for output_type in outputs_type:
for hasbias in hasbiases:
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule:
if not check_config_valid(
tile_config, kernel_schedule,
epilogue_schedule, tile_schedule):
continue
value_dict = {
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"Activation": act_tag[1],
"TileShape": tile_config[0],
"ClusterShape": tile_config[1],
"KernelSchedule": kernel_schedule,
"EpilogueSchedule": epilogue_schedule,
"TileSchedule": tile_schedule,
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
generate_launch_gemm_cus
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
TileSchedule: (str) = single_config[5]
code_map = {}
head_path = os.path.join(generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule,
tile_schedule):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
}
head_all_code += SubstituteTemplate(
LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
f.close()
for tile_shape in tiles:
blocks, clusters = get_shape_str(tile_shape)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule,
tile_schedule):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
}
source_all_code = SubstituteTemplate(
LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"type_id": str(type_id),
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"Activation": act_tag[1],
"TileShape": tile_shape[0],
"ClusterShape": tile_shape[1],
"KernelSchedule": kernel_schedule,
"EpilogueSchedule": epilogue_schedule,
"TileSchedule": tile_schedule,
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
gemm_config_str = gemm_config_str.replace("<", "").replace(
">", "")
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
)
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
return head_all_code, code_map
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
generate_dispatch_gemm_cu
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
TileSchedule: (str) = single_config[5]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag[0],
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"type_id": str(type_id),
}
all_code += SubstituteTemplate(code_part1, value_dict)
type_id += 1
all_code += code_part2
tile_id = 0
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule,
tile_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
"ClusterShape": tile_shape[1],
"kernel_schedule": kernel_schedule,
"epilogue_schedule": epilogue_schedule,
"tile_schedule": tile_schedule,
"tile_id": str(tile_id),
}
all_code += SubstituteTemplate(code_part3, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part4, {"sm": sm[-2:]})
tile_id = 0
for tile_shape in tiles:
blocks, clusters = get_shape_str(tile_shape)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule,
tile_schedule):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"tile_id":
str(tile_id),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part6, {"sm": sm[-2:]})
return all_code
if __name__ == "__main__":
args = parse_args()
archs = args.cuda_arch
inputs_type = (
"float8_e4m3fn",
"float8_e5m2",
)
outputs_type = ("float16", "bfloat16")
sm_dict = {"90": "cutlass::arch::Sm90"}
for sm in archs:
if sm == "90":
fuse_gemm_configs = get_candidate_configs(sm)
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
f"generic_block_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
)
all_code = generate_source_cu(
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
fuse_gemm_config[3],
fuse_gemm_config[4],
fuse_gemm_config[5],
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
outputs_type, fuse_gemm_configs, sm_dict[sm])
# hard code for act_tag
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu")
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
else:
raise ValueError(f"Unsupported SM: {sm}")

View File

@@ -0,0 +1,673 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate dual_gemm_fused_kernel code."""
import argparse
import os
import re
def get_candidate_tiles():
"""
get_candidate_tiles returns a list of candidate tiles for the dual_gemm_fused_kernel.
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
return base_configs
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages,
max_stages):
"""
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
"""
tiles = get_candidate_tiles()
candidate_configs = list()
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1))
hasbias = ("false", "true")
for act_tag in [
("swiglu", "LeftSiLUAndMul"),
("geglu", "LeftGELUAndMul"),
]:
candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)])
return candidate_configs
# this is a file's header part
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
#pragma once
#include "fp8_gemm_fused/fuse_dual_gemm_{act_tag}_template.h"
"""
CommonTail = """
"""
GemmDeclare = """
template<>
bool dispatch_dual_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type}, phi::dtype::{bias_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(DualGemmEpilogueAllParams);
"""
GemmSplitKDeclare = """
template<>
bool dispatch_dual_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type}, phi::dtype::{bias_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(DualGemmEpilogueAllParams);
"""
LaunchGemmHead = """
#pragma once
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
"""
LaunchGemmDeclare = """
bool launch_dual_gemm_kernel_{gemm_config}(const int type_id, const int split_k, DualGemmEpilogueAllParams params);
"""
LaunchGemmPart0 = """
#pragma once
#include "launch_dual_gemm_kernel.h"
bool launch_dual_gemm_kernel_{gemm_config}(const int type_id, const int split_k, DualGemmEpilogueAllParams params){
if(split_k < 2){
params.split_k = 1;
switch (type_id) {
"""
LaunchGemmPart1 = """
case {type_id}:
return dispatch_dual_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type}, phi::dtype::{bias_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
break;
"""
LaunchGemmPart2 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
}else{
throw std::runtime_error("cutlass dual gemm split_k mode is not generated.");
}
return false;
}
"""
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
#include <map>
#include <regex>
#include <limits>
#include "helper.h"
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
#include "launch_dual_gemm_kernel.h"
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
std::map<std::string, int> dual_gemm_type_map{"""
code_part1 = """
{"{input_type}_{output_type}_{bias_type}_{hasbias}_{act_tag}", {type_id}}, """
code_part2 = """
};
std::map<std::string, int> dual_gemm_config_map{
"""
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
"""
code_part4 = """};
bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, DualGemmEpilogueAllParams params){
switch (kernel_id) {"""
code_part5 = """
case {tile_id}:
return launch_dual_gemm_kernel_{gemm_config}(type_id, split_k, params);
break;"""
code_part6 = """
default:
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
break;
}
return false;
}
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
if (dual_gemm_type_map.find(params.fuse_gemm_config) == dual_gemm_type_map.end()) {
throw std::runtime_error("fp8 gemm_fused config is invalid.");
}
int type_id = dual_gemm_type_map[params.fuse_gemm_config];
int M = (params.M+31)/32 *32;
int N = params.N;
int K = params.K;
std::string mnk_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
auto encoded_mnk_string = base64_encode(mnk_string);
std::string mnk_split_k_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
auto encoded_mnk_split_k_string = base64_encode(mnk_split_k_string);
int split_k;
int kernel_id;
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
nlohmann::json* config_json = new nlohmann::json();
if (config_file_path != "default") {
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
}
best_config = get_relative_best<std::string>(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
split_k = get_relative_best<int>(config_json, encoded_mnk_split_k_string, 1);
if (dual_gemm_config_map.find(best_config) == dual_gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate.");
} else {
kernel_id = dual_gemm_config_map[best_config];
}
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
} else { // tune kernel
int warm_up_times = 5;
int tune_times = 10;
std::string best_kernel_id = "";
int best_split_k = -1;
float duratation = 1000000.f;
// tune all split_k, kernel_id kernels
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
for(const auto& config_pair : dual_gemm_config_map){
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
is_valid = false;
break;
}
}
if(!is_valid){
continue;
}
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaStreamSynchronize(params.stream);
cudaEventRecord(start, params.stream);
for(int num_time = 0; num_time < tune_times; ++num_time){
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
is_valid = false;
break;
};
}
cudaEventRecord(stop, params.stream);
cudaEventSynchronize(stop);
float elapsedTime;
if(is_valid){
cudaEventElapsedTime(&elapsedTime, start, stop);
} else {
continue;
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
if(elapsedTime < duratation){
best_kernel_id = config_pair.first;
best_split_k = i;
duratation = elapsedTime;
}
}
}
nlohmann::json new_json;
new_json[encoded_mnk_string] = best_kernel_id;
new_json[encoded_mnk_split_k_string] = best_split_k;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
return true;
}
}
"""
def SubstituteTemplate(template, values):
"""
SubstituteTemplate
"""
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = f"\\{{{key}\\}}"
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def check_min_split_k(value):
"""
check_min_split_k
"""
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError(
"Dual gemm split_k mode is not support.")
return ivalue
def check_max_split_k(value):
"""
check_max_split_k
"""
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError(
"Dual gemm split_k mode is not support..")
return ivalue
def parse_args():
"""
parse_args
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
type=str,
nargs="+",
default=["89"],
help="The CUDA architecture to be generated.",
)
parser.add_argument(
"--min_split_k",
type=check_min_split_k,
default=1,
help="The min split k for the gemm kernel.",
)
parser.add_argument(
"--max_split_k",
type=check_max_split_k,
default=1,
help="The max split k for the gemm kernel.",
)
parser.add_argument(
"--min_stages",
type=int,
default=3,
help="The min stages for the gemm kernel.",
)
parser.add_argument(
"--max_stages",
type=int,
default=8,
help="The max stages for the gemm kernel.",
)
args = parser.parse_args()
return args
# generate source .cu
def generate_dual_gemm_source_cu(
inputs_type: str,
outputs_type: str,
biases_type: str,
stages: int,
tiles: str,
act_tag: str,
hasbiases: str,
sm: str,
min_split_k: int,
max_split_k: int,
):
"""
generate_dual_gemm_source_cu
"""
value_dict = {
"act_tag": act_tag,
}
all_code = SubstituteTemplate(CommonHead, value_dict)
for input_type in inputs_type:
for bias_type in biases_type:
for stage in stages:
for hasbias in hasbiases:
for tile_config in tiles:
value_dict = {
"input_type": input_type,
"output_type": input_type,
"bias_type": bias_type,
"thread_block_shape": tile_config[0],
"warp_shape": tile_config[1],
"mma_shape": tile_config[2],
"num_stages": str(stage),
"act_tag": act_tag,
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(GemmDeclare, value_dict)
if min_split_k > 1 and max_split_k > 1:
for input_type in inputs_type:
for bias_type in biases_type:
for stage in stages:
for hasbias in hasbiases:
for tile_config in tiles:
value_dict = {
"input_type": input_type,
"output_type": input_type,
"bias_type": bias_type,
"thread_block_shape": tile_config[0],
"warp_shape": tile_config[1],
"mma_shape": tile_config[2],
"num_stages": str(stage),
"act_tag": act_tag,
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(
GemmSplitKDeclare, value_dict)
all_code += CommonTail
return all_code
# generate gemm launch .cu
def generate_launch_dual_gemm_cus(
generate_dir: str,
inputs_type: str,
outputs_type: str,
stages: int,
split_ks: int,
tiles: str,
act_tags: str,
hasbiases: str,
sm: str,
min_split_k: int,
max_split_k: int,
):
"""
generate_launch_dual_gemm_cus
"""
code_map = {}
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
# split_k_code = ""
type_id = 0
for input_type in inputs_type:
for bias_type in biases_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag,
"input_type": input_type,
"output_type": input_type,
"bias_type": bias_type,
"hasbias": hasbias,
"type_id": str(type_id),
"thread_block_shape": tile[0],
"warp_shape": tile[1],
"mma_shape": tile[2],
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
# source_all_code += split_k_code
# source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
return head_all_code, code_map
# generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_dual_gemm_cu(
inputs_type: str,
outputs_type: str,
biases_type: str,
stages: int,
split_ks: int,
tiles: str,
act_tags: str,
hasbiases: str,
sm: str,
min_split_k: int,
max_split_k: int,
):
"""
generate_dispatch_dual_gemm_cu
"""
all_code = code_part0
type_id = 0
for input_type in inputs_type:
for bias_type in biases_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag,
"input_type": input_type,
"output_type": input_type,
"bias_type": bias_type,
"hasbias": hasbias,
"type_id": str(type_id),
}
all_code += SubstituteTemplate(code_part1, value_dict)
type_id += 1
all_code += code_part2
tile_id = 0
for tile in tiles:
for stage in stages:
value_dict = {
"thread_block_shape": tile[0],
"warp_shape": tile[1],
"mma_shape": tile[2],
"num_stages": str(stage),
"tile_id": str(tile_id),
}
all_code += SubstituteTemplate(code_part3, value_dict)
tile_id += 1
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"tile_id": str(tile_id),
"gemm_config": gemm_config_str,
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
value_dict.update({
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
})
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
if __name__ == "__main__":
args = parse_args()
archs = args.cuda_arch
min_split_k = args.min_split_k
max_split_k = args.max_split_k
min_stages = args.min_stages
max_stages = args.max_stages
inputs_type = ("float8_e4m3fn", "float8_e5m2")
biases_type = ("float16", "bfloat16")
outputs_type = ("float8_e4m3fn", "float8_e4m3fn")
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_dual_gemm_candidate_configs(
sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_dual_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
)
all_code = generate_dual_gemm_source_cu(
inputs_type,
outputs_type,
biases_type,
fuse_gemm_config[0],
fuse_gemm_config[2],
fuse_gemm_config[3][0],
fuse_gemm_config[4],
sm_dict[sm],
min_split_k,
max_split_k,
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
fuse_gemm_config = list(fuse_gemm_configs)[0]
act_tags = ["swiglu", "geglu"]
# Compile parallelization
generate_launch_dual_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
act_tags,
fuse_gemm_config[4],
sm_dict[sm],
min_split_k,
max_split_k,
)
# hard code for act_tag
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_dual_gemm_scale_bias_act.cu"
all_code = generate_dispatch_dual_gemm_cu(
inputs_type,
outputs_type,
biases_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
act_tags,
fuse_gemm_config[4],
sm_dict[sm],
min_split_k,
max_split_k,
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
elif sm == 90:
print("Not supported yet.")
exit(0)
else:
raise ValueError(f"Unsupported SM: {sm}")

View File

@@ -0,0 +1,592 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate dual_gemm_fused_kernel code."""
import argparse
import os
import re
def get_candidate_tiles():
"""
"""
cta_shape = [
("<_64, _16, _128>"),
("<_64, _32, _128>"),
("<_64, _64, _128>"),
("<_64, _128, _128>"),
("<_128, _16, _128>"),
("<_128, _32, _128>"),
("<_128, _64, _128>"),
# ("<_128, _128, _128>"),
]
cluster_shape = [
("<_1, _1, _1>"),
("<_2, _1, _1>"),
("<_1, _2, _1>"),
# ("<_2, _2, _1>"),
# ("<_1, _8, _1>"),
# ("<_8, _1, _1>"),
]
base_configs = [(x, y) for x in cta_shape for y in cluster_shape]
return base_configs
def get_dual_gemm_candidate_configs(sm):
"""
"""
tiles = get_candidate_tiles()
candidate_configs = list()
hasbias = (
"false",
# "true",
)
KernelSchedule = (
# "KernelTmaWarpSpecializedFP8FastAccum",
"KernelTmaWarpSpecializedPingpongFP8FastAccum",
"KernelTmaWarpSpecializedCooperativeFP8FastAccum",
)
EpilogueSchedule = ("TmaWarpSpecialized", "TmaWarpSpecializedCooperative")
for act_tag in [
("swiglu", "SiLu"),
("geglu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
"""
"""
blocks, clusters = get_shape_str(tile_shape)
if int(
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False
if tile_shape[
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
return False
return True
# this is a file's header part
CommonHead = """// Generated by auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py - Do not edit.
#pragma once
#include "fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h"
using namespace cute;
"""
GemmDeclare = """
template<>
bool dispatch_dual_gemm_act_sm{sm}<phi::dtype::{input_type},
Shape{TileShape},
Shape{ClusterShape},
cutlass::gemm::{KernelSchedule},
cutlass::epilogue::{EpilogueSchedule},
void,
cutlass::epilogue::thread::{Activation}>(DualGemmEpilogueAllParams);
"""
LaunchGemmHead = """
#pragma once
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
"""
LaunchGemmDeclare = """
bool launch_dual_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, DualGemmEpilogueAllParams params);
"""
LaunchGemmPart0 = """
#pragma once
#include "launch_dual_gemm_kernel_sm{sm}.h"
bool launch_dual_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, DualGemmEpilogueAllParams params){
using namespace cute;
switch (type_id) {
"""
LaunchGemmPart1 = """
case {type_id}:
return dispatch_dual_gemm_act_sm{sm}<phi::dtype::{input_type},
Shape{TileShape},
Shape{ClusterShape},
cutlass::gemm::{KernelSchedule},
cutlass::epilogue::{EpilogueSchedule},
void,
cutlass::epilogue::thread::{Activation}>(params);
break;
"""
LaunchGemmPart2 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
return false;
}
"""
code_part0 = """// Generated by auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py - Do not edit.
#include <map>
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
#include "launch_dual_gemm_kernel_sm{sm}.h"
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
std::map<std::string, int> dual_gemm_type_map{"""
code_part1 = """
{"{input_type}_{output_type}_{bias_type}_{hasbias}_{act_tag}", {type_id}}, """
code_part2 = """
};
std::map<std::string, int> dual_gemm_config_map{
"""
code_part3 = """ {"{TileShape}, {ClusterShape}, {kernel_schedule}, {epilogue_schedule}", {tile_id}},
"""
code_part4 = """};
bool launch_dual_gemm_kernel_sm{sm}(const int type_id, const int kernel_id, DualGemmEpilogueAllParams params){
switch (kernel_id) {"""
code_part5 = """
case {tile_id}:
return launch_dual_gemm_kernel_sm{sm}_{gemm_config}(type_id, params);
break;"""
code_part6 = """
default:
throw std::runtime_error("fp8 dual gemm_fused Config is invalid.");
break;
}
return false;
}
template <typename T>
T get_relative_best(nlohmann::json* json_data,
const std::string& target_key,
const int& m) {
if (json_data->contains(target_key)) {
return json_data->at(target_key);
} else {
if (m <= 6400){
return "<_64, _32, _128>, <_1, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
}else{
return "<_64, _64, _128>, <_1, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
}
}
}
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
if (dual_gemm_type_map.find(params.fuse_gemm_config) == dual_gemm_type_map.end()) {
throw std::runtime_error("fp8 dual gemm_fused config is invalid.");
}
int type_id = dual_gemm_type_map[params.fuse_gemm_config];
int M = (params.M + 31) / 32 * 32;
int N = params.N;
int K = params.K;
int kernel_id;
std::string mnk_string = "tensor_dual_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
if(config_file_path == "tune"){ // tune kernel
int warm_up_times = 5;
int tune_times = 10;
std::string best_kernel_id = "";
float duratation = 1000000.f;
// tune all kernel_id kernels
for(const auto& config_pair : dual_gemm_config_map){
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
if(!launch_dual_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
is_valid = false;
break;
}
}
if(!is_valid){
continue;
}
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaStreamSynchronize(params.stream);
cudaEventRecord(start, params.stream);
for(int num_time = 0; num_time < tune_times; ++num_time){
if(!launch_dual_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
is_valid = false;
break;
};
}
cudaEventRecord(stop, params.stream);
cudaEventSynchronize(stop);
float elapsedTime;
if(is_valid){
cudaEventElapsedTime(&elapsedTime, start, stop);
} else {
continue;
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
if(elapsedTime < duratation){
best_kernel_id = config_pair.first;
duratation = elapsedTime;
}
}
nlohmann::json new_json;
new_json[mnk_string] = best_kernel_id;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
return true;
} else { // run kernel
nlohmann::json* config_json = new nlohmann::json();
if (config_file_path != "" && config_file_path != "default") {
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
}
best_config = get_relative_best<std::string>(config_json, mnk_string, M);
if (dual_gemm_config_map.find(best_config) == dual_gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_dual_gemm_fused_kernels.py and re-generate.");
} else {
kernel_id = dual_gemm_config_map[best_config];
}
return launch_dual_gemm_kernel_sm{sm}(type_id, kernel_id, params);
}
}
"""
def SubstituteTemplate(template, values):
"""
"""
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = f"\\{{{key}\\}}"
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser.add_argument(
"--cuda_arch",
type=str,
nargs="+",
default=["90"],
help="The CUDA architecture to be generated.",
)
args = parser.parse_args()
return args
# generate source .cu
def generate_dual_gemm_source_cu(
inputs_type: (str),
biases_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
sm: str,
):
"""
"""
all_code = CommonHead
for input_type in inputs_type:
for bias_type in biases_type:
for hasbias in hasbiases:
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config,
kernel_schedule,
epilogue_schedule):
continue
value_dict = {
"input_type": input_type,
"bias_type": bias_type,
"hasbias": hasbias,
"Activation": act_tag[1],
"TileShape": tile_config[0],
"ClusterShape": tile_config[1],
"KernelSchedule": kernel_schedule,
"EpilogueSchedule": epilogue_schedule,
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_dual_gemm_cus(
generate_dir: (str), inputs_type: (str), biases_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
code_map = {}
head_path = os.path.join(generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
f.close()
for tile_shape in tiles:
blocks, clusters = get_shape_str(tile_shape)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0,
value_dict)
type_id = 0
for input_type in inputs_type:
for bias_type in biases_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"type_id": str(type_id),
"input_type": input_type,
"bias_type": bias_type,
"hasbias": hasbias,
"Activation": act_tag[1],
"TileShape": tile_shape[0],
"ClusterShape": tile_shape[1],
"KernelSchedule": kernel_schedule,
"EpilogueSchedule": epilogue_schedule,
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
)
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
return head_all_code, code_map
# generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
for input_type in inputs_type:
for bias_type in biases_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag[0],
"input_type": input_type,
"output_type": input_type,
"bias_type": bias_type,
"hasbias": hasbias,
"type_id": str(type_id),
}
all_code += SubstituteTemplate(code_part1, value_dict)
type_id += 1
all_code += code_part2
tile_id = 0
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
"ClusterShape": tile_shape[1],
"kernel_schedule": kernel_schedule,
"epilogue_schedule": epilogue_schedule,
"tile_id": str(tile_id),
}
all_code += SubstituteTemplate(code_part3, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part4, {"sm": sm[-2:]})
tile_id = 0
for tile_shape in tiles:
blocks, clusters = get_shape_str(tile_shape)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"tile_id": str(tile_id),
"gemm_config": gemm_config_str,
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part6, {"sm": sm[-2:]})
return all_code
if __name__ == "__main__":
args = parse_args()
archs = args.cuda_arch
inputs_type = ("float8_e4m3fn", "float8_e5m2")
biases_type = (
"float16",
"bfloat16",
)
sm_dict = {"90": "cutlass::arch::Sm90"}
for sm in archs:
if sm == "90":
fuse_gemm_configs = get_dual_gemm_candidate_configs(sm)
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_dual_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
)
all_code = generate_dual_gemm_source_cu(
inputs_type,
biases_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
fuse_gemm_config[3],
fuse_gemm_config[4],
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
# Compile parallelization
generate_launch_dual_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
biases_type, fuse_gemm_configs, sm_dict[sm])
# hard code for act_tag
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_dual_gemm_cu(
inputs_type,
biases_type,
fuse_gemm_configs,
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
else:
raise ValueError(f"Unsupported SM: {sm}")

View File

@@ -0,0 +1,682 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate gemm_fused_kernel code."""
import argparse
import os
import re
def get_candidate_tiles():
"""
获取候选的tile配置列表。
Args:
无参数。
Returns:
List[Tuple[str, str, str]]: 包含tile配置的三元组列表每个三元组中的字符串表示tile的形状"
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
return base_configs
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages,
max_stages):
"""
获取候选的gemm算子配置列表。
Args:
sm (str): 计算能力,如"70"
min_split_k (int): split k的最小值。
"""
tiles = get_candidate_tiles()
candidate_configs = list()
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1))
hasbias = ("false", "true")
for act_tag in [
("noact", "LinearCombination"),
("relu", "LinearCombinationRelu"),
("gelu", "LinearCombinationGELU"),
]:
candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)])
return candidate_configs
# this is a file's header part
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
#pragma once
#include "fp8_gemm_fused/fuse_gemm_{act_tag}_template.h"
"""
CommonTail = """
"""
GemmDeclare = """
template<>
bool dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
"""
GemmSplitKDeclare = """
template<>
bool dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
"""
LaunchGemmHead = """
#pragma once
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
"""
LaunchGemmDeclare = """
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params);
"""
LaunchGemmPart0 = """
#pragma once
#include "launch_gemm_kernel.h"
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params){
if(split_k < 2){
params.split_k = 1;
switch (type_id) {
"""
LaunchGemmPart1 = """
case {type_id}:
return dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
break;
"""
LaunchGemmPart2 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
}else{
switch (type_id) {
"""
LaunchGemmPart3 = """
case {type_id}:
return dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
break;
"""
LaunchGemmPart4 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
}
return false;
}
"""
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
#include <map>
#include <regex>
#include <limits>
#include "helper.h"
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
#include "launch_gemm_kernel.h"
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
std::map<std::string, int> gemm_type_map{"""
code_part1 = """
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
code_part2 = """
};
std::map<std::string, int> gemm_config_map{
"""
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
"""
code_part4 = """};
bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){
switch (kernel_id) {"""
code_part5 = """
case {tile_id}:
return launch_gemm_kernel_{gemm_config}(type_id, split_k, params);
break;"""
code_part6 = """
default:
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
break;
}
return false;
}
template <typename T>
T get_relative_best(nlohmann::json* json_data,
const std::string& target_key,
const std::string& regex_key,
const int& m,
const T& default_value) {
if (json_data->contains(target_key)) {
return json_data->at(target_key);
} else {
std::regex pattern(regex_key);
std::string closest_key;
int closest_diff = std::numeric_limits<int>::max();
T closest_value = default_value;
for (const auto& [key, value] : json_data->items()) {
std::smatch matches;
if (std::regex_search(key, matches, pattern)) {
int relative_m = std::stoi(matches[1].str());
int diff = std::abs(relative_m - m);
if (diff < closest_diff) {
closest_diff = diff;
closest_value = value;
}
}
}
json_data->push_back({target_key, closest_value});
return closest_value;
}
}
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
throw std::runtime_error("fp8 gemm_fused config is invalid.");
}
int type_id = gemm_type_map[params.fuse_gemm_config];
int M = (params.M+31)/32 *32;
int N = params.N;
int K = params.K;
std::string mnk_string = "tensor_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
std::string regex_mnk_string = "tensor_gemm_sm90<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">";
std::string mnk_split_k_string = "tensor_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
std::string regex_mnk_split_k_string = "tensor_gemm_sm90<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">, split_k";
int split_k;
int kernel_id;
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
if(config_file_path == "tune"){ // tune kernel
int warm_up_times = 5;
int tune_times = 10;
std::string best_kernel_id = "";
int best_split_k = -1;
float duratation = 1000000.f;
// tune all split_k, kernel_id kernels
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
for(const auto& config_pair : gemm_config_map){
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
is_valid = false;
break;
}
}
if(!is_valid){
continue;
}
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaStreamSynchronize(params.stream);
cudaEventRecord(start, params.stream);
for(int num_time = 0; num_time < tune_times; ++num_time){
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
is_valid = false;
break;
};
}
cudaEventRecord(stop, params.stream);
cudaEventSynchronize(stop);
float elapsedTime;
if(is_valid){
cudaEventElapsedTime(&elapsedTime, start, stop);
} else {
continue;
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
if(elapsedTime < duratation){
best_kernel_id = config_pair.first;
best_split_k = i;
duratation = elapsedTime;
}
}
}
nlohmann::json new_json;
new_json[mnk_string] = best_kernel_id;
new_json[mnk_split_k_string] = best_split_k;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
return true;
} else { // run kernel
nlohmann::json* config_json = new nlohmann::json();
if (config_file_path != "" && config_file_path != "default") {
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
}
best_config = get_relative_best<std::string>(config_json, mnk_string, regex_mnk_string, M, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
split_k = get_relative_best<int>(config_json, mnk_split_k_string, regex_mnk_split_k_string, M, 1);
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_gemm_fused_kernels.py and re-generate.");
} else {
kernel_id = gemm_config_map[best_config];
}
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
}
}
"""
def SubstituteTemplate(template, values):
"""
生成函数模板
"""
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = f"\\{{{key}\\}}"
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def parse_args():
"""
代码参数解析
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
type=str,
nargs="+",
default=["89"],
help="The CUDA architecture to be generated.",
)
parser.add_argument(
"--min_split_k",
type=int,
default=2,
help="The min split k for the gemm kernel.",
)
parser.add_argument(
"--max_split_k",
type=int,
default=6,
help="The max split k for the gemm kernel.",
)
parser.add_argument(
"--min_stages",
type=int,
default=2,
help="The min stages for the gemm kernel.",
)
parser.add_argument(
"--max_stages",
type=int,
default=8,
help="The max stages for the gemm kernel.",
)
args = parser.parse_args()
return args
# generate source .cu
def generate_source_cu(
inputs_type: str,
outputs_type: str,
stages: int,
tiles: str,
act_tag: str,
hasbiases: str,
sm: str,
):
"""
生成.cu源文件
"""
value_dict = {
"act_tag": act_tag,
}
all_code = SubstituteTemplate(CommonHead, value_dict)
for input_type in inputs_type:
for output_type in outputs_type:
for stage in stages:
for hasbias in hasbiases:
for tile_config in tiles:
value_dict = {
"input_type": input_type,
"output_type": output_type,
"thread_block_shape": tile_config[0],
"warp_shape": tile_config[1],
"mma_shape": tile_config[2],
"num_stages": str(stage),
"act_tag": act_tag,
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(GemmDeclare, value_dict)
for input_type in inputs_type:
for output_type in outputs_type:
for stage in stages:
for hasbias in hasbiases:
for tile_config in tiles:
value_dict = {
"input_type": input_type,
"output_type": output_type,
"thread_block_shape": tile_config[0],
"warp_shape": tile_config[1],
"mma_shape": tile_config[2],
"num_stages": str(stage),
"act_tag": act_tag,
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(GemmSplitKDeclare,
value_dict)
all_code += CommonTail
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: str,
inputs_type: str,
outputs_type: str,
stages: int,
split_ks: int,
tiles: str,
act_tags: str,
hasbiases: str,
sm: str,
min_split_k: int,
max_split_k: int,
):
code_map = {}
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
split_k_code = ""
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag,
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"type_id": str(type_id),
"thread_block_shape": tile[0],
"warp_shape": tile[1],
"mma_shape": tile[2],
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
split_k_code += SubstituteTemplate(
LaunchGemmPart3, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
source_all_code += split_k_code
source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
return head_all_code, code_map
# generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_gemm_cu(
inputs_type: str,
outputs_type: str,
stages: int,
split_ks: int,
tiles: str,
act_tags: str,
hasbiases: str,
sm: str,
min_split_k: int,
max_split_k: int,
):
all_code = code_part0
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag,
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"type_id": str(type_id),
}
all_code += SubstituteTemplate(code_part1, value_dict)
type_id += 1
all_code += code_part2
tile_id = 0
for tile in tiles:
for stage in stages:
value_dict = {
"thread_block_shape": tile[0],
"warp_shape": tile[1],
"mma_shape": tile[2],
"num_stages": str(stage),
"tile_id": str(tile_id),
}
all_code += SubstituteTemplate(code_part3, value_dict)
tile_id += 1
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"tile_id": str(tile_id),
"gemm_config": gemm_config_str,
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
value_dict.update({
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
})
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
if __name__ == "__main__":
args = parse_args()
archs = args.cuda_arch
min_split_k = args.min_split_k
max_split_k = args.max_split_k
min_stages = args.min_stages
max_stages = args.max_stages
inputs_type = ("float8_e4m3fn", "float8_e5m2")
outputs_type = ("float16", "bfloat16")
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_split_k,
max_split_k, min_stages,
max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
all_code = generate_source_cu(
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[2],
fuse_gemm_config[3][0],
fuse_gemm_config[4],
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
fuse_gemm_config = list(fuse_gemm_configs)[0]
act_tags = ["noact", "relu", "gelu"]
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
act_tags,
fuse_gemm_config[4],
sm_dict[sm],
min_split_k,
max_split_k,
)
# hard code for act_tag
file_name = (
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
)
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
act_tags,
fuse_gemm_config[4],
sm_dict[sm],
min_split_k,
max_split_k,
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
elif sm == 90:
print("Not supported yet.")
exit(0)
else:
raise ValueError(f"Unsupported SM: {sm}")

View File

@@ -0,0 +1,614 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate gemm_fused_kernel code."""
import argparse
import copy
import os
import re
def get_candidate_tiles():
"""
"""
base_configs = [
("<_64, _64, _128>", "<_1, _8, _1>"),
("<_64, _128, _128>", "<_2, _1, _1>"),
("<_128, _128, _128>", "<_2, _1, _1>"),
]
base_configs.extend([
("<_64, _64, _128>", "<_1, _1, _1>"),
("<_64, _64, _128>", "<_1, _2, _1>"),
("<_64, _64, _128>", "<_2, _1, _1>"),
("<_64, _64, _64>", "<_1, _1, _1>"),
("<_64, _64, _64>", "<_1, _2, _1>"),
("<_64, _64, _64>", "<_2, _1, _1>"),
("<_64, _128, _128>", "<_1, _2, _1>"),
("<_64, _128, _128>", "<_1, _1, _1>"),
("<_128, _128, _64>", "<_2, _1, _1>"),
("<_256, _128, _128>", "<_1, _2, _1>"),
("<_256, _128, _128>", "<_1, _1, _1>"),
# The following configurations are rarely selected in Qwen2-7B-model.
# ("<_256, _128, _128>", "<_4, _1, _1>"),
# ("<_256, _128, _128>", "<_1, _4, _1>"),
# ("<_256, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _256>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_4, _1, _1>"),
# ("<_128, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _128>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_1, _1, _1>"),
# ("<_128, _128, _128>", "<_1, _4, _1>"),
# ("<_128, _128, _64>", "<_2, _2, _1>"),
])
return base_configs
def get_candidate_configs(sm):
"""
"""
tiles = get_candidate_tiles()
candidate_configs = list()
hasbias = ("false", "true")
KernelSchedule = (
"KernelTmaWarpSpecializedFP8FastAccum",
"KernelTmaWarpSpecializedPingpongFP8FastAccum",
# "KernelTmaWarpSpecializedCooperativeFP8FastAccum",
)
EpilogueSchedule = ("TmaWarpSpecialized", "TmaWarpSpecializedCooperative")
for act_tag in [
("noact", "Identity"),
("relu", "ReLu"),
("gelu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
"""
"""
blocks, clusters = get_shape_str(tile_shape)
if int(
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False
if (tile_shape[0] == "<_256, _128, _128>"
and "Cooperative" not in kernel_schedule
and "Cooperative" not in epilogue_schedule):
return False
return True
# this is a file's header part
CommonHead = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
#pragma once
#include "fp8_gemm_fused/fuse_gemm_act_template_3x.h"
using namespace cute;
"""
GemmDeclare = """
template<>
bool dispatch_fuse_gemm_act_sm{sm}<phi::dtype::{input_type}, phi::dtype::{output_type},
{hasbias},
cutlass::epilogue::thread::{Activation},
Shape{TileShape},
Shape{ClusterShape},
cutlass::gemm::{KernelSchedule},
cutlass::epilogue::{EpilogueSchedule},
{SM}
>(GemmEpilogueAllParams);
"""
LaunchGemmHead = """
#pragma once
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
"""
LaunchGemmDeclare = """
bool launch_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params);
"""
LaunchGemmPart0 = """
#pragma once
#include "launch_gemm_kernel_sm{sm}.h"
bool launch_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params){
using namespace cute;
switch (type_id) {
"""
LaunchGemmPart1 = """
case {type_id}:
return dispatch_fuse_gemm_act_sm{sm}<phi::dtype::{input_type}, phi::dtype::{output_type},
{hasbias},
cutlass::epilogue::thread::{Activation},
Shape{TileShape},
Shape{ClusterShape},
cutlass::gemm::{KernelSchedule},
cutlass::epilogue::{EpilogueSchedule},
{SM}
>(params);
break;
"""
LaunchGemmPart2 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
return false;
}
"""
code_part0 = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
#include <map>
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
#include "launch_gemm_kernel_sm{sm}.h"
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
std::map<std::string, int> gemm_type_map{"""
code_part1 = """
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
code_part2 = """
};
std::map<std::string, int> gemm_config_map{
"""
code_part3 = """ {"{TileShape}, {ClusterShape}, {kernel_schedule}, {epilogue_schedule}", {tile_id}},
"""
code_part4 = """};
bool launch_gemm_kernel_sm{sm}(const int type_id, const int kernel_id, GemmEpilogueAllParams params){
switch (kernel_id) {"""
code_part5 = """
case {tile_id}:
return launch_gemm_kernel_sm{sm}_{gemm_config}(type_id, params);
break;
"""
code_part6 = """
default:
throw std::runtime_error("fp8 gemm_fused Config is invalid.");
break;
}
return false;
}
template <typename T>
T get_relative_best(nlohmann::json* json_data,
const std::string& target_key,
const int& m) {
if (json_data->contains(target_key)) {
return json_data->at(target_key);
} else {
if (m <= 64){
return "<_64, _64, _128>, <_1, _8, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
}else if(m <= 128){
return "<_64, _128, _128>, <_2, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
}else{
return "<_128, _128, _128>, <_2, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
}
}
}
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
throw std::runtime_error("fp8 gemm_fused config is invalid.");
}
int type_id = gemm_type_map[params.fuse_gemm_config];
int M = (params.M + 31) / 32 * 32;
int N = params.N;
int K = params.K;
int kernel_id;
std::string mnk_string = "tensor_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
if(config_file_path == "tune"){ // tune kernel
int warm_up_times = 5;
int tune_times = 10;
std::string best_kernel_id = "";
float duratation = 1000000.f;
// tune all kernel_id kernels
for(const auto& config_pair : gemm_config_map){
std::cout << "Running tune kernel: " << config_pair.first<< std::endl;
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
if(!launch_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
is_valid = false;
break;
}
}
if(!is_valid){
continue;
}
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaStreamSynchronize(params.stream);
cudaEventRecord(start, params.stream);
for(int num_time = 0; num_time < tune_times; ++num_time){
if(!launch_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
is_valid = false;
break;
};
}
cudaEventRecord(stop, params.stream);
cudaEventSynchronize(stop);
float elapsedTime;
if(is_valid){
cudaEventElapsedTime(&elapsedTime, start, stop);
} else {
continue;
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
if(elapsedTime < duratation){
best_kernel_id = config_pair.first;
duratation = elapsedTime;
}
}
nlohmann::json new_json;
new_json[mnk_string] = best_kernel_id;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
return true;
} else { // run kernel
nlohmann::json* config_json = new nlohmann::json();
if (config_file_path != "" && config_file_path != "default") {
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
}
best_config = get_relative_best<std::string>(config_json, mnk_string, M);
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py and re-generate.");
} else {
kernel_id = gemm_config_map[best_config];
}
return launch_gemm_kernel_sm{sm}(type_id, kernel_id, params);
}
}
"""
def SubstituteTemplate(template, values_base):
"""
"""
values = copy.deepcopy(values_base)
if values.get("KernelSchedule"
) is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule"
) is not None and "Auto" in values["EpilogueSchedule"]:
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = f"\\{{{key}\\}}"
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser.add_argument(
"--cuda_arch",
type=str,
nargs="+",
default=["90"],
help="The CUDA architecture to be generated.",
)
args = parser.parse_args()
return args
# generate source .cu
def generate_source_cu(
inputs_type: (str),
outputs_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
sm: str,
):
"""
"""
all_code = CommonHead
for input_type in inputs_type:
for output_type in outputs_type:
for hasbias in hasbiases:
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config,
kernel_schedule,
epilogue_schedule):
continue
value_dict = {
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"Activation": act_tag[1],
"TileShape": tile_config[0],
"ClusterShape": tile_config[1],
"KernelSchedule": kernel_schedule,
"EpilogueSchedule": epilogue_schedule,
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
code_map = {}
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
f.close()
for tile_shape in tiles:
blocks, clusters = get_shape_str(tile_shape)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0,
value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"type_id": str(type_id),
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"Activation": act_tag[1],
"TileShape": tile_shape[0],
"ClusterShape": tile_shape[1],
"KernelSchedule": kernel_schedule,
"EpilogueSchedule": epilogue_schedule,
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
return head_all_code, code_map
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for act_tag in act_tags:
for hasbias in hasbiases:
value_dict = {
"act_tag": act_tag[0],
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"type_id": str(type_id),
}
all_code += SubstituteTemplate(code_part1, value_dict)
type_id += 1
all_code += code_part2
tile_id = 0
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
"ClusterShape": tile_shape[1],
"kernel_schedule": kernel_schedule,
"epilogue_schedule": epilogue_schedule,
"tile_id": str(tile_id),
}
all_code += SubstituteTemplate(code_part3, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part4, {"sm": sm[-2:]})
tile_id = 0
for tile_shape in tiles:
blocks, clusters = get_shape_str(tile_shape)
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"tile_id": str(tile_id),
"gemm_config": gemm_config_str,
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part6, {"sm": sm[-2:]})
return all_code
if __name__ == "__main__":
args = parse_args()
archs = args.cuda_arch
inputs_type = (
"float8_e4m3fn",
"float8_e5m2",
)
outputs_type = ("float16", "bfloat16")
sm_dict = {"90": "cutlass::arch::Sm90"}
for sm in archs:
if sm == "90":
fuse_gemm_configs = get_candidate_configs(sm)
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu")
all_code = generate_source_cu(
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
fuse_gemm_config[3],
fuse_gemm_config[4],
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
outputs_type, fuse_gemm_configs, sm_dict[sm])
# hard code for act_tag
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
else:
raise ValueError(f"Unsupported SM: {sm}")

View File

@@ -0,0 +1,568 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate gemm_fused_kernel code."""
import argparse
import os
import re
def get_candidate_tiles():
"""
获取候选的tile配置列表。
Args:
无参数。
Returns:
List[Tuple[str, str, str]]: 包含tile配置的三元组列表每个三元组中的字符串表示tile的形状"
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
return base_configs
def get_candidate_configs(sm, min_stages, max_stages):
"""
获取候选的gemm算子配置列表。
Args:
sm (str): 计算能力,如"70"
"""
tiles = get_candidate_tiles()
candidate_configs = list()
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
hasbias = ("false", "true")
candidate_configs.extend([(stages, tiles, hasbias)])
return candidate_configs
# this is a file's header part
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
#pragma once
#include "fp8_gemm_fused/visitor_fp8_gemm_fused_template.h"
"""
CommonTail = """
"""
GemmDeclare = """
template<>
bool dispatch_visitor_fuse_gemm<phi::dtype::{input_type}, phi::dtype::{output_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
"""
LaunchGemmHead = """
#pragma once
#include "fp8_gemm_fused/visitor_fp8_gemm_fused.h"
"""
LaunchGemmDeclare = """
bool launch_visitor_gemm_fused_kernel_{gemm_config}(const int type_id, GemmEpilogueAllParams params);
"""
LaunchGemmPart0 = """
#pragma once
#include "fp8_gemm_fused/visitor_fp8_gemm_fused_template.h"
bool launch_visitor_gemm_fused_kernel_{gemm_config}(const int type_id, GemmEpilogueAllParams params){
switch (type_id) {"""
LaunchGemmPart1 = """
case {type_id}:
return dispatch_visitor_fuse_gemm<phi::dtype::{input_type}, phi::dtype::{output_type},
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
break;"""
LaunchGemmPart2 = """
default:
throw std::runtime_error("cutlass gemm config is invalid.");
break;
}
return false;
}
"""
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
#include <map>
#include <regex>
#include <limits>
#include "helper.h"
#include "visitor_fp8_gemm_fused.h"
#include "launch_visitor_gemm_fused_kernel.h"
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
std::map<std::string, int> per_channel_gemm_type_map{"""
code_part1 = """
{"{input_type}_{output_type}_{hasbias}", {type_id}}, """
code_part2 = """
};
std::map<std::string, int> per_channel_gemm_config_map{
"""
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
"""
code_part4 = """};
bool launch_visitor_gemm_fused_kernel(const int type_id, const int kernel_id, GemmEpilogueAllParams params){
switch (kernel_id) {"""
code_part5 = """
case {tile_id}:
return launch_visitor_gemm_fused_kernel_{gemm_config}(type_id, params);
break;"""
code_part6 = """
default:
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
break;
}
return false;
}
bool fp8_visitor_gemm_fused(GemmEpilogueAllParams params) {
if (per_channel_gemm_type_map.find(params.fuse_gemm_config) == per_channel_gemm_type_map.end()) {
throw std::runtime_error("fp8_visitor_gemm_fused config is invalid.");
}
int type_id = per_channel_gemm_type_map[params.fuse_gemm_config];
int M = (params.M+31)/32 *32;
int N = params.N;
int K = params.K;
std::string mnk_string = "per_channel_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
auto encoded_mnk_string = base64_encode(mnk_string);
int kernel_id;
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
nlohmann::json* config_json = new nlohmann::json();
if (config_file_path != "default") {
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
}
best_config = get_relative_best<std::string>(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
if (per_channel_gemm_config_map.find(best_config) == per_channel_gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_visitor_fp8_gemm_fused_kernels.py and re-generate.");
} else {
kernel_id = per_channel_gemm_config_map[best_config];
}
return launch_visitor_gemm_fused_kernel(type_id, kernel_id, params);
} else { // tune kernel
int warm_up_times = 5;
int tune_times = 10;
std::string best_kernel_id = "";
float duratation = 1000000.f;
// tune all kernel_id kernels
for(const auto& config_pair : per_channel_gemm_config_map){
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
if(!launch_visitor_gemm_fused_kernel(type_id, config_pair.second, params)){
is_valid = false;
break;
}
}
if(!is_valid){
continue;
}
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaStreamSynchronize(params.stream);
cudaEventRecord(start, params.stream);
for(int num_time = 0; num_time < tune_times; ++num_time){
if(!launch_visitor_gemm_fused_kernel(type_id, config_pair.second, params)){
is_valid = false;
break;
};
}
cudaEventRecord(stop, params.stream);
cudaEventSynchronize(stop);
float elapsedTime;
if(is_valid){
cudaEventElapsedTime(&elapsedTime, start, stop);
} else {
continue;
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
if(elapsedTime < duratation){
best_kernel_id = config_pair.first;
duratation = elapsedTime;
}
}
nlohmann::json new_json;
new_json[encoded_mnk_string] = best_kernel_id;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
return true;
}
}
"""
def SubstituteTemplate(template, values):
"""
生成函数模板
"""
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = f"\\{{{key}\\}}"
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def parse_args():
"""
代码参数解析
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
type=str,
nargs="+",
default=["89"],
help="The CUDA architecture to be generated.",
)
parser.add_argument(
"--min_stages",
type=int,
default=2,
help="The min stages for the gemm kernel.",
)
parser.add_argument(
"--max_stages",
type=int,
default=8,
help="The max stages for the gemm kernel.",
)
args = parser.parse_args()
return args
# generate source .cu
def generate_source_cu(
inputs_type: str,
outputs_type: str,
stages: int,
tiles: str,
hasbiases: str,
sm: str,
):
"""
生成.cu源文件
"""
all_code = SubstituteTemplate(CommonHead, {})
for input_type in inputs_type:
for output_type in outputs_type:
for stage in stages:
for hasbias in hasbiases:
for tile_config in tiles:
value_dict = {
"input_type": input_type,
"output_type": output_type,
"thread_block_shape": tile_config[0],
"warp_shape": tile_config[1],
"mma_shape": tile_config[2],
"num_stages": str(stage),
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(GemmDeclare, value_dict)
all_code += CommonTail
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: str,
inputs_type: str,
outputs_type: str,
stages: int,
tiles: str,
hasbiases: str,
sm: str,
):
"""
生成含有CUDA执行器的launch_visitor_gemm_fused_kernel.h和launch_visitor_gemm_fused_kernel_*.cu文件。
Args:
generate_dir (str): 生成文件的目录路径。
inputs_type (str): 输入类型,可以是"float", "half"或者"bfloat16"中的一种。
outputs_type (str): 输出类型,可以是"float", "half"或者"bfloat16"中的一种。
stages (int): Gemm算子的阶段数。
tiles (str): 包含三个元素的列表每个元素都是包含三个元素的列表分别代表线程块形状、整体线程形状和MMA形状。例如["32,8,4","16,8,4"]。
hasbiases (str): 是否包含偏置量,可以是"true"或者"false"中的一种。
sm (str): GPU的SM大小可以是"70"或者"80"中的一种。
Returns:
tuple (str, dict):
- str (head_all_code) - 所有头部代码的字符串。
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。
"""
code_map = {}
head_path = os.path.join(generate_dir,
"launch_visitor_gemm_fused_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for hasbias in hasbiases:
value_dict = {
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"type_id": str(type_id),
"thread_block_shape": tile[0],
"warp_shape": tile[1],
"mma_shape": tile[2],
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
return head_all_code, code_map
# generate fp8_visitor_gemm_fused.cu
def generate_dispatch_gemm_cu(
inputs_type: str,
outputs_type: str,
stages: int,
tiles: str,
hasbiases: str,
sm: str,
):
"""
生成调度Gemm的CU代码。
Args:
inputs_type (str): 输入类型,字符串格式,多个类型用逗号分隔。可选值为 "float", "half".
outputs_type (str): 输出类型,字符串格式,多个类型用逗号分隔。可选值为 "float", "half".
stages (int): Gemm的层数。
tiles (str): 瓦片形状字符串格式多个瓦片形状用逗号分隔。每个瓦片形状由三个整数组成表示线程块形状、线程块内部warp形状和MMA形状。例如"<8,8,8>,<8,8,8>,<8,8,8>"
hasbiases (str): 是否有偏置,字符串格式,多个值用逗号分隔。可选值为 "true", "false".
sm (str): 使用的SM数量字符串格式多个值用逗号分隔。可选值为 "32", "64".
Returns:
str: 返回一个包含所有代码的字符串。
"""
all_code = code_part0
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
for hasbias in hasbiases:
value_dict = {
"input_type": input_type,
"output_type": output_type,
"hasbias": hasbias,
"type_id": str(type_id),
}
all_code += SubstituteTemplate(code_part1, value_dict)
type_id += 1
all_code += code_part2
tile_id = 0
for tile in tiles:
for stage in stages:
value_dict = {
"thread_block_shape": tile[0],
"warp_shape": tile[1],
"mma_shape": tile[2],
"num_stages": str(stage),
"tile_id": str(tile_id),
}
all_code += SubstituteTemplate(code_part3, value_dict)
tile_id += 1
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
"tile_id": str(tile_id),
"gemm_config": gemm_config_str,
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
if __name__ == "__main__":
args = parse_args()
archs = args.cuda_arch
min_stages = args.min_stages
max_stages = args.max_stages
inputs_type = ("float8_e4m3fn", "float8_e5m2")
outputs_type = ("float16", "bfloat16")
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_stages,
max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
all_code = generate_source_cu(
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
fuse_gemm_config = list(fuse_gemm_configs)[0]
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
sm_dict[sm],
)
file_name = (
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
)
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,
fuse_gemm_config[0],
fuse_gemm_config[1],
fuse_gemm_config[2],
sm_dict[sm],
)
file_dir = os.path.dirname(file_name)
os.makedirs(file_dir, exist_ok=True)
with open(file_name, "w") as f:
f.write(all_code)
f.close()
elif sm == 90:
print("Not supported yet.")
exit(0)
else:
raise ValueError(f"Unsupported SM: {sm}")