mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
655 lines
21 KiB
Python
655 lines
21 KiB
Python
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""generate gemm_fused_kernel code."""
|
||
|
||
import argparse
|
||
import os
|
||
import re
|
||
|
||
|
||
def get_candidate_tiles():
|
||
"""
|
||
获取候选的tile配置列表。
|
||
|
||
Args:
|
||
无参数。
|
||
|
||
Returns:
|
||
List[Tuple[str, str, str]]: 包含tile配置的三元组列表,每个三元组中的字符串表示tile的形状"。
|
||
|
||
"""
|
||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||
|
||
base_configs.extend(
|
||
[
|
||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||
]
|
||
)
|
||
|
||
return base_configs
|
||
|
||
|
||
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
|
||
"""
|
||
获取候选的gemm算子配置列表。
|
||
|
||
Args:
|
||
sm (str): 计算能力,如"70"。
|
||
min_split_k (int): split k的最小值。
|
||
"""
|
||
tiles = get_candidate_tiles()
|
||
candidate_configs = list()
|
||
|
||
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
|
||
splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1))
|
||
hasbias = ("false", "true")
|
||
|
||
for act_tag in [
|
||
("noact", "LinearCombination"),
|
||
("relu", "LinearCombinationRelu"),
|
||
("gelu", "LinearCombinationGELU"),
|
||
]:
|
||
candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)])
|
||
|
||
return candidate_configs
|
||
|
||
|
||
# this is a file's header part
|
||
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||
|
||
#pragma once
|
||
|
||
#include "fp8_gemm_fused/fuse_gemm_{act_tag}_template.h"
|
||
|
||
"""
|
||
|
||
|
||
CommonTail = """
|
||
|
||
"""
|
||
|
||
GemmDeclare = """
|
||
template<>
|
||
bool dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||
|
||
|
||
"""
|
||
|
||
|
||
GemmSplitKDeclare = """
|
||
template<>
|
||
bool dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||
|
||
|
||
"""
|
||
|
||
LaunchGemmHead = """
|
||
#pragma once
|
||
|
||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||
|
||
"""
|
||
|
||
LaunchGemmDeclare = """
|
||
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params);
|
||
"""
|
||
|
||
LaunchGemmPart0 = """
|
||
#pragma once
|
||
|
||
#include "launch_gemm_kernel.h"
|
||
|
||
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params){
|
||
if(split_k < 2){
|
||
params.split_k = 1;
|
||
switch (type_id) {
|
||
"""
|
||
|
||
LaunchGemmPart1 = """
|
||
case {type_id}:
|
||
return dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||
break;
|
||
"""
|
||
|
||
LaunchGemmPart2 = """
|
||
default:
|
||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||
break;
|
||
}
|
||
}else{
|
||
switch (type_id) {
|
||
"""
|
||
|
||
LaunchGemmPart3 = """
|
||
case {type_id}:
|
||
return dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||
break;
|
||
"""
|
||
|
||
LaunchGemmPart4 = """
|
||
default:
|
||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||
break;
|
||
}
|
||
}
|
||
|
||
return false;
|
||
}
|
||
"""
|
||
|
||
|
||
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||
|
||
#include <map>
|
||
#include <regex>
|
||
#include <limits>
|
||
#include "helper.h"
|
||
#include "fp8_fp8_gemm_scale_bias_act.h"
|
||
#include "launch_gemm_kernel.h"
|
||
|
||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||
|
||
std::map<std::string, int> gemm_type_map{"""
|
||
|
||
code_part1 = """
|
||
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||
|
||
code_part2 = """
|
||
};
|
||
|
||
std::map<std::string, int> gemm_config_map{
|
||
"""
|
||
|
||
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
|
||
"""
|
||
|
||
code_part4 = """};
|
||
|
||
bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){
|
||
switch (kernel_id) {"""
|
||
|
||
code_part5 = """
|
||
case {tile_id}:
|
||
return launch_gemm_kernel_{gemm_config}(type_id, split_k, params);
|
||
break;"""
|
||
|
||
code_part6 = """
|
||
default:
|
||
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
|
||
break;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
|
||
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
|
||
throw std::runtime_error("fp8 gemm_fused config is invalid.");
|
||
}
|
||
|
||
int type_id = gemm_type_map[params.fuse_gemm_config];
|
||
int M = (params.M+31)/32 *32;
|
||
int N = params.N;
|
||
int K = params.K;
|
||
|
||
std::string mnk_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||
auto encoded_mnk_string = base64_encode(mnk_string);
|
||
|
||
std::string mnk_split_k_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
|
||
auto encoded_mnk_split_k_string = base64_encode(mnk_split_k_string);
|
||
|
||
int split_k;
|
||
int kernel_id;
|
||
std::string best_config;
|
||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
|
||
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||
nlohmann::json* config_json = new nlohmann::json();
|
||
if (config_file_path != "default") {
|
||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||
}
|
||
|
||
best_config = get_relative_best<std::string>(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
|
||
split_k = get_relative_best<int>(config_json, encoded_mnk_split_k_string, 1);
|
||
|
||
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
|
||
throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate.");
|
||
} else {
|
||
kernel_id = gemm_config_map[best_config];
|
||
}
|
||
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
|
||
} else { // tune kernel
|
||
int warm_up_times = 5;
|
||
int tune_times = 10;
|
||
std::string best_kernel_id = "";
|
||
int best_split_k = -1;
|
||
float duratation = 1000000.f;
|
||
// tune all split_k, kernel_id kernels
|
||
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
|
||
for(const auto& config_pair : gemm_config_map){
|
||
bool is_valid = true;
|
||
// warm up
|
||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||
is_valid = false;
|
||
break;
|
||
}
|
||
}
|
||
if(!is_valid){
|
||
continue;
|
||
}
|
||
cudaEvent_t start, stop;
|
||
cudaEventCreate(&start);
|
||
cudaEventCreate(&stop);
|
||
cudaStreamSynchronize(params.stream);
|
||
cudaEventRecord(start, params.stream);
|
||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||
is_valid = false;
|
||
break;
|
||
};
|
||
}
|
||
cudaEventRecord(stop, params.stream);
|
||
cudaEventSynchronize(stop);
|
||
float elapsedTime;
|
||
if(is_valid){
|
||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||
} else {
|
||
continue;
|
||
}
|
||
cudaEventDestroy(start);
|
||
cudaEventDestroy(stop);
|
||
if(elapsedTime < duratation){
|
||
best_kernel_id = config_pair.first;
|
||
best_split_k = i;
|
||
duratation = elapsedTime;
|
||
}
|
||
}
|
||
}
|
||
|
||
nlohmann::json new_json;
|
||
new_json[encoded_mnk_string] = best_kernel_id;
|
||
new_json[encoded_mnk_split_k_string] = best_split_k;
|
||
best_config_mannager.up_date_configs(new_json);
|
||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
|
||
return true;
|
||
}
|
||
}
|
||
"""
|
||
|
||
|
||
def SubstituteTemplate(template, values):
|
||
"""
|
||
生成函数模板
|
||
"""
|
||
text = template
|
||
changed = True
|
||
while changed:
|
||
changed = False
|
||
for key, value in values.items():
|
||
regex = f"\\{{{key}\\}}"
|
||
newtext = re.sub(regex, value, text)
|
||
if newtext != text:
|
||
changed = True
|
||
text = newtext
|
||
return text
|
||
|
||
|
||
def parse_args():
|
||
"""
|
||
代码参数解析
|
||
"""
|
||
parser = argparse.ArgumentParser(
|
||
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||
)
|
||
parser.add_argument(
|
||
"--cuda_arch",
|
||
type=str,
|
||
nargs="+",
|
||
default=["89"],
|
||
help="The CUDA architecture to be generated.",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--min_split_k",
|
||
type=int,
|
||
default=2,
|
||
help="The min split k for the gemm kernel.",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--max_split_k",
|
||
type=int,
|
||
default=6,
|
||
help="The max split k for the gemm kernel.",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--min_stages",
|
||
type=int,
|
||
default=2,
|
||
help="The min stages for the gemm kernel.",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--max_stages",
|
||
type=int,
|
||
default=8,
|
||
help="The max stages for the gemm kernel.",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
# generate source .cu
|
||
def generate_source_cu(
|
||
inputs_type: str,
|
||
outputs_type: str,
|
||
stages: int,
|
||
tiles: str,
|
||
act_tag: str,
|
||
hasbiases: str,
|
||
sm: str,
|
||
):
|
||
"""
|
||
生成.cu源文件
|
||
"""
|
||
value_dict = {
|
||
"act_tag": act_tag,
|
||
}
|
||
all_code = SubstituteTemplate(CommonHead, value_dict)
|
||
|
||
for input_type in inputs_type:
|
||
for output_type in outputs_type:
|
||
for stage in stages:
|
||
for hasbias in hasbiases:
|
||
for tile_config in tiles:
|
||
value_dict = {
|
||
"input_type": input_type,
|
||
"output_type": output_type,
|
||
"thread_block_shape": tile_config[0],
|
||
"warp_shape": tile_config[1],
|
||
"mma_shape": tile_config[2],
|
||
"num_stages": str(stage),
|
||
"act_tag": act_tag,
|
||
"hasbias": hasbias,
|
||
"SM": sm,
|
||
}
|
||
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||
|
||
for input_type in inputs_type:
|
||
for output_type in outputs_type:
|
||
for stage in stages:
|
||
for hasbias in hasbiases:
|
||
for tile_config in tiles:
|
||
value_dict = {
|
||
"input_type": input_type,
|
||
"output_type": output_type,
|
||
"thread_block_shape": tile_config[0],
|
||
"warp_shape": tile_config[1],
|
||
"mma_shape": tile_config[2],
|
||
"num_stages": str(stage),
|
||
"act_tag": act_tag,
|
||
"hasbias": hasbias,
|
||
"SM": sm,
|
||
}
|
||
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
|
||
|
||
all_code += CommonTail
|
||
return all_code
|
||
|
||
|
||
# generate gemm launch .cu
|
||
def generate_launch_gemm_cus(
|
||
generate_dir: str,
|
||
inputs_type: str,
|
||
outputs_type: str,
|
||
stages: int,
|
||
split_ks: int,
|
||
tiles: str,
|
||
act_tags: str,
|
||
hasbiases: str,
|
||
sm: str,
|
||
min_split_k: int,
|
||
max_split_k: int,
|
||
):
|
||
code_map = {}
|
||
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
|
||
head_all_code = LaunchGemmHead
|
||
for tile in tiles:
|
||
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||
for stage in stages:
|
||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||
value_dict = {
|
||
"gemm_config": gemm_config_str,
|
||
}
|
||
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||
os.makedirs(generate_dir, exist_ok=True)
|
||
with open(head_path, "w") as f:
|
||
f.write(head_all_code)
|
||
f.close()
|
||
|
||
for tile in tiles:
|
||
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||
for stage in stages:
|
||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||
value_dict = {
|
||
"gemm_config": gemm_config_str,
|
||
}
|
||
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||
split_k_code = ""
|
||
type_id = 0
|
||
for input_type in inputs_type:
|
||
for output_type in outputs_type:
|
||
for act_tag in act_tags:
|
||
for hasbias in hasbiases:
|
||
value_dict = {
|
||
"act_tag": act_tag,
|
||
"input_type": input_type,
|
||
"output_type": output_type,
|
||
"hasbias": hasbias,
|
||
"type_id": str(type_id),
|
||
"thread_block_shape": tile[0],
|
||
"warp_shape": tile[1],
|
||
"mma_shape": tile[2],
|
||
"num_stages": str(stage),
|
||
"SM": sm,
|
||
}
|
||
source_all_code += SubstituteTemplate(
|
||
LaunchGemmPart1, value_dict
|
||
)
|
||
split_k_code += SubstituteTemplate(
|
||
LaunchGemmPart3, value_dict
|
||
)
|
||
type_id += 1
|
||
source_all_code += LaunchGemmPart2
|
||
source_all_code += split_k_code
|
||
source_all_code += LaunchGemmPart4
|
||
code_map[gemm_config_str] = source_all_code
|
||
source_path = os.path.join(
|
||
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu"
|
||
)
|
||
with open(source_path, "w") as f:
|
||
f.write(source_all_code)
|
||
f.close()
|
||
|
||
return head_all_code, code_map
|
||
|
||
|
||
# generate fp8_fp8_gemm_scale_bias_act.cu
|
||
def generate_dispatch_gemm_cu(
|
||
inputs_type: str,
|
||
outputs_type: str,
|
||
stages: int,
|
||
split_ks: int,
|
||
tiles: str,
|
||
act_tags: str,
|
||
hasbiases: str,
|
||
sm: str,
|
||
min_split_k: int,
|
||
max_split_k: int,
|
||
):
|
||
|
||
all_code = code_part0
|
||
type_id = 0
|
||
for input_type in inputs_type:
|
||
for output_type in outputs_type:
|
||
for act_tag in act_tags:
|
||
for hasbias in hasbiases:
|
||
value_dict = {
|
||
"act_tag": act_tag,
|
||
"input_type": input_type,
|
||
"output_type": output_type,
|
||
"hasbias": hasbias,
|
||
"type_id": str(type_id),
|
||
}
|
||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||
type_id += 1
|
||
|
||
all_code += code_part2
|
||
tile_id = 0
|
||
for tile in tiles:
|
||
for stage in stages:
|
||
value_dict = {
|
||
"thread_block_shape": tile[0],
|
||
"warp_shape": tile[1],
|
||
"mma_shape": tile[2],
|
||
"num_stages": str(stage),
|
||
"tile_id": str(tile_id),
|
||
}
|
||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||
tile_id += 1
|
||
all_code += code_part4
|
||
tile_id = 0
|
||
for tile in tiles:
|
||
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||
for stage in stages:
|
||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||
value_dict = {
|
||
"tile_id": str(tile_id),
|
||
"gemm_config": gemm_config_str,
|
||
}
|
||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||
tile_id += 1
|
||
value_dict.update(
|
||
{
|
||
"min_split_k": str(min_split_k),
|
||
"max_split_k": str(max_split_k),
|
||
}
|
||
)
|
||
all_code += SubstituteTemplate(code_part6, value_dict)
|
||
return all_code
|
||
|
||
|
||
if __name__ == "__main__":
|
||
args = parse_args()
|
||
archs = args.cuda_arch
|
||
min_split_k = args.min_split_k
|
||
max_split_k = args.max_split_k
|
||
min_stages = args.min_stages
|
||
max_stages = args.max_stages
|
||
inputs_type = ("float8_e4m3fn", "float8_e5m2")
|
||
outputs_type = ("float16", "bfloat16")
|
||
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
|
||
|
||
for sm in archs:
|
||
if sm == "89":
|
||
fuse_gemm_configs = get_candidate_configs(
|
||
sm, min_split_k, max_split_k, min_stages, max_stages
|
||
)
|
||
for fuse_gemm_config in fuse_gemm_configs:
|
||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
|
||
all_code = generate_source_cu(
|
||
inputs_type,
|
||
outputs_type,
|
||
fuse_gemm_config[0],
|
||
fuse_gemm_config[2],
|
||
fuse_gemm_config[3][0],
|
||
fuse_gemm_config[4],
|
||
sm_dict[sm],
|
||
)
|
||
file_dir = os.path.dirname(file_name)
|
||
os.makedirs(file_dir, exist_ok=True)
|
||
with open(file_name, "w") as f:
|
||
f.write(all_code)
|
||
f.close()
|
||
|
||
fuse_gemm_config = list(fuse_gemm_configs)[0]
|
||
|
||
act_tags = ["noact", "relu", "gelu"]
|
||
# Compile parallelization
|
||
generate_launch_gemm_cus(
|
||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||
inputs_type,
|
||
outputs_type,
|
||
fuse_gemm_config[0],
|
||
fuse_gemm_config[1],
|
||
fuse_gemm_config[2],
|
||
act_tags,
|
||
fuse_gemm_config[4],
|
||
sm_dict[sm],
|
||
min_split_k,
|
||
max_split_k,
|
||
)
|
||
|
||
# hard code for act_tag
|
||
|
||
file_name = (
|
||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu"
|
||
)
|
||
all_code = generate_dispatch_gemm_cu(
|
||
inputs_type,
|
||
outputs_type,
|
||
fuse_gemm_config[0],
|
||
fuse_gemm_config[1],
|
||
fuse_gemm_config[2],
|
||
act_tags,
|
||
fuse_gemm_config[4],
|
||
sm_dict[sm],
|
||
min_split_k,
|
||
max_split_k,
|
||
)
|
||
file_dir = os.path.dirname(file_name)
|
||
os.makedirs(file_dir, exist_ok=True)
|
||
with open(file_name, "w") as f:
|
||
f.write(all_code)
|
||
f.close()
|
||
|
||
elif sm == 90:
|
||
print("Not supported yet.")
|
||
exit(0)
|
||
else:
|
||
raise ValueError(f"Unsupported SM: {sm}")
|