// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "cutlass/bfloat16.h" #include "cutlass/cutlass.h" #include "cutlass/half.h" #include "helper.h" #include "paddle/extension.h" /** * Helper function for checking CUTLASS errors */ #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ PD_CHECK(error == cutlass::Status::kSuccess, \ cutlassGetStatusString(error)); \ } /** * A wrapper for a kernel that is used to guard against compilation on * architectures that will never use the kernel. The purpose of this is to * reduce the size of the compiled binary. * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef * into code that will be executed on the device where it is defined. */ template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args &&...args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif } }; template class CutlassDtypeTraits; template <> class CutlassDtypeTraits { public: typedef float DataType; typedef float data_t; }; template <> class CutlassDtypeTraits { public: typedef cutlass::half_t DataType; typedef paddle::float16 data_t; }; template <> class CutlassDtypeTraits { public: typedef cutlass::bfloat16_t DataType; typedef paddle::bfloat16 data_t; }; class CutlassGemmConfigMannager { public: static CutlassGemmConfigMannager &getInstance() { static CutlassGemmConfigMannager instance; return instance; } CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete; CutlassGemmConfigMannager & operator=(const CutlassGemmConfigMannager &) = delete; void up_date_configs(const nlohmann::json &j) { std::lock_guard lock(mutex_); for (auto it = j.begin(); it != j.end(); ++it) { json_[it.key()] = it.value(); } } nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) { if (!load_initialized_) { std::ifstream file(config_file_path); if (!file.good()) { throw std::runtime_error( "cutlass gemm_best_config can not be found, please set " "gemm_best_config'path as " "FLAGS_use_cutlass_device_best_config_path, or unset " "FLAGS_use_cutlass_device_best_config_path to tune " "gemm_best_config"); } json_ = readJsonFromFile(config_file_path); load_initialized_ = true; save_initialized_ = false; } return &json_; } private: void save_gemm_best_configs_(const std::string &config_file_path) { std::ifstream file(config_file_path); if (!file.good()) { std::ofstream new_file(config_file_path); new_file << json_.dump(4); new_file.close(); } else { nlohmann::json old_json = readJsonFromFile(config_file_path); for (auto it = json_.begin(); it != json_.end(); ++it) { old_json[it.key()] = it.value(); } json_ = old_json; std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc); new_file << json_.dump(4); new_file.close(); file.close(); } return; } CutlassGemmConfigMannager() : json_(nullptr), load_initialized_(false), save_initialized_(true) {} ~CutlassGemmConfigMannager() { std::lock_guard lock(mutex_); if (save_initialized_) { std::string config_file_path = "fp8_fuse_gemm_config.json"; save_gemm_best_configs_(config_file_path); } save_initialized_ = true; load_initialized_ = false; json_.clear(); } mutable std::mutex mutex_; nlohmann::json json_; bool load_initialized_; bool save_initialized_; };