Files
FastDeploy/custom_ops/gpu_ops/tune_cublaslt_gemm.cu
2025-07-19 23:19:27 +08:00

814 lines
30 KiB
Plaintext

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <sys/time.h>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <list>
#include <vector>
#include "helper.h"
template <typename T>
void handleError(T status, const char* file, int line) {
printf("Unknown error type at %s:%d\n", file, line);
exit(1);
}
// for cudaError_t
template <>
void handleError<cudaError_t>(cudaError_t status, const char* file, int line) {
if (status != cudaSuccess) {
printf(
"CUDA error at %s:%d - %s\n", file, line, cudaGetErrorString(status));
exit(1);
}
}
// for cublasStatus_t
template <>
void handleError<cublasStatus_t>(cublasStatus_t status,
const char* file,
int line) {
if (status != CUBLAS_STATUS_SUCCESS) {
printf(
"cuBLAS error at %s:%d - %d\n", file, line, static_cast<int>(status));
exit(1);
}
}
#define CUDA_CHECK(call) \
do { \
handleError((call), __FILE__, __LINE__); \
} while (0)
typedef struct {
cublasLtMatmulAlgo_t algo;
cublasStatus_t status;
float time;
size_t workspaceSize;
cublasMath_t mathMode;
cublasLtReductionScheme_t reductionScheme;
int customOption;
float wavesCount;
} customMatmulPerf_t;
typedef struct {
cublasLtMatmulAlgo_t algo;
int m;
int n;
int k;
int algoId;
int swizzle;
int customOption;
int tile;
int splitK_val;
int reductionScheme;
int stages;
size_t workspaceSize;
float time;
} algoSelect_t;
inline double diffTime(
const std::chrono::high_resolution_clock::time_point& start,
const std::chrono::high_resolution_clock::time_point& end) {
return std::chrono::duration<double, std::milli>(end - start).count();
}
const int splitKSequenceA[] = {1, 2, 3, 4, 5, 6, 8, 12, 16, 32};
static inline bool time_compare_perf(const customMatmulPerf_t& perf_a,
const customMatmulPerf_t& perf_b) {
return ((perf_a.status == CUBLAS_STATUS_SUCCESS) &&
(perf_a.time < perf_b.time));
}
static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a,
const algoSelect_t& algo_para_b) {
return (algo_para_a.time < algo_para_b.time);
}
// 获取当前 GPU 的剩余显存大小(以字节为单位)
size_t get_remaining_memory() {
size_t free, total;
CUDA_CHECK(cudaMemGetInfo(&free, &total));
return free;
}
template <typename InT, typename OutT, typename ScaleT = OutT>
static void TestMatmulRun(cublasLtHandle_t ltHandle,
cublasLtMatmulDesc_t matmulDesc,
cublasLtMatrixLayout_t A_desc,
cublasLtMatrixLayout_t B_desc,
cublasLtMatrixLayout_t C_desc,
const InT* A,
const InT* B,
OutT* C,
const cublasLtMatmulAlgo_t& algo,
customMatmulPerf_t& perfResults,
cudaEvent_t& startEvent,
cudaEvent_t& stopEvent) {
cudaStream_t stream = 0;
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult);
auto remainingMemorySize = 0.95 * get_remaining_memory();
if (algoStatus == CUBLAS_STATUS_SUCCESS &&
remainingMemorySize > heurResult.workspaceSize) {
ScaleT alpha = static_cast<ScaleT>(1), beta = static_cast<ScaleT>(0);
void* workSpace;
CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize));
CUDA_CHECK(cudaEventRecord(startEvent, stream));
int repeats = 100;
for (int loop = 0; loop < repeats; loop++) {
cublasStatus_t currStatus = cublasLtMatmul(ltHandle,
matmulDesc,
&alpha,
A,
A_desc,
B,
B_desc,
&beta,
C,
C_desc,
C,
C_desc,
&algo,
workSpace,
heurResult.workspaceSize,
stream);
if (currStatus != CUBLAS_STATUS_SUCCESS) {
perfResults.status = currStatus;
break;
}
CUDA_CHECK(cudaDeviceSynchronize());
}
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaEventRecord(stopEvent, stream));
CUDA_CHECK(cudaEventSynchronize(stopEvent));
float time;
CUDA_CHECK(cudaEventElapsedTime(&time, startEvent, stopEvent));
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
perfResults.algo = algo;
perfResults.time = time / repeats;
perfResults.workspaceSize = heurResult.workspaceSize;
perfResults.wavesCount = heurResult.wavesCount;
}
CUDA_CHECK(cudaFree(workSpace));
} else {
std::cerr << "Not enough workspace! Required "
<< static_cast<double>(heurResult.workspaceSize) / 1024.0 /
1024.0 / 1024.0
<< " GiB" << ", But remaining "
<< static_cast<double>(remainingMemorySize) / 1024.0 / 1024.0 /
1024.0
<< " GiB" << std::endl;
perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
}
}
template <typename InT, typename OutT, typename ScaleT = OutT>
void FindAlgo(const cublasLtHandle_t& ltHandle,
int m,
int n,
int k,
const InT* A,
const InT* B,
OutT* C,
cublasLtMatmulDesc_t matmulDesc,
cublasLtMatrixLayout_t A_desc,
cublasLtMatrixLayout_t B_desc,
cublasLtMatrixLayout_t C_desc,
cublasComputeType_t computeType,
cudaDataType_t scaleType,
cudaDataType_t Atype,
cudaDataType_t Btype,
cudaDataType_t Ctype,
std::vector<algoSelect_t>& algos,
const std::string& path) {
// Get Ids
// https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds
// Input
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
int AlgoCount = 0;
int AlgoCombinations = 20000;
std::vector<customMatmulPerf_t> perfResultsTmp;
// Output
int algoIdA[100];
int nbAlgoIds;
CUDA_CHECK(cublasLtMatmulAlgoGetIds(ltHandle,
computeType,
scaleType,
Atype,
Btype,
Ctype,
Ctype,
100,
algoIdA,
&nbAlgoIds));
std::clog << std::endl << "get " << nbAlgoIds << " algoIds" << std::endl;
for (int idx = 0; idx < nbAlgoIds; idx++) {
cublasLtMatmulAlgo_t algo;
std::clog << "Process algo: " << algoIdA[idx] << " ";
/* Initialize algo structure with given Algp ID */
// https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit
CUDA_CHECK(cublasLtMatmulAlgoInit(ltHandle,
computeType,
scaleType,
Atype,
Btype,
Ctype,
Ctype,
algoIdA[idx],
&algo));
// Query the tiles enums supported by that algo which is used to alloc
// enough space to store it
// https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute
cublasLtMatmulTile_t tileA[CUBLASLT_MATMUL_TILE_END];
size_t nbTiles, sizeWritten;
CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(tileA), &sizeWritten));
nbTiles = sizeWritten / sizeof(tileA[0]);
// Query the stages enums supported by that algo (cuda must >= 11.0)
CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute(
&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten));
int nbStages = int(sizeWritten / sizeof(uint32_t));
std::vector<uint32_t> stagesA(nbStages == 0 ? 1 : nbStages);
if (nbStages == 0) {
stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED;
nbStages = 1;
} else {
CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute(&algo,
CUBLASLT_ALGO_CAP_STAGES_IDS,
stagesA.data(),
sizeof(uint32_t) * nbStages,
&sizeWritten));
}
// Retrieve Other Algo Capabilities attributes
int32_t splitkSupport, customOptionMax;
uint32_t redMask, swizzlingMax;
// cublasLtMatmulInnerShape_t innerShape;
CUDA_CHECK(
cublasLtMatmulAlgoCapGetAttribute(&algo,
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT,
&splitkSupport,
sizeof(splitkSupport),
&sizeWritten));
std::clog << "splitkSupport: " << splitkSupport << std::endl;
CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute(
&algo,
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK,
&redMask,
sizeof(redMask),
&sizeWritten));
CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute(
&algo,
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT,
&swizzlingMax,
sizeof(swizzlingMax),
&sizeWritten));
CUDA_CHECK(
cublasLtMatmulAlgoCapGetAttribute(&algo,
CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX,
&customOptionMax,
sizeof(customOptionMax),
&sizeWritten));
/* Loop over the different tiles */
for (int tileIdx = 0; tileIdx < nbTiles && AlgoCount < AlgoCombinations;
tileIdx++) {
/* Loop over different stages count */
for (int stagesIdx = 0;
stagesIdx < nbStages && AlgoCount < AlgoCombinations;
stagesIdx++) {
/* Loop over the different custom option if any */
for (int32_t customOption = 0;
customOption <= customOptionMax && AlgoCount < AlgoCombinations;
customOption++) {
/* Loop over the CTAs swizzling support */
for (uint32_t k = 0;
k <= swizzlingMax && AlgoCount < AlgoCombinations;
k++) {
int splitK_trial = 0;
if (splitkSupport) {
splitK_trial +=
sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]);
}
for (int l = 0;
(l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations);
l++) {
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_TILE_ID,
&tileA[tileIdx],
sizeof(tileA[tileIdx])));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_STAGES_ID,
&stagesA[stagesIdx],
sizeof(stagesA[stagesIdx])));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&customOption,
sizeof(customOption)));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)));
int splitK_val = 1;
uint32_t redScheme = CUBLASLT_REDUCTION_SCHEME_NONE;
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitK_val,
sizeof(splitK_val)));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme,
sizeof(int)));
if (l > 0) { // Split-K case
splitK_val = splitKSequenceA[l - 1];
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&splitKSequenceA[l - 1],
sizeof(splitKSequenceA[l - 1])));
for (redScheme = 1;
redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK &&
(AlgoCount < AlgoCombinations);
redScheme <<= 1) {
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&redScheme,
sizeof(redScheme)));
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus =
cublasLtMatmulAlgoCheck(ltHandle,
matmulDesc,
A_desc,
B_desc,
C_desc,
C_desc,
&algo,
&heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algoSelect_t algoSelect;
algoSelect.algo = algo;
algoSelect.m = m;
algoSelect.n = n;
algoSelect.k = k;
algoSelect.algoId = algoIdA[idx];
algoSelect.tile = tileA[tileIdx];
algoSelect.swizzle = k;
algoSelect.customOption = customOption;
algoSelect.splitK_val = splitK_val;
algoSelect.reductionScheme = redScheme;
algoSelect.stages = stagesA[stagesIdx];
algos.push_back(algoSelect);
AlgoCount++;
}
}
} else {
// Prepare algos
cublasLtMatmulHeuristicResult_t heurResult;
// https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck
cublasStatus_t algoStatus =
cublasLtMatmulAlgoCheck(ltHandle,
matmulDesc,
A_desc,
B_desc,
C_desc,
C_desc,
&algo,
&heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
algoSelect_t algoSelect;
algoSelect.algo = algo;
algoSelect.m = m;
algoSelect.n = n;
algoSelect.k = k;
algoSelect.algoId = algoIdA[idx];
algoSelect.tile = tileA[tileIdx];
algoSelect.swizzle = k;
algoSelect.customOption = customOption;
algoSelect.splitK_val = splitK_val;
algoSelect.reductionScheme = redScheme;
algoSelect.stages = stagesA[stagesIdx];
algos.push_back(algoSelect);
AlgoCount++;
}
}
}
}
}
}
}
}
std::clog << "Got " << AlgoCount << " algos" << std::endl;
cudaEvent_t startEvent;
cudaEvent_t stopEvent;
std::vector<customMatmulPerf_t> perfResults(AlgoCount);
CUDA_CHECK(cudaEventCreate(&startEvent));
CUDA_CHECK(cudaEventCreate(&stopEvent));
for (int i = 0; i < AlgoCount; i++) {
TestMatmulRun<InT, OutT, ScaleT>(ltHandle,
matmulDesc,
A_desc,
B_desc,
C_desc,
A,
B,
C,
algos[i].algo,
perfResults[i],
startEvent,
stopEvent);
algos[i].workspaceSize = perfResults[i].workspaceSize;
algos[i].time = perfResults[i].time;
if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) {
std::clog << "algo " << algos[i].algoId << " tile " << algos[i].tile
<< " stages " << algos[i].stages << " splitK_val "
<< algos[i].splitK_val << std::endl;
algos[i].time = std::numeric_limits<float>::max();
std::cerr << " TestMatmulRun with status " << perfResults[i].status
<< std::endl;
continue;
}
}
std::sort(algos.begin(), algos.end(), time_compare_algo_para);
int i = 0;
while (algos[i].time == 0) i++;
std::ofstream outfile;
outfile.open(path, std::ios::app);
outfile << m << "," << k << "," << n << "," << algos[i].algoId << ","
<< algos[i].swizzle << "," << algos[i].customOption << ","
<< algos[i].tile << "," << algos[i].splitK_val << ","
<< algos[i].reductionScheme << "," << algos[i].stages << ","
<< algos[i].workspaceSize << "," << algos[i].time << "\n";
outfile.close();
}
class DevContext {};
class CPUContext : public DevContext {};
class CUBLASLTContext : public DevContext {
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }
cublasLtHandle_t handle;
};
template <typename InT, typename OutT, typename DevContext>
void GEMMInt8(const DevContext& dev_ctx,
const std::vector<InT>& A,
const std::vector<InT>& B,
std::vector<OutT>& C,
int m,
int k,
int n,
bool is_test,
bool is_read_from_file = false,
const std::string& path = "search.csv") {
std::cerr << "Base Class is not implemented" << std::endl;
}
template <>
void GEMMInt8<int8_t, int32_t, CPUContext>(const CPUContext& dev_ctx,
const std::vector<int8_t>& A,
const std::vector<int8_t>& B,
std::vector<int32_t>& C,
int m,
int k,
int n,
bool is_test,
bool is_read_from_file,
const std::string& path) {
std::cerr << "CPUContext Class is not implemented" << std::endl;
}
template <>
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
const std::vector<int8_t>& AVec,
const std::vector<int8_t>& BVec,
std::vector<int32_t>& CVec,
int m,
int k,
int n,
bool is_test,
bool is_read_from_file,
const std::string& path) {
int8_t* A_dev;
int8_t* B_dev;
int32_t* C_dev;
char* workSpace;
CUDA_CHECK(cudaMalloc((void**)&A_dev, AVec.size() * sizeof(int8_t)));
CUDA_CHECK(cudaMalloc((void**)&B_dev, BVec.size() * sizeof(int8_t)));
CUDA_CHECK(cudaMalloc((void**)&C_dev, m * n * sizeof(int32_t)));
CUDA_CHECK(
cudaMemcpy(A_dev, AVec.data(), AVec.size(), cudaMemcpyHostToDevice));
CUDA_CHECK(
cudaMemcpy(B_dev, BVec.data(), BVec.size(), cudaMemcpyHostToDevice));
// init data structure
cublasLtMatmulDesc_t matmul_desc;
cublasLtMatrixLayout_t A_desc;
cublasLtMatrixLayout_t B_desc;
cublasLtMatrixLayout_t C_desc;
int32_t alpha = 1;
int32_t beta = 0;
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
CUDA_CHECK(
cublasLtMatmulDescCreate(&matmul_desc, cudaComputeType, CUDA_R_32I));
cublasOperation_t op_transpose = CUBLAS_OP_T;
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&op_transpose,
sizeof(op_transpose)));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc, CUDA_R_8I, k, n, k));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, k, m, k));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc, CUDA_R_32I, n, m, n));
cublasLtMatmulAlgo_t algo;
int algoId;
int swizzle;
int customOption;
int tile;
int splitK_val;
int reductionScheme;
int stages;
size_t work_space_size = 0;
float time_ref;
auto using_default_config = [&]() {
algoId = 21;
swizzle = 0;
customOption = 0;
tile = 15;
splitK_val = 0;
reductionScheme = 0;
stages = 23;
if (m >= 128) {
tile = 20;
stages = 17;
}
};
if (is_test) {
std::vector<algoSelect_t> algos;
// Select //
FindAlgo(dev_ctx.handle,
m,
n,
k,
B_dev,
A_dev,
C_dev,
matmul_desc,
B_desc,
A_desc,
C_desc,
CUBLAS_COMPUTE_32I,
CUDA_R_32I,
CUDA_R_8I,
CUDA_R_8I,
CUDA_R_32I,
algos,
path);
int i = 0;
while (algos[i].time == 0) i++;
algoId = algos[i].algoId;
swizzle = algos[i].swizzle;
customOption = algos[i].customOption;
tile = algos[i].tile;
splitK_val = algos[i].splitK_val;
reductionScheme = algos[i].reductionScheme;
stages = algos[i].stages;
work_space_size = algos[i].workspaceSize;
} else if (is_read_from_file) {
int m_tmp, k_tmp, n_tmp;
std::ifstream file(path);
if (!file.is_open()) {
std::cout << "file not open. Now we use default params" << std::endl;
using_default_config();
} else {
bool match = false;
int find_cnt = 0;
std::string line;
while (std::getline(file, line)) {
std::istringstream iss(line);
char comma;
if (iss >> m_tmp >> comma >> k_tmp >> comma >> n_tmp >> comma >>
algoId >> comma >> swizzle >> comma >> customOption >> comma >>
tile >> comma >> splitK_val >> comma >> reductionScheme >> comma >>
stages >> comma >> work_space_size >> comma >> time_ref) {
if (k_tmp == k && n_tmp == n && m <= m_tmp) {
match = true;
break;
}
find_cnt++;
}
}
if (find_cnt == 0) {
std::cout << "the file is empty. Now we use default params"
<< std::endl;
using_default_config();
}
}
} else {
std::cout << "Please use test mode to select\n, Now we use default params"
<< std::endl;
using_default_config();
}
CUDA_CHECK(cudaMalloc((void**)&workSpace, work_space_size));
CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle,
cudaComputeType,
CUDA_R_32I,
CUDA_R_8I,
CUDA_R_8I,
CUDA_R_32I,
CUDA_R_32I,
algoId,
&algo));
CUDA_CHECK(
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption),
sizeof(customOption)));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)));
CUDA_CHECK(
cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val),
sizeof(splitK_val)));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme),
sizeof(int)));
CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)));
auto start = std::chrono::high_resolution_clock::now();
const int repeats = 10;
for (int loop = 0; loop < repeats; loop++) {
CUDA_CHECK(cublasLtMatmul(dev_ctx.handle,
matmul_desc,
&alpha,
B_dev,
B_desc,
A_dev,
A_desc,
&beta,
C_dev,
C_desc,
C_dev,
C_desc,
&algo,
// nullptr,
workSpace,
// 0,
work_space_size,
0));
CUDA_CHECK(cudaDeviceSynchronize());
}
CUDA_CHECK(cudaDeviceSynchronize());
auto end = std::chrono::high_resolution_clock::now();
double time = diffTime(start, end);
auto now = std::chrono::system_clock::now();
std::time_t now_time_t = std::chrono::system_clock::to_time_t(now);
std::tm now_tm = *std::localtime(&now_time_t);
std::cout << "GEMM with cublaslt imma1 int8 spend " << time / repeats
<< " ms in " << m << ", " << k << ", " << n
<< ", current time: " << std::put_time(&now_tm, "%H:%M:%S")
<< std::endl;
CUDA_CHECK(cudaFree(A_dev));
CUDA_CHECK(cudaFree(B_dev));
CUDA_CHECK(cudaFree(C_dev));
CUDA_CHECK(cudaFree(workSpace));
}
void TuneCublasltGemm(const paddle::Tensor& K,
const paddle::Tensor& N,
const int M_start,
const int M_end,
const std::string& dtype,
const bool is_test,
const bool is_read_from_file,
const std::string& path) {
assert(M_end >= M_start);
assert(M_start >= 1);
assert(K.dims().size() == 1 && N.dims().size() == 1);
assert(is_test != is_read_from_file);
auto K_cpu = K.copy_to(paddle::CPUPlace(), false);
auto N_cpu = N.copy_to(paddle::CPUPlace(), false);
int64_t* K_data = K_cpu.data<int64_t>();
int64_t* N_data = N_cpu.data<int64_t>();
int K_size = K.numel();
int N_size = N.numel();
assert(K_size == N_size);
std::vector<int> mm;
int m = M_start, step = 1;
while (m <= M_end) {
// update step
if (m >= 8192) {
step = 4096;
} else if (m >= 1024) {
step = 1024;
} else if (m >= 512) {
step = 128;
} else if (m >= 256) {
step = 64;
} else if (m >= 64) {
step = 32;
} else if (m >= 16) {
step = 16;
} else if (m >= 4) {
step = 4;
} else {
step = 1;
}
mm.push_back(m);
m += step;
}
for (int j = 0; j < mm.size(); j++) {
int m = mm[j];
for (int i = 0; i < K_size; ++i) {
int n = (int)N_data[i];
int k = (int)K_data[i];
auto A = std::vector<int8_t>(m * k);
auto B = std::vector<int8_t>(k * n);
auto C = std::vector<int32_t>(m * n);
if (dtype == "int8") {
CUBLASLTContext dev_ctx;
GEMMInt8(dev_ctx,
A,
B,
C,
m,
k,
n,
is_test, /*is_test*/
is_read_from_file, /*is_read_from_file*/
path);
} else {
// other dtype
throw std::runtime_error(dtype + "not currently supported");
}
}
}
}
PD_BUILD_STATIC_OP(tune_cublaslt_gemm)
.Inputs({"K", "N"})
.Outputs({})
.Attrs({"M_start: int",
"M_end: int",
"dtype: std::string",
"is_test: bool",
"is_read_from_file: bool",
"path: std::string"})
.SetKernelFn(PD_KERNEL(TuneCublasltGemm));