Files
FastDeploy/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h
2025-06-29 23:29:37 +00:00

143 lines
4.6 KiB
C++

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex>
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/half.h"
#include "helper.h"
#include "paddle/extension.h"
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
PD_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
/**
* A wrapper for a kernel that is used to guard against compilation on
* architectures that will never use the kernel. The purpose of this is to
* reduce the size of the compiled binary.
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template <typename Kernel> struct enable_sm90_or_later : Kernel {
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <paddle::DataType D> class CutlassDtypeTraits;
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
};
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
};
class CutlassGemmConfigMannager {
public:
static CutlassGemmConfigMannager &getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
CutlassGemmConfigMannager &
operator=(const CutlassGemmConfigMannager &) = delete;
void up_date_configs(const nlohmann::json &j) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = j.begin(); it != j.end(); ++it) {
json_[it.key()] = it.value();
}
}
nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
if (!load_initialized_) {
std::ifstream file(config_file_path);
if (!file.good()) {
throw std::runtime_error(
"cutlass gemm_best_config can not be found, please set "
"gemm_best_config'path as "
"FLAGS_use_cutlass_device_best_config_path, or unset "
"FLAGS_use_cutlass_device_best_config_path to tune "
"gemm_best_config");
}
json_ = readJsonFromFile(config_file_path);
load_initialized_ = true;
save_initialized_ = false;
}
return &json_;
}
private:
void save_gemm_best_configs_(const std::string &config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {
std::ofstream new_file(config_file_path);
new_file << json_.dump(4);
new_file.close();
} else {
nlohmann::json old_json = readJsonFromFile(config_file_path);
for (auto it = json_.begin(); it != json_.end(); ++it) {
old_json[it.key()] = it.value();
}
json_ = old_json;
std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
new_file << json_.dump(4);
new_file.close();
file.close();
}
return;
}
CutlassGemmConfigMannager()
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
~CutlassGemmConfigMannager() {
std::lock_guard<std::mutex> lock(mutex_);
if (save_initialized_) {
std::string config_file_path = "fp8_fuse_gemm_config.json";
save_gemm_best_configs_(config_file_path);
}
save_initialized_ = true;
load_initialized_ = false;
json_.clear();
}
mutable std::mutex mutex_;
nlohmann::json json_;
bool load_initialized_;
bool save_initialized_;
};