mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
143 lines
4.6 KiB
C++
143 lines
4.6 KiB
C++
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#pragma once
|
|
#include <mutex>
|
|
|
|
#include "cutlass/bfloat16.h"
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/half.h"
|
|
#include "helper.h"
|
|
#include "paddle/extension.h"
|
|
|
|
/**
|
|
* Helper function for checking CUTLASS errors
|
|
*/
|
|
#define CUTLASS_CHECK(status) \
|
|
{ \
|
|
cutlass::Status error = status; \
|
|
PD_CHECK(error == cutlass::Status::kSuccess, \
|
|
cutlassGetStatusString(error)); \
|
|
}
|
|
|
|
/**
|
|
* A wrapper for a kernel that is used to guard against compilation on
|
|
* architectures that will never use the kernel. The purpose of this is to
|
|
* reduce the size of the compiled binary.
|
|
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
|
* into code that will be executed on the device where it is defined.
|
|
*/
|
|
template <typename Kernel> struct enable_sm90_or_later : Kernel {
|
|
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
|
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
|
Kernel::operator()(std::forward<Args>(args)...);
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <paddle::DataType D> class CutlassDtypeTraits;
|
|
|
|
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
|
public:
|
|
typedef float DataType;
|
|
typedef float data_t;
|
|
};
|
|
|
|
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
|
public:
|
|
typedef cutlass::half_t DataType;
|
|
typedef paddle::float16 data_t;
|
|
};
|
|
|
|
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
|
public:
|
|
typedef cutlass::bfloat16_t DataType;
|
|
typedef paddle::bfloat16 data_t;
|
|
};
|
|
|
|
class CutlassGemmConfigMannager {
|
|
public:
|
|
static CutlassGemmConfigMannager &getInstance() {
|
|
static CutlassGemmConfigMannager instance;
|
|
return instance;
|
|
}
|
|
|
|
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
|
|
CutlassGemmConfigMannager &
|
|
operator=(const CutlassGemmConfigMannager &) = delete;
|
|
|
|
void up_date_configs(const nlohmann::json &j) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
for (auto it = j.begin(); it != j.end(); ++it) {
|
|
json_[it.key()] = it.value();
|
|
}
|
|
}
|
|
|
|
nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
|
|
if (!load_initialized_) {
|
|
std::ifstream file(config_file_path);
|
|
if (!file.good()) {
|
|
throw std::runtime_error(
|
|
"cutlass gemm_best_config can not be found, please set "
|
|
"gemm_best_config'path as "
|
|
"FLAGS_use_cutlass_device_best_config_path, or unset "
|
|
"FLAGS_use_cutlass_device_best_config_path to tune "
|
|
"gemm_best_config");
|
|
}
|
|
json_ = readJsonFromFile(config_file_path);
|
|
load_initialized_ = true;
|
|
save_initialized_ = false;
|
|
}
|
|
return &json_;
|
|
}
|
|
|
|
private:
|
|
void save_gemm_best_configs_(const std::string &config_file_path) {
|
|
std::ifstream file(config_file_path);
|
|
if (!file.good()) {
|
|
std::ofstream new_file(config_file_path);
|
|
new_file << json_.dump(4);
|
|
new_file.close();
|
|
} else {
|
|
nlohmann::json old_json = readJsonFromFile(config_file_path);
|
|
for (auto it = json_.begin(); it != json_.end(); ++it) {
|
|
old_json[it.key()] = it.value();
|
|
}
|
|
json_ = old_json;
|
|
std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
|
|
new_file << json_.dump(4);
|
|
new_file.close();
|
|
file.close();
|
|
}
|
|
return;
|
|
}
|
|
|
|
CutlassGemmConfigMannager()
|
|
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
|
|
~CutlassGemmConfigMannager() {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
if (save_initialized_) {
|
|
std::string config_file_path = "fp8_fuse_gemm_config.json";
|
|
save_gemm_best_configs_(config_file_path);
|
|
}
|
|
save_initialized_ = true;
|
|
load_initialized_ = false;
|
|
json_.clear();
|
|
}
|
|
mutable std::mutex mutex_;
|
|
nlohmann::json json_;
|
|
bool load_initialized_;
|
|
bool save_initialized_;
|
|
};
|