mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -0,0 +1,631 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate cutlass_fp8_fp8_half_block_gemm_fused code."""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
get_candidate_tiles returns a list of candidate tiles.
|
||||
"""
|
||||
cta_shape = [
|
||||
("<_128, _128, _128>"),
|
||||
# ("<_256, _128, _128>"),
|
||||
]
|
||||
cluster_shape = [
|
||||
("<_1, _1, _1>"),
|
||||
("<_2, _1, _1>"),
|
||||
("<_1, _2, _1>"),
|
||||
("<_2, _2, _1>"),
|
||||
# ("<_1, _8, _1>"),
|
||||
# ("<_8, _1, _1>"),
|
||||
]
|
||||
base_configs = [(x, y) for x in cta_shape for y in cluster_shape]
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_candidate_configs(sm):
|
||||
"""
|
||||
get_candidate_configs returns a list of candidate configs.
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
hasbias = ("false", "true")
|
||||
KernelSchedule = (
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", )
|
||||
EpilogueSchedule = ("TmaWarpSpecializedCooperative", )
|
||||
TileSchedule = ("PersistentScheduler", "StreamKScheduler")
|
||||
for act_tag in [
|
||||
("noact", "Identity"),
|
||||
# ("relu", "ReLu"),
|
||||
# ("gelu", "GELU"),
|
||||
]:
|
||||
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
|
||||
EpilogueSchedule, TileSchedule)])
|
||||
return candidate_configs
|
||||
|
||||
|
||||
def get_shape_str(tile_shape):
|
||||
"""
|
||||
return tile_shape string.
|
||||
"""
|
||||
blocks, clusters = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
|
||||
]
|
||||
blocks = [elem.strip("_") for elem in blocks]
|
||||
clusters = [elem.strip("_") for elem in clusters]
|
||||
return blocks, clusters
|
||||
|
||||
|
||||
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule,
|
||||
tile_schedule):
|
||||
"""
|
||||
check the cutlass config valid.
|
||||
"""
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
if int(blocks[0]) < 128 and "Cooperative" in kernel_schedule:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fuse_block_gemm_act_template_3x.h"
|
||||
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_fuse_block_gemm_c3x<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
{hasbias},
|
||||
cutlass::epilogue::thread::{Activation},
|
||||
Shape{TileShape},
|
||||
Shape{ClusterShape},
|
||||
cutlass::gemm::{KernelSchedule},
|
||||
cutlass::epilogue::{EpilogueSchedule},
|
||||
cutlass::gemm::{TileSchedule},
|
||||
{SM}
|
||||
>(GemmEpilogueAllParams);
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_block_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "launch_block_gemm_kernel_sm{sm}.h"
|
||||
|
||||
bool launch_block_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params){
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_fuse_block_gemm_c3x<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
{hasbias},
|
||||
cutlass::epilogue::thread::{Activation},
|
||||
Shape{TileShape},
|
||||
Shape{ClusterShape},
|
||||
cutlass::gemm::{KernelSchedule},
|
||||
cutlass::epilogue::{EpilogueSchedule},
|
||||
cutlass::gemm::{TileSchedule},
|
||||
{SM}
|
||||
>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
code_part0 = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
#include "launch_block_gemm_kernel_sm{sm}.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> block_gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> block_gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{TileShape}, {ClusterShape}, {kernel_schedule}, {epilogue_schedule}, {tile_schedule}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_block_gemm_kernel(const int type_id, const int kernel_id, GemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_block_gemm_kernel_sm{sm}_{gemm_config}(type_id, params);
|
||||
break;
|
||||
"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8 gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_relative_best(nlohmann::json* json_data,
|
||||
const std::string& target_key,
|
||||
const int& m,
|
||||
const int& n,
|
||||
const int& k) {
|
||||
if (json_data->contains(target_key)) {
|
||||
return json_data->at(target_key);
|
||||
} else {
|
||||
if (k > 3 * n){
|
||||
return "<_128, _128, _128>, <_1, _2, _1>, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, TmaWarpSpecializedCooperative, StreamKScheduler";
|
||||
}else{
|
||||
return "<_128, _128, _128>, <_1, _2, _1>, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, TmaWarpSpecializedCooperative, PersistentScheduler";
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params) {
|
||||
if (block_gemm_type_map.find(params.fuse_gemm_config) == block_gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8 gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = block_gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M + 31) / 32 * 32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
int kernel_id;
|
||||
std::string mnk_string = "block_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
|
||||
if(config_file_path == "tune"){ // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
float duratation = 1000000.f;
|
||||
// tune all kernel_id kernels
|
||||
for(const auto& config_pair : block_gemm_config_map){
|
||||
std::cout << "Running tune kernel: " << config_pair.first<< std::endl;
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_block_gemm_kernel(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_block_gemm_kernel(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[mnk_string] = best_kernel_id;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
|
||||
return true;
|
||||
} else { // run kernel
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "" && config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, mnk_string, M, N, K);
|
||||
|
||||
if (block_gemm_config_map.find(best_config) == block_gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = block_gemm_config_map[best_config];
|
||||
}
|
||||
return launch_block_gemm_kernel(type_id, kernel_id, params);
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values_base):
|
||||
"""
|
||||
SubstituteTemplate
|
||||
"""
|
||||
values = copy.deepcopy(values_base)
|
||||
if values.get("KernelSchedule"
|
||||
) is not None and "Auto" in values["KernelSchedule"]:
|
||||
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
|
||||
if values.get("EpilogueSchedule"
|
||||
) is not None and "Auto" in values["EpilogueSchedule"]:
|
||||
values[
|
||||
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse_args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["90"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_source_cu(
|
||||
inputs_type: (str),
|
||||
outputs_type: (str),
|
||||
hasbiases: (str),
|
||||
act_tag: (str),
|
||||
tiles: (str),
|
||||
KernelSchedule: (str),
|
||||
EpilogueSchedule: (str),
|
||||
TileSchedule: (str),
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
generate_source_cu
|
||||
"""
|
||||
all_code = CommonHead
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
for kernel_schedule in KernelSchedule:
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
for tile_schedule in TileSchedule:
|
||||
if not check_config_valid(
|
||||
tile_config, kernel_schedule,
|
||||
epilogue_schedule, tile_schedule):
|
||||
continue
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"Activation": act_tag[1],
|
||||
"TileShape": tile_config[0],
|
||||
"ClusterShape": tile_config[1],
|
||||
"KernelSchedule": kernel_schedule,
|
||||
"EpilogueSchedule": epilogue_schedule,
|
||||
"TileSchedule": tile_schedule,
|
||||
"SM": sm,
|
||||
"sm": sm[-2:],
|
||||
}
|
||||
all_code += SubstituteTemplate(
|
||||
GemmDeclare, value_dict)
|
||||
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_gemm_cus(
|
||||
generate_dir: (str), inputs_type: (str), outputs_type: (str),
|
||||
fuse_gemm_configs: tuple, sm: str):
|
||||
"""
|
||||
generate_launch_gemm_cus
|
||||
"""
|
||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||
|
||||
single_config = fuse_gemm_configs[0]
|
||||
hasbiases: (str) = single_config[0]
|
||||
tiles: (str) = single_config[2]
|
||||
KernelSchedule: (str) = single_config[3]
|
||||
EpilogueSchedule: (str) = single_config[4]
|
||||
TileSchedule: (str) = single_config[5]
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir,
|
||||
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile_config in tiles:
|
||||
blocks, clusters = get_shape_str(tile_config)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
for tile_schedule in TileSchedule:
|
||||
if not check_config_valid(tile_config, kernel_schedule,
|
||||
epilogue_schedule,
|
||||
tile_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
||||
value_dict = {
|
||||
"sm":
|
||||
sm[-2:],
|
||||
"gemm_config":
|
||||
gemm_config_str.replace("<", "").replace(">", ""),
|
||||
}
|
||||
head_all_code += SubstituteTemplate(
|
||||
LaunchGemmDeclare, value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile_shape in tiles:
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
for tile_schedule in TileSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule,
|
||||
tile_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
||||
value_dict = {
|
||||
"sm":
|
||||
sm[-2:],
|
||||
"gemm_config":
|
||||
gemm_config_str.replace("<", "").replace(">", ""),
|
||||
}
|
||||
source_all_code = SubstituteTemplate(
|
||||
LaunchGemmPart0, value_dict)
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"type_id": str(type_id),
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"Activation": act_tag[1],
|
||||
"TileShape": tile_shape[0],
|
||||
"ClusterShape": tile_shape[1],
|
||||
"KernelSchedule": kernel_schedule,
|
||||
"EpilogueSchedule": epilogue_schedule,
|
||||
"TileSchedule": tile_schedule,
|
||||
"SM": sm,
|
||||
"sm": sm[-2:],
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
gemm_config_str = gemm_config_str.replace("<", "").replace(
|
||||
">", "")
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir,
|
||||
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
|
||||
)
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
|
||||
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
||||
fuse_gemm_configs: tuple, sm: str):
|
||||
"""
|
||||
generate_dispatch_gemm_cu
|
||||
"""
|
||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||
|
||||
single_config = fuse_gemm_configs[0]
|
||||
hasbiases: (str) = single_config[0]
|
||||
tiles: (str) = single_config[2]
|
||||
KernelSchedule: (str) = single_config[3]
|
||||
EpilogueSchedule: (str) = single_config[4]
|
||||
TileSchedule: (str) = single_config[5]
|
||||
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag[0],
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile_shape in tiles:
|
||||
for kernel_schedule in KernelSchedule:
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
for tile_schedule in TileSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule,
|
||||
tile_schedule):
|
||||
continue
|
||||
value_dict = {
|
||||
"TileShape": tile_shape[0],
|
||||
"ClusterShape": tile_shape[1],
|
||||
"kernel_schedule": kernel_schedule,
|
||||
"epilogue_schedule": epilogue_schedule,
|
||||
"tile_schedule": tile_schedule,
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += SubstituteTemplate(code_part4, {"sm": sm[-2:]})
|
||||
tile_id = 0
|
||||
for tile_shape in tiles:
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
for tile_schedule in TileSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule,
|
||||
tile_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
||||
value_dict = {
|
||||
"sm":
|
||||
sm[-2:],
|
||||
"tile_id":
|
||||
str(tile_id),
|
||||
"gemm_config":
|
||||
gemm_config_str.replace("<", "").replace(">", ""),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
|
||||
all_code += SubstituteTemplate(code_part6, {"sm": sm[-2:]})
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
inputs_type = (
|
||||
"float8_e4m3fn",
|
||||
"float8_e5m2",
|
||||
)
|
||||
outputs_type = ("float16", "bfloat16")
|
||||
sm_dict = {"90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "90":
|
||||
fuse_gemm_configs = get_candidate_configs(sm)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = (
|
||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
|
||||
f"generic_block_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
|
||||
)
|
||||
all_code = generate_source_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
fuse_gemm_config[3],
|
||||
fuse_gemm_config[4],
|
||||
fuse_gemm_config[5],
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
# Compile parallelization
|
||||
generate_launch_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
|
||||
outputs_type, fuse_gemm_configs, sm_dict[sm])
|
||||
|
||||
# hard code for act_tag
|
||||
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
|
||||
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu")
|
||||
all_code = generate_dispatch_gemm_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_configs,
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
673
custom_ops/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py
Normal file
673
custom_ops/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate dual_gemm_fused_kernel code."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
get_candidate_tiles returns a list of candidate tiles for the dual_gemm_fused_kernel.
|
||||
"""
|
||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||
|
||||
base_configs.extend([
|
||||
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||
])
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages,
|
||||
max_stages):
|
||||
"""
|
||||
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
|
||||
splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1))
|
||||
hasbias = ("false", "true")
|
||||
|
||||
for act_tag in [
|
||||
("swiglu", "LeftSiLUAndMul"),
|
||||
("geglu", "LeftGELUAndMul"),
|
||||
]:
|
||||
candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)])
|
||||
|
||||
return candidate_configs
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fuse_dual_gemm_{act_tag}_template.h"
|
||||
|
||||
"""
|
||||
|
||||
CommonTail = """
|
||||
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_dual_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type}, phi::dtype::{bias_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(DualGemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
GemmSplitKDeclare = """
|
||||
template<>
|
||||
bool dispatch_dual_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type}, phi::dtype::{bias_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(DualGemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_dual_gemm_kernel_{gemm_config}(const int type_id, const int split_k, DualGemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "launch_dual_gemm_kernel.h"
|
||||
|
||||
bool launch_dual_gemm_kernel_{gemm_config}(const int type_id, const int split_k, DualGemmEpilogueAllParams params){
|
||||
if(split_k < 2){
|
||||
params.split_k = 1;
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_dual_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type}, phi::dtype::{bias_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
}else{
|
||||
throw std::runtime_error("cutlass dual gemm split_k mode is not generated.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <limits>
|
||||
#include "helper.h"
|
||||
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
|
||||
#include "launch_dual_gemm_kernel.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> dual_gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{bias_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> dual_gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, DualGemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_dual_gemm_kernel_{gemm_config}(type_id, split_k, params);
|
||||
break;"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
|
||||
if (dual_gemm_type_map.find(params.fuse_gemm_config) == dual_gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8 gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = dual_gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M+31)/32 *32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
std::string mnk_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
auto encoded_mnk_string = base64_encode(mnk_string);
|
||||
|
||||
std::string mnk_split_k_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
|
||||
auto encoded_mnk_split_k_string = base64_encode(mnk_split_k_string);
|
||||
|
||||
int split_k;
|
||||
int kernel_id;
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
|
||||
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
|
||||
split_k = get_relative_best<int>(config_json, encoded_mnk_split_k_string, 1);
|
||||
|
||||
if (dual_gemm_config_map.find(best_config) == dual_gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = dual_gemm_config_map[best_config];
|
||||
}
|
||||
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
|
||||
} else { // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
int best_split_k = -1;
|
||||
float duratation = 1000000.f;
|
||||
// tune all split_k, kernel_id kernels
|
||||
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
|
||||
for(const auto& config_pair : dual_gemm_config_map){
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
best_split_k = i;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[encoded_mnk_string] = best_kernel_id;
|
||||
new_json[encoded_mnk_split_k_string] = best_split_k;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
"""
|
||||
SubstituteTemplate
|
||||
"""
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def check_min_split_k(value):
|
||||
"""
|
||||
check_min_split_k
|
||||
"""
|
||||
ivalue = int(value)
|
||||
if ivalue > 1:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Dual gemm split_k mode is not support.")
|
||||
return ivalue
|
||||
|
||||
|
||||
def check_max_split_k(value):
|
||||
"""
|
||||
check_max_split_k
|
||||
"""
|
||||
ivalue = int(value)
|
||||
if ivalue > 1:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Dual gemm split_k mode is not support..")
|
||||
return ivalue
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse_args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["89"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_split_k",
|
||||
type=check_min_split_k,
|
||||
default=1,
|
||||
help="The min split k for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_split_k",
|
||||
type=check_max_split_k,
|
||||
default=1,
|
||||
help="The max split k for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_stages",
|
||||
type=int,
|
||||
default=3,
|
||||
help="The min stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_stages",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The max stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_dual_gemm_source_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
biases_type: str,
|
||||
stages: int,
|
||||
tiles: str,
|
||||
act_tag: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
"""
|
||||
generate_dual_gemm_source_cu
|
||||
"""
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
}
|
||||
all_code = SubstituteTemplate(CommonHead, value_dict)
|
||||
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"act_tag": act_tag,
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||
|
||||
if min_split_k > 1 and max_split_k > 1:
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"act_tag": act_tag,
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(
|
||||
GemmSplitKDeclare, value_dict)
|
||||
|
||||
all_code += CommonTail
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_dual_gemm_cus(
|
||||
generate_dir: str,
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
split_ks: int,
|
||||
tiles: str,
|
||||
act_tags: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
"""
|
||||
generate_launch_dual_gemm_cus
|
||||
"""
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
||||
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
||||
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
||||
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
||||
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||
# split_k_code = ""
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
"input_type": input_type,
|
||||
"output_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"SM": sm,
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict)
|
||||
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
# source_all_code += split_k_code
|
||||
# source_all_code += LaunchGemmPart4
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_fp8_gemm_scale_bias_act.cu
|
||||
def generate_dispatch_dual_gemm_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
biases_type: str,
|
||||
stages: int,
|
||||
split_ks: int,
|
||||
tiles: str,
|
||||
act_tags: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
"""
|
||||
generate_dispatch_dual_gemm_cu
|
||||
"""
|
||||
all_code = code_part0
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
"input_type": input_type,
|
||||
"output_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
for stage in stages:
|
||||
value_dict = {
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += code_part4
|
||||
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
||||
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
||||
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"tile_id": str(tile_id),
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
value_dict.update({
|
||||
"min_split_k": str(min_split_k),
|
||||
"max_split_k": str(max_split_k),
|
||||
})
|
||||
all_code += SubstituteTemplate(code_part6, value_dict)
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
min_split_k = args.min_split_k
|
||||
max_split_k = args.max_split_k
|
||||
min_stages = args.min_stages
|
||||
max_stages = args.max_stages
|
||||
inputs_type = ("float8_e4m3fn", "float8_e5m2")
|
||||
biases_type = ("float16", "bfloat16")
|
||||
outputs_type = ("float8_e4m3fn", "float8_e4m3fn")
|
||||
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "89":
|
||||
fuse_gemm_configs = get_dual_gemm_candidate_configs(
|
||||
sm, min_split_k, max_split_k, min_stages, max_stages)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = (
|
||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
||||
f"autogen/generic_dual_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
|
||||
)
|
||||
all_code = generate_dual_gemm_source_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
biases_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[2],
|
||||
fuse_gemm_config[3][0],
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
fuse_gemm_config = list(fuse_gemm_configs)[0]
|
||||
|
||||
act_tags = ["swiglu", "geglu"]
|
||||
|
||||
# Compile parallelization
|
||||
generate_launch_dual_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
act_tags,
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
# hard code for act_tag
|
||||
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_dual_gemm_scale_bias_act.cu"
|
||||
all_code = generate_dispatch_dual_gemm_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
biases_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
act_tags,
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
elif sm == 90:
|
||||
print("Not supported yet.")
|
||||
exit(0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
@@ -0,0 +1,592 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate dual_gemm_fused_kernel code."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
"""
|
||||
cta_shape = [
|
||||
("<_64, _16, _128>"),
|
||||
("<_64, _32, _128>"),
|
||||
("<_64, _64, _128>"),
|
||||
("<_64, _128, _128>"),
|
||||
("<_128, _16, _128>"),
|
||||
("<_128, _32, _128>"),
|
||||
("<_128, _64, _128>"),
|
||||
# ("<_128, _128, _128>"),
|
||||
]
|
||||
cluster_shape = [
|
||||
("<_1, _1, _1>"),
|
||||
("<_2, _1, _1>"),
|
||||
("<_1, _2, _1>"),
|
||||
# ("<_2, _2, _1>"),
|
||||
# ("<_1, _8, _1>"),
|
||||
# ("<_8, _1, _1>"),
|
||||
]
|
||||
base_configs = [(x, y) for x in cta_shape for y in cluster_shape]
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_dual_gemm_candidate_configs(sm):
|
||||
"""
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
hasbias = (
|
||||
"false",
|
||||
# "true",
|
||||
)
|
||||
KernelSchedule = (
|
||||
# "KernelTmaWarpSpecializedFP8FastAccum",
|
||||
"KernelTmaWarpSpecializedPingpongFP8FastAccum",
|
||||
"KernelTmaWarpSpecializedCooperativeFP8FastAccum",
|
||||
)
|
||||
EpilogueSchedule = ("TmaWarpSpecialized", "TmaWarpSpecializedCooperative")
|
||||
for act_tag in [
|
||||
("swiglu", "SiLu"),
|
||||
("geglu", "GELU"),
|
||||
]:
|
||||
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
|
||||
EpilogueSchedule)])
|
||||
|
||||
return candidate_configs
|
||||
|
||||
|
||||
def get_shape_str(tile_shape):
|
||||
"""
|
||||
"""
|
||||
blocks, clusters = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
|
||||
]
|
||||
blocks = [elem.strip("_") for elem in blocks]
|
||||
clusters = [elem.strip("_") for elem in clusters]
|
||||
return blocks, clusters
|
||||
|
||||
|
||||
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||
"""
|
||||
"""
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
if int(
|
||||
blocks[0]
|
||||
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
|
||||
return False
|
||||
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
|
||||
return False
|
||||
if tile_shape[
|
||||
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h"
|
||||
|
||||
using namespace cute;
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_dual_gemm_act_sm{sm}<phi::dtype::{input_type},
|
||||
Shape{TileShape},
|
||||
Shape{ClusterShape},
|
||||
cutlass::gemm::{KernelSchedule},
|
||||
cutlass::epilogue::{EpilogueSchedule},
|
||||
void,
|
||||
cutlass::epilogue::thread::{Activation}>(DualGemmEpilogueAllParams);
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_dual_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, DualGemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "launch_dual_gemm_kernel_sm{sm}.h"
|
||||
|
||||
bool launch_dual_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, DualGemmEpilogueAllParams params){
|
||||
using namespace cute;
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_dual_gemm_act_sm{sm}<phi::dtype::{input_type},
|
||||
Shape{TileShape},
|
||||
Shape{ClusterShape},
|
||||
cutlass::gemm::{KernelSchedule},
|
||||
cutlass::epilogue::{EpilogueSchedule},
|
||||
void,
|
||||
cutlass::epilogue::thread::{Activation}>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
code_part0 = """// Generated by auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h"
|
||||
#include "launch_dual_gemm_kernel_sm{sm}.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> dual_gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{bias_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> dual_gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{TileShape}, {ClusterShape}, {kernel_schedule}, {epilogue_schedule}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_dual_gemm_kernel_sm{sm}(const int type_id, const int kernel_id, DualGemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_dual_gemm_kernel_sm{sm}_{gemm_config}(type_id, params);
|
||||
break;"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8 dual gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_relative_best(nlohmann::json* json_data,
|
||||
const std::string& target_key,
|
||||
const int& m) {
|
||||
if (json_data->contains(target_key)) {
|
||||
return json_data->at(target_key);
|
||||
} else {
|
||||
if (m <= 6400){
|
||||
return "<_64, _32, _128>, <_1, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
|
||||
}else{
|
||||
return "<_64, _64, _128>, <_1, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
|
||||
if (dual_gemm_type_map.find(params.fuse_gemm_config) == dual_gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8 dual gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = dual_gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M + 31) / 32 * 32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
int kernel_id;
|
||||
std::string mnk_string = "tensor_dual_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
|
||||
if(config_file_path == "tune"){ // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
float duratation = 1000000.f;
|
||||
// tune all kernel_id kernels
|
||||
for(const auto& config_pair : dual_gemm_config_map){
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_dual_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_dual_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[mnk_string] = best_kernel_id;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
|
||||
return true;
|
||||
} else { // run kernel
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "" && config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, mnk_string, M);
|
||||
|
||||
if (dual_gemm_config_map.find(best_config) == dual_gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_dual_gemm_fused_kernels.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = dual_gemm_config_map[best_config];
|
||||
}
|
||||
return launch_dual_gemm_kernel_sm{sm}(type_id, kernel_id, params);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
"""
|
||||
"""
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["90"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_dual_gemm_source_cu(
|
||||
inputs_type: (str),
|
||||
biases_type: (str),
|
||||
hasbiases: (str),
|
||||
act_tag: (str),
|
||||
tiles: (str),
|
||||
KernelSchedule: (str),
|
||||
EpilogueSchedule: (str),
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
all_code = CommonHead
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
for kernel_schedule in KernelSchedule:
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_config,
|
||||
kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"hasbias": hasbias,
|
||||
"Activation": act_tag[1],
|
||||
"TileShape": tile_config[0],
|
||||
"ClusterShape": tile_config[1],
|
||||
"KernelSchedule": kernel_schedule,
|
||||
"EpilogueSchedule": epilogue_schedule,
|
||||
"SM": sm,
|
||||
"sm": sm[-2:],
|
||||
}
|
||||
all_code += SubstituteTemplate(
|
||||
GemmDeclare, value_dict)
|
||||
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_dual_gemm_cus(
|
||||
generate_dir: (str), inputs_type: (str), biases_type: (str),
|
||||
fuse_gemm_configs: tuple, sm: str):
|
||||
"""
|
||||
"""
|
||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||
|
||||
single_config = fuse_gemm_configs[0]
|
||||
hasbiases: (str) = single_config[0]
|
||||
tiles: (str) = single_config[2]
|
||||
KernelSchedule: (str) = single_config[3]
|
||||
EpilogueSchedule: (str) = single_config[4]
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir,
|
||||
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile_config in tiles:
|
||||
blocks, clusters = get_shape_str(tile_config)
|
||||
blocks = [elem.strip("_") for elem in blocks]
|
||||
clusters = [elem.strip("_") for elem in clusters]
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_config, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
value_dict = {
|
||||
"sm": sm[-2:],
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
|
||||
value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile_shape in tiles:
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
value_dict = {
|
||||
"sm": sm[-2:],
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
source_all_code = SubstituteTemplate(LaunchGemmPart0,
|
||||
value_dict)
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"type_id": str(type_id),
|
||||
"input_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"hasbias": hasbias,
|
||||
"Activation": act_tag[1],
|
||||
"TileShape": tile_shape[0],
|
||||
"ClusterShape": tile_shape[1],
|
||||
"KernelSchedule": kernel_schedule,
|
||||
"EpilogueSchedule": epilogue_schedule,
|
||||
"SM": sm,
|
||||
"sm": sm[-2:],
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir,
|
||||
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
|
||||
)
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_fp8_gemm_scale_bias_act.cu
|
||||
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
|
||||
fuse_gemm_configs: tuple, sm: str):
|
||||
"""
|
||||
"""
|
||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||
single_config = fuse_gemm_configs[0]
|
||||
hasbiases: (str) = single_config[0]
|
||||
tiles: (str) = single_config[2]
|
||||
KernelSchedule: (str) = single_config[3]
|
||||
EpilogueSchedule: (str) = single_config[4]
|
||||
|
||||
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for bias_type in biases_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag[0],
|
||||
"input_type": input_type,
|
||||
"output_type": input_type,
|
||||
"bias_type": bias_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile_shape in tiles:
|
||||
for kernel_schedule in KernelSchedule:
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
value_dict = {
|
||||
"TileShape": tile_shape[0],
|
||||
"ClusterShape": tile_shape[1],
|
||||
"kernel_schedule": kernel_schedule,
|
||||
"epilogue_schedule": epilogue_schedule,
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += SubstituteTemplate(code_part4, {"sm": sm[-2:]})
|
||||
tile_id = 0
|
||||
for tile_shape in tiles:
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
value_dict = {
|
||||
"sm": sm[-2:],
|
||||
"tile_id": str(tile_id),
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
all_code += SubstituteTemplate(code_part6, {"sm": sm[-2:]})
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
inputs_type = ("float8_e4m3fn", "float8_e5m2")
|
||||
biases_type = (
|
||||
"float16",
|
||||
"bfloat16",
|
||||
)
|
||||
sm_dict = {"90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "90":
|
||||
fuse_gemm_configs = get_dual_gemm_candidate_configs(sm)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = (
|
||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
||||
f"autogen/generic_dual_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
|
||||
)
|
||||
all_code = generate_dual_gemm_source_cu(
|
||||
inputs_type,
|
||||
biases_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
fuse_gemm_config[3],
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
# Compile parallelization
|
||||
generate_launch_dual_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
|
||||
biases_type, fuse_gemm_configs, sm_dict[sm])
|
||||
# hard code for act_tag
|
||||
file_name = (
|
||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
||||
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
|
||||
)
|
||||
all_code = generate_dispatch_dual_gemm_cu(
|
||||
inputs_type,
|
||||
biases_type,
|
||||
fuse_gemm_configs,
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
682
custom_ops/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py
Normal file
682
custom_ops/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py
Normal file
@@ -0,0 +1,682 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate gemm_fused_kernel code."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
获取候选的tile配置列表。
|
||||
|
||||
Args:
|
||||
无参数。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 包含tile配置的三元组列表,每个三元组中的字符串表示tile的形状"。
|
||||
|
||||
"""
|
||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||
|
||||
base_configs.extend([
|
||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||
])
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages,
|
||||
max_stages):
|
||||
"""
|
||||
获取候选的gemm算子配置列表。
|
||||
|
||||
Args:
|
||||
sm (str): 计算能力,如"70"。
|
||||
min_split_k (int): split k的最小值。
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
|
||||
splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1))
|
||||
hasbias = ("false", "true")
|
||||
|
||||
for act_tag in [
|
||||
("noact", "LinearCombination"),
|
||||
("relu", "LinearCombinationRelu"),
|
||||
("gelu", "LinearCombinationGELU"),
|
||||
]:
|
||||
candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)])
|
||||
|
||||
return candidate_configs
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fuse_gemm_{act_tag}_template.h"
|
||||
|
||||
"""
|
||||
|
||||
CommonTail = """
|
||||
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
GemmSplitKDeclare = """
|
||||
template<>
|
||||
bool dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "launch_gemm_kernel.h"
|
||||
|
||||
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params){
|
||||
if(split_k < 2){
|
||||
params.split_k = 1;
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
}else{
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart3 = """
|
||||
case {type_id}:
|
||||
return dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart4 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <limits>
|
||||
#include "helper.h"
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
#include "launch_gemm_kernel.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_gemm_kernel_{gemm_config}(type_id, split_k, params);
|
||||
break;"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_relative_best(nlohmann::json* json_data,
|
||||
const std::string& target_key,
|
||||
const std::string& regex_key,
|
||||
const int& m,
|
||||
const T& default_value) {
|
||||
if (json_data->contains(target_key)) {
|
||||
return json_data->at(target_key);
|
||||
} else {
|
||||
std::regex pattern(regex_key);
|
||||
std::string closest_key;
|
||||
int closest_diff = std::numeric_limits<int>::max();
|
||||
T closest_value = default_value;
|
||||
|
||||
for (const auto& [key, value] : json_data->items()) {
|
||||
std::smatch matches;
|
||||
if (std::regex_search(key, matches, pattern)) {
|
||||
int relative_m = std::stoi(matches[1].str());
|
||||
int diff = std::abs(relative_m - m);
|
||||
if (diff < closest_diff) {
|
||||
closest_diff = diff;
|
||||
closest_value = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
json_data->push_back({target_key, closest_value});
|
||||
return closest_value;
|
||||
}
|
||||
}
|
||||
|
||||
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
|
||||
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8 gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M+31)/32 *32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
std::string mnk_string = "tensor_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
std::string regex_mnk_string = "tensor_gemm_sm90<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">";
|
||||
std::string mnk_split_k_string = "tensor_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
|
||||
std::string regex_mnk_split_k_string = "tensor_gemm_sm90<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">, split_k";
|
||||
int split_k;
|
||||
int kernel_id;
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
|
||||
if(config_file_path == "tune"){ // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
int best_split_k = -1;
|
||||
float duratation = 1000000.f;
|
||||
// tune all split_k, kernel_id kernels
|
||||
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
|
||||
for(const auto& config_pair : gemm_config_map){
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
best_split_k = i;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[mnk_string] = best_kernel_id;
|
||||
new_json[mnk_split_k_string] = best_split_k;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
|
||||
return true;
|
||||
} else { // run kernel
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "" && config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, mnk_string, regex_mnk_string, M, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
|
||||
split_k = get_relative_best<int>(config_json, mnk_split_k_string, regex_mnk_split_k_string, M, 1);
|
||||
|
||||
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_gemm_fused_kernels.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = gemm_config_map[best_config];
|
||||
}
|
||||
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
"""
|
||||
生成函数模板
|
||||
"""
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
代码参数解析
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["89"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_split_k",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The min split k for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_split_k",
|
||||
type=int,
|
||||
default=6,
|
||||
help="The max split k for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_stages",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The min stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_stages",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The max stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_source_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
tiles: str,
|
||||
act_tag: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
生成.cu源文件
|
||||
"""
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
}
|
||||
all_code = SubstituteTemplate(CommonHead, value_dict)
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"act_tag": act_tag,
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"act_tag": act_tag,
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(GemmSplitKDeclare,
|
||||
value_dict)
|
||||
|
||||
all_code += CommonTail
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_gemm_cus(
|
||||
generate_dir: str,
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
split_ks: int,
|
||||
tiles: str,
|
||||
act_tags: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||
split_k_code = ""
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"SM": sm,
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict)
|
||||
split_k_code += SubstituteTemplate(
|
||||
LaunchGemmPart3, value_dict)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
source_all_code += split_k_code
|
||||
source_all_code += LaunchGemmPart4
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_fp8_gemm_scale_bias_act.cu
|
||||
def generate_dispatch_gemm_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
split_ks: int,
|
||||
tiles: str,
|
||||
act_tags: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
|
||||
all_code = code_part0
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
for stage in stages:
|
||||
value_dict = {
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += code_part4
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"tile_id": str(tile_id),
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
value_dict.update({
|
||||
"min_split_k": str(min_split_k),
|
||||
"max_split_k": str(max_split_k),
|
||||
})
|
||||
all_code += SubstituteTemplate(code_part6, value_dict)
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
min_split_k = args.min_split_k
|
||||
max_split_k = args.max_split_k
|
||||
min_stages = args.min_stages
|
||||
max_stages = args.max_stages
|
||||
inputs_type = ("float8_e4m3fn", "float8_e5m2")
|
||||
outputs_type = ("float16", "bfloat16")
|
||||
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "89":
|
||||
fuse_gemm_configs = get_candidate_configs(sm, min_split_k,
|
||||
max_split_k, min_stages,
|
||||
max_stages)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
|
||||
all_code = generate_source_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[2],
|
||||
fuse_gemm_config[3][0],
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
fuse_gemm_config = list(fuse_gemm_configs)[0]
|
||||
|
||||
act_tags = ["noact", "relu", "gelu"]
|
||||
# Compile parallelization
|
||||
generate_launch_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
act_tags,
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
|
||||
# hard code for act_tag
|
||||
|
||||
file_name = (
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
|
||||
)
|
||||
all_code = generate_dispatch_gemm_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
act_tags,
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
elif sm == 90:
|
||||
print("Not supported yet.")
|
||||
exit(0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
614
custom_ops/utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py
Normal file
614
custom_ops/utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py
Normal file
@@ -0,0 +1,614 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate gemm_fused_kernel code."""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
"""
|
||||
base_configs = [
|
||||
("<_64, _64, _128>", "<_1, _8, _1>"),
|
||||
("<_64, _128, _128>", "<_2, _1, _1>"),
|
||||
("<_128, _128, _128>", "<_2, _1, _1>"),
|
||||
]
|
||||
base_configs.extend([
|
||||
("<_64, _64, _128>", "<_1, _1, _1>"),
|
||||
("<_64, _64, _128>", "<_1, _2, _1>"),
|
||||
("<_64, _64, _128>", "<_2, _1, _1>"),
|
||||
("<_64, _64, _64>", "<_1, _1, _1>"),
|
||||
("<_64, _64, _64>", "<_1, _2, _1>"),
|
||||
("<_64, _64, _64>", "<_2, _1, _1>"),
|
||||
("<_64, _128, _128>", "<_1, _2, _1>"),
|
||||
("<_64, _128, _128>", "<_1, _1, _1>"),
|
||||
("<_128, _128, _64>", "<_2, _1, _1>"),
|
||||
("<_256, _128, _128>", "<_1, _2, _1>"),
|
||||
("<_256, _128, _128>", "<_1, _1, _1>"),
|
||||
# The following configurations are rarely selected in Qwen2-7B-model.
|
||||
# ("<_256, _128, _128>", "<_4, _1, _1>"),
|
||||
# ("<_256, _128, _128>", "<_1, _4, _1>"),
|
||||
# ("<_256, _128, _128>", "<_2, _4, _1>"),
|
||||
# ("<_128, _128, _256>", "<_1, _2, _1>"),
|
||||
# ("<_128, _128, _128>", "<_4, _1, _1>"),
|
||||
# ("<_128, _128, _128>", "<_2, _4, _1>"),
|
||||
# ("<_128, _128, _128>", "<_1, _2, _1>"),
|
||||
# ("<_128, _128, _128>", "<_1, _1, _1>"),
|
||||
# ("<_128, _128, _128>", "<_1, _4, _1>"),
|
||||
# ("<_128, _128, _64>", "<_2, _2, _1>"),
|
||||
])
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_candidate_configs(sm):
|
||||
"""
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
hasbias = ("false", "true")
|
||||
KernelSchedule = (
|
||||
"KernelTmaWarpSpecializedFP8FastAccum",
|
||||
"KernelTmaWarpSpecializedPingpongFP8FastAccum",
|
||||
# "KernelTmaWarpSpecializedCooperativeFP8FastAccum",
|
||||
)
|
||||
EpilogueSchedule = ("TmaWarpSpecialized", "TmaWarpSpecializedCooperative")
|
||||
for act_tag in [
|
||||
("noact", "Identity"),
|
||||
("relu", "ReLu"),
|
||||
("gelu", "GELU"),
|
||||
]:
|
||||
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
|
||||
EpilogueSchedule)])
|
||||
|
||||
return candidate_configs
|
||||
|
||||
|
||||
def get_shape_str(tile_shape):
|
||||
"""
|
||||
"""
|
||||
blocks, clusters = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
|
||||
]
|
||||
blocks = [elem.strip("_") for elem in blocks]
|
||||
clusters = [elem.strip("_") for elem in clusters]
|
||||
return blocks, clusters
|
||||
|
||||
|
||||
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||
"""
|
||||
"""
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
if int(
|
||||
blocks[0]
|
||||
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
|
||||
return False
|
||||
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
|
||||
return False
|
||||
if (tile_shape[0] == "<_256, _128, _128>"
|
||||
and "Cooperative" not in kernel_schedule
|
||||
and "Cooperative" not in epilogue_schedule):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fuse_gemm_act_template_3x.h"
|
||||
using namespace cute;
|
||||
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_fuse_gemm_act_sm{sm}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
{hasbias},
|
||||
cutlass::epilogue::thread::{Activation},
|
||||
Shape{TileShape},
|
||||
Shape{ClusterShape},
|
||||
cutlass::gemm::{KernelSchedule},
|
||||
cutlass::epilogue::{EpilogueSchedule},
|
||||
{SM}
|
||||
>(GemmEpilogueAllParams);
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "launch_gemm_kernel_sm{sm}.h"
|
||||
|
||||
bool launch_gemm_kernel_sm{sm}_{gemm_config}(const int type_id, GemmEpilogueAllParams params){
|
||||
using namespace cute;
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_fuse_gemm_act_sm{sm}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
{hasbias},
|
||||
cutlass::epilogue::thread::{Activation},
|
||||
Shape{TileShape},
|
||||
Shape{ClusterShape},
|
||||
cutlass::gemm::{KernelSchedule},
|
||||
cutlass::epilogue::{EpilogueSchedule},
|
||||
{SM}
|
||||
>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
code_part0 = """// Generated by auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
#include "launch_gemm_kernel_sm{sm}.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{TileShape}, {ClusterShape}, {kernel_schedule}, {epilogue_schedule}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_gemm_kernel_sm{sm}(const int type_id, const int kernel_id, GemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_gemm_kernel_sm{sm}_{gemm_config}(type_id, params);
|
||||
break;
|
||||
"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8 gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_relative_best(nlohmann::json* json_data,
|
||||
const std::string& target_key,
|
||||
const int& m) {
|
||||
if (json_data->contains(target_key)) {
|
||||
return json_data->at(target_key);
|
||||
} else {
|
||||
if (m <= 64){
|
||||
return "<_64, _64, _128>, <_1, _8, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
|
||||
}else if(m <= 128){
|
||||
return "<_64, _128, _128>, <_2, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
|
||||
}else{
|
||||
return "<_128, _128, _128>, <_2, _1, _1>, KernelTmaWarpSpecializedPingpongFP8FastAccum, TmaWarpSpecialized";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
|
||||
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8 gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M + 31) / 32 * 32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
int kernel_id;
|
||||
std::string mnk_string = "tensor_gemm_sm90<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
char *config_file_path_c_str = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
std::string config_file_path = config_file_path_c_str == nullptr ? "" : std::string(config_file_path_c_str);
|
||||
if(config_file_path == "tune"){ // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
float duratation = 1000000.f;
|
||||
// tune all kernel_id kernels
|
||||
for(const auto& config_pair : gemm_config_map){
|
||||
std::cout << "Running tune kernel: " << config_pair.first<< std::endl;
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_gemm_kernel_sm{sm}(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[mnk_string] = best_kernel_id;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
|
||||
return true;
|
||||
} else { // run kernel
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "" && config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, mnk_string, M);
|
||||
|
||||
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = gemm_config_map[best_config];
|
||||
}
|
||||
return launch_gemm_kernel_sm{sm}(type_id, kernel_id, params);
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values_base):
|
||||
"""
|
||||
"""
|
||||
values = copy.deepcopy(values_base)
|
||||
if values.get("KernelSchedule"
|
||||
) is not None and "Auto" in values["KernelSchedule"]:
|
||||
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
|
||||
if values.get("EpilogueSchedule"
|
||||
) is not None and "Auto" in values["EpilogueSchedule"]:
|
||||
values[
|
||||
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["90"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_source_cu(
|
||||
inputs_type: (str),
|
||||
outputs_type: (str),
|
||||
hasbiases: (str),
|
||||
act_tag: (str),
|
||||
tiles: (str),
|
||||
KernelSchedule: (str),
|
||||
EpilogueSchedule: (str),
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
all_code = CommonHead
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
for kernel_schedule in KernelSchedule:
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_config,
|
||||
kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"Activation": act_tag[1],
|
||||
"TileShape": tile_config[0],
|
||||
"ClusterShape": tile_config[1],
|
||||
"KernelSchedule": kernel_schedule,
|
||||
"EpilogueSchedule": epilogue_schedule,
|
||||
"SM": sm,
|
||||
"sm": sm[-2:],
|
||||
}
|
||||
all_code += SubstituteTemplate(
|
||||
GemmDeclare, value_dict)
|
||||
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_gemm_cus(
|
||||
generate_dir: (str), inputs_type: (str), outputs_type: (str),
|
||||
fuse_gemm_configs: tuple, sm: str):
|
||||
"""
|
||||
"""
|
||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||
|
||||
single_config = fuse_gemm_configs[0]
|
||||
hasbiases: (str) = single_config[0]
|
||||
tiles: (str) = single_config[2]
|
||||
KernelSchedule: (str) = single_config[3]
|
||||
EpilogueSchedule: (str) = single_config[4]
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile_config in tiles:
|
||||
blocks, clusters = get_shape_str(tile_config)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_config, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
value_dict = {
|
||||
"sm": sm[-2:],
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
|
||||
value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile_shape in tiles:
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
value_dict = {
|
||||
"sm": sm[-2:],
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
source_all_code = SubstituteTemplate(LaunchGemmPart0,
|
||||
value_dict)
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"type_id": str(type_id),
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"Activation": act_tag[1],
|
||||
"TileShape": tile_shape[0],
|
||||
"ClusterShape": tile_shape[1],
|
||||
"KernelSchedule": kernel_schedule,
|
||||
"EpilogueSchedule": epilogue_schedule,
|
||||
"SM": sm,
|
||||
"sm": sm[-2:],
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir,
|
||||
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu")
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
|
||||
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
||||
fuse_gemm_configs: tuple, sm: str):
|
||||
"""
|
||||
"""
|
||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||
|
||||
single_config = fuse_gemm_configs[0]
|
||||
hasbiases: (str) = single_config[0]
|
||||
tiles: (str) = single_config[2]
|
||||
KernelSchedule: (str) = single_config[3]
|
||||
EpilogueSchedule: (str) = single_config[4]
|
||||
|
||||
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag[0],
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile_shape in tiles:
|
||||
for kernel_schedule in KernelSchedule:
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
value_dict = {
|
||||
"TileShape": tile_shape[0],
|
||||
"ClusterShape": tile_shape[1],
|
||||
"kernel_schedule": kernel_schedule,
|
||||
"epilogue_schedule": epilogue_schedule,
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += SubstituteTemplate(code_part4, {"sm": sm[-2:]})
|
||||
tile_id = 0
|
||||
for tile_shape in tiles:
|
||||
blocks, clusters = get_shape_str(tile_shape)
|
||||
gemm_config_str_0 = f"tile{blocks[0]}x{blocks[1]}x{blocks[2]}_cluster{clusters[0]}x{clusters[1]}x{clusters[2]}"
|
||||
for kernel_schedule in KernelSchedule:
|
||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||
for epilogue_schedule in EpilogueSchedule:
|
||||
if not check_config_valid(tile_shape, kernel_schedule,
|
||||
epilogue_schedule):
|
||||
continue
|
||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||
value_dict = {
|
||||
"sm": sm[-2:],
|
||||
"tile_id": str(tile_id),
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
|
||||
all_code += SubstituteTemplate(code_part6, {"sm": sm[-2:]})
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
inputs_type = (
|
||||
"float8_e4m3fn",
|
||||
"float8_e5m2",
|
||||
)
|
||||
outputs_type = ("float16", "bfloat16")
|
||||
sm_dict = {"90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "90":
|
||||
fuse_gemm_configs = get_candidate_configs(sm)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = (
|
||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
||||
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu")
|
||||
all_code = generate_source_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
fuse_gemm_config[3],
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
# Compile parallelization
|
||||
generate_launch_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
|
||||
outputs_type, fuse_gemm_configs, sm_dict[sm])
|
||||
|
||||
# hard code for act_tag
|
||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"
|
||||
all_code = generate_dispatch_gemm_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_configs,
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
568
custom_ops/utils/auto_gen_visitor_fp8_gemm_fused_kernels.py
Normal file
568
custom_ops/utils/auto_gen_visitor_fp8_gemm_fused_kernels.py
Normal file
@@ -0,0 +1,568 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate gemm_fused_kernel code."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
获取候选的tile配置列表。
|
||||
|
||||
Args:
|
||||
无参数。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 包含tile配置的三元组列表,每个三元组中的字符串表示tile的形状"。
|
||||
|
||||
"""
|
||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||
base_configs.extend([
|
||||
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
|
||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||
])
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_candidate_configs(sm, min_stages, max_stages):
|
||||
"""
|
||||
获取候选的gemm算子配置列表。
|
||||
|
||||
Args:
|
||||
sm (str): 计算能力,如"70"。
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
|
||||
hasbias = ("false", "true")
|
||||
|
||||
candidate_configs.extend([(stages, tiles, hasbias)])
|
||||
|
||||
return candidate_configs
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/visitor_fp8_gemm_fused_template.h"
|
||||
"""
|
||||
|
||||
CommonTail = """
|
||||
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_visitor_fuse_gemm<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/visitor_fp8_gemm_fused.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_visitor_gemm_fused_kernel_{gemm_config}(const int type_id, GemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/visitor_fp8_gemm_fused_template.h"
|
||||
|
||||
bool launch_visitor_gemm_fused_kernel_{gemm_config}(const int type_id, GemmEpilogueAllParams params){
|
||||
switch (type_id) {"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_visitor_fuse_gemm<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||||
break;"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <limits>
|
||||
#include "helper.h"
|
||||
#include "visitor_fp8_gemm_fused.h"
|
||||
#include "launch_visitor_gemm_fused_kernel.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> per_channel_gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{hasbias}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> per_channel_gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_visitor_gemm_fused_kernel(const int type_id, const int kernel_id, GemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_visitor_gemm_fused_kernel_{gemm_config}(type_id, params);
|
||||
break;"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool fp8_visitor_gemm_fused(GemmEpilogueAllParams params) {
|
||||
if (per_channel_gemm_type_map.find(params.fuse_gemm_config) == per_channel_gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8_visitor_gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = per_channel_gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M+31)/32 *32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
std::string mnk_string = "per_channel_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
auto encoded_mnk_string = base64_encode(mnk_string);
|
||||
|
||||
int kernel_id;
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
|
||||
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
|
||||
|
||||
if (per_channel_gemm_config_map.find(best_config) == per_channel_gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check auto_gen_visitor_fp8_gemm_fused_kernels.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = per_channel_gemm_config_map[best_config];
|
||||
}
|
||||
return launch_visitor_gemm_fused_kernel(type_id, kernel_id, params);
|
||||
} else { // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
float duratation = 1000000.f;
|
||||
// tune all kernel_id kernels
|
||||
for(const auto& config_pair : per_channel_gemm_config_map){
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_visitor_gemm_fused_kernel(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_visitor_gemm_fused_kernel(type_id, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[encoded_mnk_string] = best_kernel_id;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << std::endl;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
"""
|
||||
生成函数模板
|
||||
"""
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
代码参数解析
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["89"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_stages",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The min stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_stages",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The max stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_source_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
tiles: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
生成.cu源文件
|
||||
"""
|
||||
all_code = SubstituteTemplate(CommonHead, {})
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||
|
||||
all_code += CommonTail
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_gemm_cus(
|
||||
generate_dir: str,
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
tiles: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
生成含有CUDA执行器的launch_visitor_gemm_fused_kernel.h和launch_visitor_gemm_fused_kernel_*.cu文件。
|
||||
|
||||
Args:
|
||||
generate_dir (str): 生成文件的目录路径。
|
||||
inputs_type (str): 输入类型,可以是"float", "half"或者"bfloat16"中的一种。
|
||||
outputs_type (str): 输出类型,可以是"float", "half"或者"bfloat16"中的一种。
|
||||
stages (int): Gemm算子的阶段数。
|
||||
tiles (str): 包含三个元素的列表,每个元素都是包含三个元素的列表,分别代表线程块形状、整体线程形状和MMA形状。例如:["32,8,4","16,8,4"]。
|
||||
hasbiases (str): 是否包含偏置量,可以是"true"或者"false"中的一种。
|
||||
sm (str): GPU的SM大小,可以是"70"或者"80"中的一种。
|
||||
|
||||
Returns:
|
||||
tuple (str, dict):
|
||||
- str (head_all_code) - 所有头部代码的字符串。
|
||||
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典,格式为{"gemm_config": source_code}。
|
||||
"""
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir,
|
||||
"launch_visitor_gemm_fused_kernel.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"SM": sm,
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir,
|
||||
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu")
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_visitor_gemm_fused.cu
|
||||
def generate_dispatch_gemm_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
tiles: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
生成调度Gemm的CU代码。
|
||||
|
||||
Args:
|
||||
inputs_type (str): 输入类型,字符串格式,多个类型用逗号分隔。可选值为 "float", "half".
|
||||
outputs_type (str): 输出类型,字符串格式,多个类型用逗号分隔。可选值为 "float", "half".
|
||||
stages (int): Gemm的层数。
|
||||
tiles (str): 瓦片形状,字符串格式,多个瓦片形状用逗号分隔。每个瓦片形状由三个整数组成,表示线程块形状、线程块内部warp形状和MMA形状。例如:"<8,8,8>,<8,8,8>,<8,8,8>"。
|
||||
hasbiases (str): 是否有偏置,字符串格式,多个值用逗号分隔。可选值为 "true", "false".
|
||||
sm (str): 使用的SM数量,字符串格式,多个值用逗号分隔。可选值为 "32", "64".
|
||||
|
||||
Returns:
|
||||
str: 返回一个包含所有代码的字符串。
|
||||
|
||||
"""
|
||||
|
||||
all_code = code_part0
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
for stage in stages:
|
||||
value_dict = {
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += code_part4
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [
|
||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
||||
]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"tile_id": str(tile_id),
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
all_code += SubstituteTemplate(code_part6, value_dict)
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
min_stages = args.min_stages
|
||||
max_stages = args.max_stages
|
||||
inputs_type = ("float8_e4m3fn", "float8_e5m2")
|
||||
outputs_type = ("float16", "bfloat16")
|
||||
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "89":
|
||||
fuse_gemm_configs = get_candidate_configs(sm, min_stages,
|
||||
max_stages)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
|
||||
all_code = generate_source_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
sm_dict[sm],
|
||||
)
|
||||
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
fuse_gemm_config = list(fuse_gemm_configs)[0]
|
||||
|
||||
# Compile parallelization
|
||||
generate_launch_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
sm_dict[sm],
|
||||
)
|
||||
|
||||
file_name = (
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
|
||||
)
|
||||
all_code = generate_dispatch_gemm_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
elif sm == 90:
|
||||
print("Not supported yet.")
|
||||
exit(0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
Reference in New Issue
Block a user