mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
122 lines
3.9 KiB
C++
122 lines
3.9 KiB
C++
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#pragma once
|
|
#include <mutex>
|
|
|
|
#include "cutlass/bfloat16.h"
|
|
#include "cutlass/half.h"
|
|
#include "helper.h"
|
|
#include "paddle/extension.h"
|
|
|
|
template <paddle::DataType D>
|
|
class CutlassDtypeTraits;
|
|
|
|
template <>
|
|
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
|
public:
|
|
typedef float DataType;
|
|
typedef float data_t;
|
|
};
|
|
|
|
template <>
|
|
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
|
public:
|
|
typedef cutlass::half_t DataType;
|
|
typedef paddle::float16 data_t;
|
|
};
|
|
|
|
template <>
|
|
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
|
public:
|
|
typedef cutlass::bfloat16_t DataType;
|
|
typedef paddle::bfloat16 data_t;
|
|
};
|
|
|
|
class CutlassGemmConfigMannager {
|
|
public:
|
|
static CutlassGemmConfigMannager& getInstance() {
|
|
static CutlassGemmConfigMannager instance;
|
|
return instance;
|
|
}
|
|
|
|
CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete;
|
|
CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) =
|
|
delete;
|
|
|
|
void up_date_configs(const nlohmann::json& j) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
for (auto it = j.begin(); it != j.end(); ++it) {
|
|
json_[it.key()] = it.value();
|
|
}
|
|
}
|
|
|
|
nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) {
|
|
if (!load_initialized_) {
|
|
std::ifstream file(config_file_path);
|
|
if (!file.good()) {
|
|
throw std::runtime_error(
|
|
"cutlass gemm_best_config can not be found, please set "
|
|
"gemm_best_config'path as "
|
|
"FLAGS_use_cutlass_device_best_config_path, or unset "
|
|
"FLAGS_use_cutlass_device_best_config_path to tune "
|
|
"gemm_best_config");
|
|
}
|
|
json_ = readJsonFromFile(config_file_path);
|
|
load_initialized_ = true;
|
|
save_initialized_ = false;
|
|
}
|
|
return &json_;
|
|
}
|
|
|
|
private:
|
|
void save_gemm_best_configs_(const std::string& config_file_path) {
|
|
std::ifstream file(config_file_path);
|
|
if (!file.good()) {
|
|
std::ofstream new_file(config_file_path);
|
|
new_file << json_.dump(4);
|
|
new_file.close();
|
|
} else {
|
|
nlohmann::json old_json = readJsonFromFile(config_file_path);
|
|
for (auto it = json_.begin(); it != json_.end(); ++it) {
|
|
old_json[it.key()] = it.value();
|
|
}
|
|
json_ = old_json;
|
|
std::ofstream new_file(config_file_path,
|
|
std::ios::out | std::ios::trunc);
|
|
new_file << json_.dump(4);
|
|
new_file.close();
|
|
file.close();
|
|
}
|
|
return;
|
|
}
|
|
|
|
CutlassGemmConfigMannager()
|
|
: json_(nullptr), load_initialized_(false), save_initialized_(true) {}
|
|
~CutlassGemmConfigMannager() {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
if (save_initialized_) {
|
|
std::string config_file_path = "fp8_fuse_gemm_config.json";
|
|
save_gemm_best_configs_(config_file_path);
|
|
}
|
|
save_initialized_ = true;
|
|
load_initialized_ = false;
|
|
json_.clear();
|
|
}
|
|
mutable std::mutex mutex_;
|
|
nlohmann::json json_;
|
|
bool load_initialized_;
|
|
bool save_initialized_;
|
|
};
|