mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			143 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			143 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #pragma once
 | |
| #include <mutex>
 | |
| 
 | |
| #include "cutlass/bfloat16.h"
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/half.h"
 | |
| #include "helper.h"
 | |
| #include "paddle/extension.h"
 | |
| 
 | |
| /**
 | |
|  * Helper function for checking CUTLASS errors
 | |
|  */
 | |
| #define CUTLASS_CHECK(status)                                                  \
 | |
|   {                                                                            \
 | |
|     cutlass::Status error = status;                                            \
 | |
|     PD_CHECK(error == cutlass::Status::kSuccess,                               \
 | |
|              cutlassGetStatusString(error));                                   \
 | |
|   }
 | |
| 
 | |
| /**
 | |
|  * A wrapper for a kernel that is used to guard against compilation on
 | |
|  * architectures that will never use the kernel. The purpose of this is to
 | |
|  * reduce the size of the compiled binary.
 | |
|  * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
 | |
|  * into code that will be executed on the device where it is defined.
 | |
|  */
 | |
| template <typename Kernel> struct enable_sm90_or_later : Kernel {
 | |
|   template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
 | |
| #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
 | |
|     Kernel::operator()(std::forward<Args>(args)...);
 | |
| #endif
 | |
|   }
 | |
| };
 | |
| 
 | |
| template <paddle::DataType D> class CutlassDtypeTraits;
 | |
| 
 | |
| template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
 | |
| public:
 | |
|   typedef float DataType;
 | |
|   typedef float data_t;
 | |
| };
 | |
| 
 | |
| template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
 | |
| public:
 | |
|   typedef cutlass::half_t DataType;
 | |
|   typedef paddle::float16 data_t;
 | |
| };
 | |
| 
 | |
| template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
 | |
| public:
 | |
|   typedef cutlass::bfloat16_t DataType;
 | |
|   typedef paddle::bfloat16 data_t;
 | |
| };
 | |
| 
 | |
| class CutlassGemmConfigMannager {
 | |
| public:
 | |
|   static CutlassGemmConfigMannager &getInstance() {
 | |
|     static CutlassGemmConfigMannager instance;
 | |
|     return instance;
 | |
|   }
 | |
| 
 | |
|   CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
 | |
|   CutlassGemmConfigMannager &
 | |
|   operator=(const CutlassGemmConfigMannager &) = delete;
 | |
| 
 | |
|   void up_date_configs(const nlohmann::json &j) {
 | |
|     std::lock_guard<std::mutex> lock(mutex_);
 | |
|     for (auto it = j.begin(); it != j.end(); ++it) {
 | |
|       json_[it.key()] = it.value();
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) {
 | |
|     if (!load_initialized_) {
 | |
|       std::ifstream file(config_file_path);
 | |
|       if (!file.good()) {
 | |
|         throw std::runtime_error(
 | |
|             "cutlass gemm_best_config can not be found, please set "
 | |
|             "gemm_best_config'path as "
 | |
|             "FLAGS_use_cutlass_device_best_config_path, or unset "
 | |
|             "FLAGS_use_cutlass_device_best_config_path to tune "
 | |
|             "gemm_best_config");
 | |
|       }
 | |
|       json_ = readJsonFromFile(config_file_path);
 | |
|       load_initialized_ = true;
 | |
|       save_initialized_ = false;
 | |
|     }
 | |
|     return &json_;
 | |
|   }
 | |
| 
 | |
| private:
 | |
|   void save_gemm_best_configs_(const std::string &config_file_path) {
 | |
|     std::ifstream file(config_file_path);
 | |
|     if (!file.good()) {
 | |
|       std::ofstream new_file(config_file_path);
 | |
|       new_file << json_.dump(4);
 | |
|       new_file.close();
 | |
|     } else {
 | |
|       nlohmann::json old_json = readJsonFromFile(config_file_path);
 | |
|       for (auto it = json_.begin(); it != json_.end(); ++it) {
 | |
|         old_json[it.key()] = it.value();
 | |
|       }
 | |
|       json_ = old_json;
 | |
|       std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc);
 | |
|       new_file << json_.dump(4);
 | |
|       new_file.close();
 | |
|       file.close();
 | |
|     }
 | |
|     return;
 | |
|   }
 | |
| 
 | |
|   CutlassGemmConfigMannager()
 | |
|       : json_(nullptr), load_initialized_(false), save_initialized_(true) {}
 | |
|   ~CutlassGemmConfigMannager() {
 | |
|     std::lock_guard<std::mutex> lock(mutex_);
 | |
|     if (save_initialized_) {
 | |
|       std::string config_file_path = "fp8_fuse_gemm_config.json";
 | |
|       save_gemm_best_configs_(config_file_path);
 | |
|     }
 | |
|     save_initialized_ = true;
 | |
|     load_initialized_ = false;
 | |
|     json_.clear();
 | |
|   }
 | |
|   mutable std::mutex mutex_;
 | |
|   nlohmann::json json_;
 | |
|   bool load_initialized_;
 | |
|   bool save_initialized_;
 | |
| };
 | 
