Files
FastDeploy/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h
2025-06-16 00:04:48 +08:00

122 lines
3.9 KiB
C++

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex>
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
#include "helper.h"
#include "paddle/extension.h"
template <paddle::DataType D>
class CutlassDtypeTraits;
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
};
template <>
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
};
class CutlassGemmConfigMannager {
public:
static CutlassGemmConfigMannager& getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete;
CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) =
delete;
void up_date_configs(const nlohmann::json& j) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = j.begin(); it != j.end(); ++it) {
json_[it.key()] = it.value();
}
}
nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) {
if (!load_initialized_) {
std::ifstream file(config_file_path);
if (!file.good()) {
throw std::runtime_error(
"cutlass gemm_best_config can not be found, please set "
"gemm_best_config'path as "
"FLAGS_use_cutlass_device_best_config_path, or unset "
"FLAGS_use_cutlass_device_best_config_path to tune "
"gemm_best_config");
}
json_ = readJsonFromFile(config_file_path);
load_initialized_ = true;
save_initialized_ = false;
}
return &json_;
}
private:
void save_gemm_best_configs_(const std::string& config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {
std::ofstream new_file(config_file_path);
new_file << json_.dump(4);
new_file.close();
} else {
nlohmann::json old_json = readJsonFromFile(config_file_path);
for (auto it = json_.begin(); it != json_.end(); ++it) {
old_json[it.key()] = it.value();
}
json_ = old_json;
std::ofstream new_file(config_file_path,
std::ios::out | std::ios::trunc);
new_file << json_.dump(4);
new_file.close();
file.close();
}
return;
}
CutlassGemmConfigMannager()
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
~CutlassGemmConfigMannager() {
std::lock_guard<std::mutex> lock(mutex_);
if (save_initialized_) {
std::string config_file_path = "fp8_fuse_gemm_config.json";
save_gemm_best_configs_(config_file_path);
}
save_initialized_ = true;
load_initialized_ = false;
json_.clear();
}
mutable std::mutex mutex_;
nlohmann::json json_;
bool load_initialized_;
bool save_initialized_;
};