// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include #include #include #include #include "helper.h" template void handleError(T status, const char* file, int line) { printf("Unknown error type at %s:%d\n", file, line); exit(1); } // for cudaError_t template <> void handleError(cudaError_t status, const char* file, int line) { if (status != cudaSuccess) { printf( "CUDA error at %s:%d - %s\n", file, line, cudaGetErrorString(status)); exit(1); } } // for cublasStatus_t template <> void handleError(cublasStatus_t status, const char* file, int line) { if (status != CUBLAS_STATUS_SUCCESS) { printf( "cuBLAS error at %s:%d - %d\n", file, line, static_cast(status)); exit(1); } } #define CUDA_CHECK(call) \ do { \ handleError((call), __FILE__, __LINE__); \ } while (0) typedef struct { cublasLtMatmulAlgo_t algo; cublasStatus_t status; float time; size_t workspaceSize; cublasMath_t mathMode; cublasLtReductionScheme_t reductionScheme; int customOption; float wavesCount; } customMatmulPerf_t; typedef struct { cublasLtMatmulAlgo_t algo; int m; int n; int k; int algoId; int swizzle; int customOption; int tile; int splitK_val; int reductionScheme; int stages; size_t workspaceSize; float time; } algoSelect_t; inline double diffTime( const std::chrono::high_resolution_clock::time_point& start, const std::chrono::high_resolution_clock::time_point& end) { return std::chrono::duration(end - start).count(); } const int splitKSequenceA[] = {1, 2, 3, 4, 5, 6, 8, 12, 16, 32}; static inline bool time_compare_perf(const customMatmulPerf_t& perf_a, const customMatmulPerf_t& perf_b) { return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); } static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a, const algoSelect_t& algo_para_b) { return (algo_para_a.time < algo_para_b.time); } // 获取当前 GPU 的剩余显存大小(以字节为单位) size_t get_remaining_memory() { size_t free, total; CUDA_CHECK(cudaMemGetInfo(&free, &total)); return free; } template static void TestMatmulRun(cublasLtHandle_t ltHandle, cublasLtMatmulDesc_t matmulDesc, cublasLtMatrixLayout_t A_desc, cublasLtMatrixLayout_t B_desc, cublasLtMatrixLayout_t C_desc, const InT* A, const InT* B, OutT* C, const cublasLtMatmulAlgo_t& algo, customMatmulPerf_t& perfResults, cudaEvent_t& startEvent, cudaEvent_t& stopEvent) { cudaStream_t stream = 0; cublasLtMatmulHeuristicResult_t heurResult; cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult); auto remainingMemorySize = 0.95 * get_remaining_memory(); if (algoStatus == CUBLAS_STATUS_SUCCESS && remainingMemorySize > heurResult.workspaceSize) { ScaleT alpha = static_cast(1), beta = static_cast(0); void* workSpace; CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize)); CUDA_CHECK(cudaEventRecord(startEvent, stream)); int repeats = 100; for (int loop = 0; loop < repeats; loop++) { cublasStatus_t currStatus = cublasLtMatmul(ltHandle, matmulDesc, &alpha, A, A_desc, B, B_desc, &beta, C, C_desc, C, C_desc, &algo, workSpace, heurResult.workspaceSize, stream); if (currStatus != CUBLAS_STATUS_SUCCESS) { perfResults.status = currStatus; break; } CUDA_CHECK(cudaDeviceSynchronize()); } CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaEventRecord(stopEvent, stream)); CUDA_CHECK(cudaEventSynchronize(stopEvent)); float time; CUDA_CHECK(cudaEventElapsedTime(&time, startEvent, stopEvent)); if (algoStatus == CUBLAS_STATUS_SUCCESS) { perfResults.algo = algo; perfResults.time = time / repeats; perfResults.workspaceSize = heurResult.workspaceSize; perfResults.wavesCount = heurResult.wavesCount; } CUDA_CHECK(cudaFree(workSpace)); } else { std::cerr << "Not enough workspace! Required " << static_cast(heurResult.workspaceSize) / 1024.0 / 1024.0 / 1024.0 << " GiB" << ", But remaining " << static_cast(remainingMemorySize) / 1024.0 / 1024.0 / 1024.0 << " GiB" << std::endl; perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace } } template void FindAlgo(const cublasLtHandle_t& ltHandle, int m, int n, int k, const InT* A, const InT* B, OutT* C, cublasLtMatmulDesc_t matmulDesc, cublasLtMatrixLayout_t A_desc, cublasLtMatrixLayout_t B_desc, cublasLtMatrixLayout_t C_desc, cublasComputeType_t computeType, cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype, cudaDataType_t Ctype, std::vector& algos, const std::string& path) { // Get Ids // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds // Input cublasStatus_t status = CUBLAS_STATUS_SUCCESS; int AlgoCount = 0; int AlgoCombinations = 20000; std::vector perfResultsTmp; // Output int algoIdA[100]; int nbAlgoIds; CUDA_CHECK(cublasLtMatmulAlgoGetIds(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, 100, algoIdA, &nbAlgoIds)); std::clog << std::endl << "get " << nbAlgoIds << " algoIds" << std::endl; for (int idx = 0; idx < nbAlgoIds; idx++) { cublasLtMatmulAlgo_t algo; std::clog << "Process algo: " << algoIdA[idx] << " "; /* Initialize algo structure with given Algp ID */ // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit CUDA_CHECK(cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo)); // Query the tiles enums supported by that algo which is used to alloc // enough space to store it // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute cublasLtMatmulTile_t tileA[CUBLASLT_MATMUL_TILE_END]; size_t nbTiles, sizeWritten; CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute( &algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(tileA), &sizeWritten)); nbTiles = sizeWritten / sizeof(tileA[0]); // Query the stages enums supported by that algo (cuda must >= 11.0) CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute( &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten)); int nbStages = int(sizeWritten / sizeof(uint32_t)); std::vector stagesA(nbStages == 0 ? 1 : nbStages); if (nbStages == 0) { stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; nbStages = 1; } else { CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(uint32_t) * nbStages, &sizeWritten)); } // Retrieve Other Algo Capabilities attributes int32_t splitkSupport, customOptionMax; uint32_t redMask, swizzlingMax; // cublasLtMatmulInnerShape_t innerShape; CUDA_CHECK( cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten)); std::clog << "splitkSupport: " << splitkSupport << std::endl; CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute( &algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten)); CUDA_CHECK(cublasLtMatmulAlgoCapGetAttribute( &algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten)); CUDA_CHECK( cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten)); /* Loop over the different tiles */ for (int tileIdx = 0; tileIdx < nbTiles && AlgoCount < AlgoCombinations; tileIdx++) { /* Loop over different stages count */ for (int stagesIdx = 0; stagesIdx < nbStages && AlgoCount < AlgoCombinations; stagesIdx++) { /* Loop over the different custom option if any */ for (int32_t customOption = 0; customOption <= customOptionMax && AlgoCount < AlgoCombinations; customOption++) { /* Loop over the CTAs swizzling support */ for (uint32_t k = 0; k <= swizzlingMax && AlgoCount < AlgoCombinations; k++) { int splitK_trial = 0; if (splitkSupport) { splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); } for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx]))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx]))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k))); int splitK_val = 1; uint32_t redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int))); if (l > 0) { // Split-K case splitK_val = splitKSequenceA[l - 1]; CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitKSequenceA[l - 1], sizeof(splitKSequenceA[l - 1]))); for (redScheme = 1; redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); redScheme <<= 1) { CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(redScheme))); cublasLtMatmulHeuristicResult_t heurResult; cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult); if (algoStatus == CUBLAS_STATUS_SUCCESS) { algoSelect_t algoSelect; algoSelect.algo = algo; algoSelect.m = m; algoSelect.n = n; algoSelect.k = k; algoSelect.algoId = algoIdA[idx]; algoSelect.tile = tileA[tileIdx]; algoSelect.swizzle = k; algoSelect.customOption = customOption; algoSelect.splitK_val = splitK_val; algoSelect.reductionScheme = redScheme; algoSelect.stages = stagesA[stagesIdx]; algos.push_back(algoSelect); AlgoCount++; } } } else { // Prepare algos cublasLtMatmulHeuristicResult_t heurResult; // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult); if (algoStatus == CUBLAS_STATUS_SUCCESS) { algoSelect_t algoSelect; algoSelect.algo = algo; algoSelect.m = m; algoSelect.n = n; algoSelect.k = k; algoSelect.algoId = algoIdA[idx]; algoSelect.tile = tileA[tileIdx]; algoSelect.swizzle = k; algoSelect.customOption = customOption; algoSelect.splitK_val = splitK_val; algoSelect.reductionScheme = redScheme; algoSelect.stages = stagesA[stagesIdx]; algos.push_back(algoSelect); AlgoCount++; } } } } } } } } std::clog << "Got " << AlgoCount << " algos" << std::endl; cudaEvent_t startEvent; cudaEvent_t stopEvent; std::vector perfResults(AlgoCount); CUDA_CHECK(cudaEventCreate(&startEvent)); CUDA_CHECK(cudaEventCreate(&stopEvent)); for (int i = 0; i < AlgoCount; i++) { TestMatmulRun(ltHandle, matmulDesc, A_desc, B_desc, C_desc, A, B, C, algos[i].algo, perfResults[i], startEvent, stopEvent); algos[i].workspaceSize = perfResults[i].workspaceSize; algos[i].time = perfResults[i].time; if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) { std::clog << "algo " << algos[i].algoId << " tile " << algos[i].tile << " stages " << algos[i].stages << " splitK_val " << algos[i].splitK_val << std::endl; algos[i].time = std::numeric_limits::max(); std::cerr << " TestMatmulRun with status " << perfResults[i].status << std::endl; continue; } } std::sort(algos.begin(), algos.end(), time_compare_algo_para); int i = 0; while (algos[i].time == 0) i++; std::ofstream outfile; outfile.open(path, std::ios::app); outfile << m << "," << k << "," << n << "," << algos[i].algoId << "," << algos[i].swizzle << "," << algos[i].customOption << "," << algos[i].tile << "," << algos[i].splitK_val << "," << algos[i].reductionScheme << "," << algos[i].stages << "," << algos[i].workspaceSize << "," << algos[i].time << "\n"; outfile.close(); } class DevContext {}; class CPUContext : public DevContext {}; class CUBLASLTContext : public DevContext { public: CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); } cublasLtHandle_t handle; }; template void GEMMInt8(const DevContext& dev_ctx, const std::vector& A, const std::vector& B, std::vector& C, int m, int k, int n, bool is_test, bool is_read_from_file = false, const std::string& path = "search.csv") { std::cerr << "Base Class is not implemented" << std::endl; } template <> void GEMMInt8(const CPUContext& dev_ctx, const std::vector& A, const std::vector& B, std::vector& C, int m, int k, int n, bool is_test, bool is_read_from_file, const std::string& path) { std::cerr << "CPUContext Class is not implemented" << std::endl; } template <> void GEMMInt8(const CUBLASLTContext& dev_ctx, const std::vector& AVec, const std::vector& BVec, std::vector& CVec, int m, int k, int n, bool is_test, bool is_read_from_file, const std::string& path) { int8_t* A_dev; int8_t* B_dev; int32_t* C_dev; char* workSpace; CUDA_CHECK(cudaMalloc((void**)&A_dev, AVec.size() * sizeof(int8_t))); CUDA_CHECK(cudaMalloc((void**)&B_dev, BVec.size() * sizeof(int8_t))); CUDA_CHECK(cudaMalloc((void**)&C_dev, m * n * sizeof(int32_t))); CUDA_CHECK( cudaMemcpy(A_dev, AVec.data(), AVec.size(), cudaMemcpyHostToDevice)); CUDA_CHECK( cudaMemcpy(B_dev, BVec.data(), BVec.size(), cudaMemcpyHostToDevice)); // init data structure cublasLtMatmulDesc_t matmul_desc; cublasLtMatrixLayout_t A_desc; cublasLtMatrixLayout_t B_desc; cublasLtMatrixLayout_t C_desc; int32_t alpha = 1; int32_t beta = 0; cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; CUDA_CHECK( cublasLtMatmulDescCreate(&matmul_desc, cudaComputeType, CUDA_R_32I)); cublasOperation_t op_transpose = CUBLAS_OP_T; CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &op_transpose, sizeof(op_transpose))); CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc, CUDA_R_8I, k, n, k)); CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, k, m, k)); CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc, CUDA_R_32I, n, m, n)); cublasLtMatmulAlgo_t algo; int algoId; int swizzle; int customOption; int tile; int splitK_val; int reductionScheme; int stages; size_t work_space_size = 0; float time_ref; auto using_default_config = [&]() { algoId = 21; swizzle = 0; customOption = 0; tile = 15; splitK_val = 0; reductionScheme = 0; stages = 23; if (m >= 128) { tile = 20; stages = 17; } }; if (is_test) { std::vector algos; // Select // FindAlgo(dev_ctx.handle, m, n, k, B_dev, A_dev, C_dev, matmul_desc, B_desc, A_desc, C_desc, CUBLAS_COMPUTE_32I, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, algos, path); int i = 0; while (algos[i].time == 0) i++; algoId = algos[i].algoId; swizzle = algos[i].swizzle; customOption = algos[i].customOption; tile = algos[i].tile; splitK_val = algos[i].splitK_val; reductionScheme = algos[i].reductionScheme; stages = algos[i].stages; work_space_size = algos[i].workspaceSize; } else if (is_read_from_file) { int m_tmp, k_tmp, n_tmp; std::ifstream file(path); if (!file.is_open()) { std::cout << "file not open. Now we use default params" << std::endl; using_default_config(); } else { bool match = false; int find_cnt = 0; std::string line; while (std::getline(file, line)) { std::istringstream iss(line); char comma; if (iss >> m_tmp >> comma >> k_tmp >> comma >> n_tmp >> comma >> algoId >> comma >> swizzle >> comma >> customOption >> comma >> tile >> comma >> splitK_val >> comma >> reductionScheme >> comma >> stages >> comma >> work_space_size >> comma >> time_ref) { if (k_tmp == k && n_tmp == n && m <= m_tmp) { match = true; break; } find_cnt++; } } if (find_cnt == 0) { std::cout << "the file is empty. Now we use default params" << std::endl; using_default_config(); } } } else { std::cout << "Please use test mode to select\n, Now we use default params" << std::endl; using_default_config(); } CUDA_CHECK(cudaMalloc((void**)&workSpace, work_space_size)); CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle, cudaComputeType, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo)); CUDA_CHECK( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(customOption), sizeof(customOption))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile))); CUDA_CHECK( cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(splitK_val), sizeof(splitK_val))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int))); CUDA_CHECK(cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages))); auto start = std::chrono::high_resolution_clock::now(); const int repeats = 10; for (int loop = 0; loop < repeats; loop++) { CUDA_CHECK(cublasLtMatmul(dev_ctx.handle, matmul_desc, &alpha, B_dev, B_desc, A_dev, A_desc, &beta, C_dev, C_desc, C_dev, C_desc, &algo, // nullptr, workSpace, // 0, work_space_size, 0)); CUDA_CHECK(cudaDeviceSynchronize()); } CUDA_CHECK(cudaDeviceSynchronize()); auto end = std::chrono::high_resolution_clock::now(); double time = diffTime(start, end); auto now = std::chrono::system_clock::now(); std::time_t now_time_t = std::chrono::system_clock::to_time_t(now); std::tm now_tm = *std::localtime(&now_time_t); std::cout << "GEMM with cublaslt imma1 int8 spend " << time / repeats << " ms in " << m << ", " << k << ", " << n << ", current time: " << std::put_time(&now_tm, "%H:%M:%S") << std::endl; CUDA_CHECK(cudaFree(A_dev)); CUDA_CHECK(cudaFree(B_dev)); CUDA_CHECK(cudaFree(C_dev)); CUDA_CHECK(cudaFree(workSpace)); } void TuneCublasltGemm(const paddle::Tensor& K, const paddle::Tensor& N, const int M_start, const int M_end, const std::string& dtype, const bool is_test, const bool is_read_from_file, const std::string& path) { assert(M_end >= M_start); assert(M_start >= 1); assert(K.dims().size() == 1 && N.dims().size() == 1); assert(is_test != is_read_from_file); auto K_cpu = K.copy_to(paddle::CPUPlace(), false); auto N_cpu = N.copy_to(paddle::CPUPlace(), false); int64_t* K_data = K_cpu.data(); int64_t* N_data = N_cpu.data(); int K_size = K.numel(); int N_size = N.numel(); assert(K_size == N_size); std::vector mm; int m = M_start, step = 1; while (m <= M_end) { // update step if (m >= 8192) { step = 4096; } else if (m >= 1024) { step = 1024; } else if (m >= 512) { step = 128; } else if (m >= 256) { step = 64; } else if (m >= 64) { step = 32; } else if (m >= 16) { step = 16; } else if (m >= 4) { step = 4; } else { step = 1; } mm.push_back(m); m += step; } for (int j = 0; j < mm.size(); j++) { int m = mm[j]; for (int i = 0; i < K_size; ++i) { int n = (int)N_data[i]; int k = (int)K_data[i]; auto A = std::vector(m * k); auto B = std::vector(k * n); auto C = std::vector(m * n); if (dtype == "int8") { CUBLASLTContext dev_ctx; GEMMInt8(dev_ctx, A, B, C, m, k, n, is_test, /*is_test*/ is_read_from_file, /*is_read_from_file*/ path); } else { // other dtype throw std::runtime_error(dtype + "not currently supported"); } } } } PD_BUILD_STATIC_OP(tune_cublaslt_gemm) .Inputs({"K", "N"}) .Outputs({}) .Attrs({"M_start: int", "M_end: int", "dtype: std::string", "is_test: bool", "is_read_from_file: bool", "path: std::string"}) .SetKernelFn(PD_KERNEL(TuneCublasltGemm));