mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-30 11:26:39 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
654
custom_ops/auto_gen_fp8_fp8_gemm_fused_kernels.py
Normal file
654
custom_ops/auto_gen_fp8_fp8_gemm_fused_kernels.py
Normal file
@@ -0,0 +1,654 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate gemm_fused_kernel code."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_candidate_tiles():
|
||||
"""
|
||||
获取候选的tile配置列表。
|
||||
|
||||
Args:
|
||||
无参数。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 包含tile配置的三元组列表,每个三元组中的字符串表示tile的形状"。
|
||||
|
||||
"""
|
||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||
|
||||
base_configs.extend(
|
||||
[
|
||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||
]
|
||||
)
|
||||
|
||||
return base_configs
|
||||
|
||||
|
||||
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
|
||||
"""
|
||||
获取候选的gemm算子配置列表。
|
||||
|
||||
Args:
|
||||
sm (str): 计算能力,如"70"。
|
||||
min_split_k (int): split k的最小值。
|
||||
"""
|
||||
tiles = get_candidate_tiles()
|
||||
candidate_configs = list()
|
||||
|
||||
stages = tuple(i for i in range(min_stages, max_stages + 1, 1))
|
||||
splitks = tuple(i for i in range(min_split_k, max_split_k + 1, 1))
|
||||
hasbias = ("false", "true")
|
||||
|
||||
for act_tag in [
|
||||
("noact", "LinearCombination"),
|
||||
("relu", "LinearCombinationRelu"),
|
||||
("gelu", "LinearCombinationGELU"),
|
||||
]:
|
||||
candidate_configs.extend([(stages, splitks, tiles, act_tag, hasbias)])
|
||||
|
||||
return candidate_configs
|
||||
|
||||
|
||||
# this is a file's header part
|
||||
CommonHead = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fuse_gemm_{act_tag}_template.h"
|
||||
|
||||
"""
|
||||
|
||||
|
||||
CommonTail = """
|
||||
|
||||
"""
|
||||
|
||||
GemmDeclare = """
|
||||
template<>
|
||||
bool dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
GemmSplitKDeclare = """
|
||||
template<>
|
||||
bool dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(GemmEpilogueAllParams);
|
||||
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmHead = """
|
||||
#pragma once
|
||||
|
||||
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
|
||||
|
||||
"""
|
||||
|
||||
LaunchGemmDeclare = """
|
||||
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params);
|
||||
"""
|
||||
|
||||
LaunchGemmPart0 = """
|
||||
#pragma once
|
||||
|
||||
#include "launch_gemm_kernel.h"
|
||||
|
||||
bool launch_gemm_kernel_{gemm_config}(const int type_id, const int split_k, GemmEpilogueAllParams params){
|
||||
if(split_k < 2){
|
||||
params.split_k = 1;
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart1 = """
|
||||
case {type_id}:
|
||||
return dispatch_fuse_gemm_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart2 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
}else{
|
||||
switch (type_id) {
|
||||
"""
|
||||
|
||||
LaunchGemmPart3 = """
|
||||
case {type_id}:
|
||||
return dispatch_fuse_gemm_split_k_{act_tag}<phi::dtype::{input_type}, phi::dtype::{output_type},
|
||||
cutlass::gemm::GemmShape{thread_block_shape}, cutlass::gemm::GemmShape{warp_shape},
|
||||
cutlass::gemm::GemmShape{mma_shape}, {num_stages}, {hasbias}, {SM}>(params);
|
||||
break;
|
||||
"""
|
||||
|
||||
LaunchGemmPart4 = """
|
||||
default:
|
||||
throw std::runtime_error("cutlass gemm config is invalid.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit.
|
||||
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <limits>
|
||||
#include "helper.h"
|
||||
#include "fp8_fp8_gemm_scale_bias_act.h"
|
||||
#include "launch_gemm_kernel.h"
|
||||
|
||||
COMMON_DECLARE_string(use_cutlass_device_best_config_path);
|
||||
|
||||
std::map<std::string, int> gemm_type_map{"""
|
||||
|
||||
code_part1 = """
|
||||
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """
|
||||
|
||||
code_part2 = """
|
||||
};
|
||||
|
||||
std::map<std::string, int> gemm_config_map{
|
||||
"""
|
||||
|
||||
code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
|
||||
"""
|
||||
|
||||
code_part4 = """};
|
||||
|
||||
bool launch_gemm_kernel(const int type_id, const int split_k, const int kernel_id, GemmEpilogueAllParams params){
|
||||
switch (kernel_id) {"""
|
||||
|
||||
code_part5 = """
|
||||
case {tile_id}:
|
||||
return launch_gemm_kernel_{gemm_config}(type_id, split_k, params);
|
||||
break;"""
|
||||
|
||||
code_part6 = """
|
||||
default:
|
||||
throw std::runtime_error("fp8_fp8_bf16_gemm_fused Config is invalid.");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
|
||||
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
|
||||
throw std::runtime_error("fp8 gemm_fused config is invalid.");
|
||||
}
|
||||
|
||||
int type_id = gemm_type_map[params.fuse_gemm_config];
|
||||
int M = (params.M+31)/32 *32;
|
||||
int N = params.N;
|
||||
int K = params.K;
|
||||
|
||||
std::string mnk_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
|
||||
auto encoded_mnk_string = base64_encode(mnk_string);
|
||||
|
||||
std::string mnk_split_k_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
|
||||
auto encoded_mnk_split_k_string = base64_encode(mnk_split_k_string);
|
||||
|
||||
int split_k;
|
||||
int kernel_id;
|
||||
std::string best_config;
|
||||
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
|
||||
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
|
||||
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
|
||||
nlohmann::json* config_json = new nlohmann::json();
|
||||
if (config_file_path != "default") {
|
||||
config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
|
||||
}
|
||||
|
||||
best_config = get_relative_best<std::string>(config_json, encoded_mnk_string, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3");
|
||||
split_k = get_relative_best<int>(config_json, encoded_mnk_split_k_string, 1);
|
||||
|
||||
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
|
||||
throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate.");
|
||||
} else {
|
||||
kernel_id = gemm_config_map[best_config];
|
||||
}
|
||||
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
|
||||
} else { // tune kernel
|
||||
int warm_up_times = 5;
|
||||
int tune_times = 10;
|
||||
std::string best_kernel_id = "";
|
||||
int best_split_k = -1;
|
||||
float duratation = 1000000.f;
|
||||
// tune all split_k, kernel_id kernels
|
||||
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
|
||||
for(const auto& config_pair : gemm_config_map){
|
||||
bool is_valid = true;
|
||||
// warm up
|
||||
for(int num_time = 0; num_time < warm_up_times; ++num_time){
|
||||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_valid){
|
||||
continue;
|
||||
}
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
cudaStreamSynchronize(params.stream);
|
||||
cudaEventRecord(start, params.stream);
|
||||
for(int num_time = 0; num_time < tune_times; ++num_time){
|
||||
if(!launch_gemm_kernel(type_id, i, config_pair.second, params)){
|
||||
is_valid = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
cudaEventRecord(stop, params.stream);
|
||||
cudaEventSynchronize(stop);
|
||||
float elapsedTime;
|
||||
if(is_valid){
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
if(elapsedTime < duratation){
|
||||
best_kernel_id = config_pair.first;
|
||||
best_split_k = i;
|
||||
duratation = elapsedTime;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json new_json;
|
||||
new_json[encoded_mnk_string] = best_kernel_id;
|
||||
new_json[encoded_mnk_split_k_string] = best_split_k;
|
||||
best_config_mannager.up_date_configs(new_json);
|
||||
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def SubstituteTemplate(template, values):
|
||||
"""
|
||||
生成函数模板
|
||||
"""
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = f"\\{{{key}\\}}"
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
代码参数解析
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_arch",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["89"],
|
||||
help="The CUDA architecture to be generated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_split_k",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The min split k for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_split_k",
|
||||
type=int,
|
||||
default=6,
|
||||
help="The max split k for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_stages",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The min stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_stages",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The max stages for the gemm kernel.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# generate source .cu
|
||||
def generate_source_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
tiles: str,
|
||||
act_tag: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
):
|
||||
"""
|
||||
生成.cu源文件
|
||||
"""
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
}
|
||||
all_code = SubstituteTemplate(CommonHead, value_dict)
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"act_tag": act_tag,
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for stage in stages:
|
||||
for hasbias in hasbiases:
|
||||
for tile_config in tiles:
|
||||
value_dict = {
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"thread_block_shape": tile_config[0],
|
||||
"warp_shape": tile_config[1],
|
||||
"mma_shape": tile_config[2],
|
||||
"num_stages": str(stage),
|
||||
"act_tag": act_tag,
|
||||
"hasbias": hasbias,
|
||||
"SM": sm,
|
||||
}
|
||||
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
|
||||
|
||||
all_code += CommonTail
|
||||
return all_code
|
||||
|
||||
|
||||
# generate gemm launch .cu
|
||||
def generate_launch_gemm_cus(
|
||||
generate_dir: str,
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
split_ks: int,
|
||||
tiles: str,
|
||||
act_tags: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
code_map = {}
|
||||
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
|
||||
head_all_code = LaunchGemmHead
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||
os.makedirs(generate_dir, exist_ok=True)
|
||||
with open(head_path, "w") as f:
|
||||
f.write(head_all_code)
|
||||
f.close()
|
||||
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||
split_k_code = ""
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"SM": sm,
|
||||
}
|
||||
source_all_code += SubstituteTemplate(
|
||||
LaunchGemmPart1, value_dict
|
||||
)
|
||||
split_k_code += SubstituteTemplate(
|
||||
LaunchGemmPart3, value_dict
|
||||
)
|
||||
type_id += 1
|
||||
source_all_code += LaunchGemmPart2
|
||||
source_all_code += split_k_code
|
||||
source_all_code += LaunchGemmPart4
|
||||
code_map[gemm_config_str] = source_all_code
|
||||
source_path = os.path.join(
|
||||
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu"
|
||||
)
|
||||
with open(source_path, "w") as f:
|
||||
f.write(source_all_code)
|
||||
f.close()
|
||||
|
||||
return head_all_code, code_map
|
||||
|
||||
|
||||
# generate fp8_fp8_gemm_scale_bias_act.cu
|
||||
def generate_dispatch_gemm_cu(
|
||||
inputs_type: str,
|
||||
outputs_type: str,
|
||||
stages: int,
|
||||
split_ks: int,
|
||||
tiles: str,
|
||||
act_tags: str,
|
||||
hasbiases: str,
|
||||
sm: str,
|
||||
min_split_k: int,
|
||||
max_split_k: int,
|
||||
):
|
||||
|
||||
all_code = code_part0
|
||||
type_id = 0
|
||||
for input_type in inputs_type:
|
||||
for output_type in outputs_type:
|
||||
for act_tag in act_tags:
|
||||
for hasbias in hasbiases:
|
||||
value_dict = {
|
||||
"act_tag": act_tag,
|
||||
"input_type": input_type,
|
||||
"output_type": output_type,
|
||||
"hasbias": hasbias,
|
||||
"type_id": str(type_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part1, value_dict)
|
||||
type_id += 1
|
||||
|
||||
all_code += code_part2
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
for stage in stages:
|
||||
value_dict = {
|
||||
"thread_block_shape": tile[0],
|
||||
"warp_shape": tile[1],
|
||||
"mma_shape": tile[2],
|
||||
"num_stages": str(stage),
|
||||
"tile_id": str(tile_id),
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part3, value_dict)
|
||||
tile_id += 1
|
||||
all_code += code_part4
|
||||
tile_id = 0
|
||||
for tile in tiles:
|
||||
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||
for stage in stages:
|
||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||
value_dict = {
|
||||
"tile_id": str(tile_id),
|
||||
"gemm_config": gemm_config_str,
|
||||
}
|
||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||
tile_id += 1
|
||||
value_dict.update(
|
||||
{
|
||||
"min_split_k": str(min_split_k),
|
||||
"max_split_k": str(max_split_k),
|
||||
}
|
||||
)
|
||||
all_code += SubstituteTemplate(code_part6, value_dict)
|
||||
return all_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
archs = args.cuda_arch
|
||||
min_split_k = args.min_split_k
|
||||
max_split_k = args.max_split_k
|
||||
min_stages = args.min_stages
|
||||
max_stages = args.max_stages
|
||||
inputs_type = ("float8_e4m3fn", "float8_e5m2")
|
||||
outputs_type = ("float16", "bfloat16")
|
||||
sm_dict = {"89": "cutlass::arch::Sm89", "90": "cutlass::arch::Sm90"}
|
||||
|
||||
for sm in archs:
|
||||
if sm == "89":
|
||||
fuse_gemm_configs = get_candidate_configs(
|
||||
sm, min_split_k, max_split_k, min_stages, max_stages
|
||||
)
|
||||
for fuse_gemm_config in fuse_gemm_configs:
|
||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
|
||||
all_code = generate_source_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[2],
|
||||
fuse_gemm_config[3][0],
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
fuse_gemm_config = list(fuse_gemm_configs)[0]
|
||||
|
||||
act_tags = ["noact", "relu", "gelu"]
|
||||
# Compile parallelization
|
||||
generate_launch_gemm_cus(
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
act_tags,
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
|
||||
# hard code for act_tag
|
||||
|
||||
file_name = (
|
||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu"
|
||||
)
|
||||
all_code = generate_dispatch_gemm_cu(
|
||||
inputs_type,
|
||||
outputs_type,
|
||||
fuse_gemm_config[0],
|
||||
fuse_gemm_config[1],
|
||||
fuse_gemm_config[2],
|
||||
act_tags,
|
||||
fuse_gemm_config[4],
|
||||
sm_dict[sm],
|
||||
min_split_k,
|
||||
max_split_k,
|
||||
)
|
||||
file_dir = os.path.dirname(file_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_name, "w") as f:
|
||||
f.write(all_code)
|
||||
f.close()
|
||||
|
||||
elif sm == 90:
|
||||
print("Not supported yet.")
|
||||
exit(0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SM: {sm}")
|
||||
Reference in New Issue
Block a user