diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 54a144974..9c5e7bfc4 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -277,7 +277,8 @@ struct MoeFCGemm { code_scale(const_cast(code_scale)), code_zp(const_cast(code_zp)), host_problem_sizes(nullptr) { - if (quant_method != WintQuantMethod::kNone || platform::is_same::value || + if (quant_method != WintQuantMethod::kNone || + platform::is_same::value || platform::is_same::value) { assert(weight_scales); } @@ -380,7 +381,8 @@ struct MoeFCGemm { } static Status can_implement(Arguments const& args) { - if (args.quant_method != WintQuantMethod::kNone || platform::is_same::value || + if (args.quant_method != WintQuantMethod::kNone || + platform::is_same::value || platform::is_same::value) { if (args.weight_scales == nullptr) { CUTLASS_TRACE_HOST( @@ -416,7 +418,6 @@ struct MoeFCGemm { template struct KernelRunner { - CUTLASS_DEVICE static void run_kernel(Params const& params, SharedStorage& shared_storage) { // NOLINT @@ -471,9 +472,13 @@ struct MoeFCGemm { int64_t rows_to_jump = 0; if (params.problem_visitor.total_rows < 0) { - rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + rows_to_jump = problem_idx == 0 + ? 0 + : params.problem_visitor + .last_row_for_problem[problem_idx - 1]; } else { - rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count); + rows_to_jump = problem_idx * (params.problem_visitor.total_rows / + params.problem_visitor.problem_count); } // begin address offset for A for current tile @@ -496,11 +501,13 @@ struct MoeFCGemm { 0, }; - // the begin threadblock_offset of B, which holds the same column id with C + // the begin threadblock_offset of B, which holds the same column id + // with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id + // the begin threadblock_offset of scale, which holds the same column id + // with C, but with no row id cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; // Compute position within threadblock @@ -628,7 +635,7 @@ struct MoeFCGemm { static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910) +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010) static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); @@ -649,9 +656,17 @@ template -struct Wint2xMoeFCGemm : public MoeFCGemm { +struct Wint2xMoeFCGemm : public MoeFCGemm { public: - using Base = MoeFCGemm; + using Base = MoeFCGemm; using Mma = Mma_; using Epilogue = Epilogue_; using EpilogueOutputOp = typename Epilogue::OutputOp; @@ -711,7 +726,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(params.local_scale + offset_in_bytes); + uint4b_t* local_scale_ptr = + reinterpret_cast(params.local_scale + offset_in_bytes); - typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale( - Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2), - local_scale_ptr, - {(gemm_k + 127) / 128, gemm_n * 2}, - thread_idx, - tb_offset_local_scale); + typename Mma::QuantParamsAccessor::IteratorLocalScale + iterator_local_scale( + Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2), + local_scale_ptr, + {(gemm_k + 127) / 128, gemm_n * 2}, + thread_idx, + tb_offset_local_scale); float* code_scale_ptr = params.code_scale + problem_idx * gemm_n; - typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale( - Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), - code_scale_ptr, - {1, gemm_n}, - thread_idx, - tb_offset_scale); + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp + iterator_code_scale( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); float* code_zp_ptr = params.code_zp + problem_idx * gemm_n; typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp( @@ -819,8 +849,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::CaclPackedDim(gemm_k); - int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + // wint2.5 and wint2.0 is quantized and packed along k dimension with + // group_size 64. + const int64_t quant_gemm_k = + WintQuantTraits::CaclPackedDim(gemm_k); + int64_t bytes_per_expert_matrix = + (quant_gemm_k * gemm_n / 8) * + cutlass::sizeof_bits::value; // Outer 'persistent' loop to iterate over tiles while (problem_visitor.next_tile()) { @@ -881,9 +918,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(byte_ptr_B); @@ -904,9 +946,12 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(params.ptr_C) + problem_idx * gemm_n : nullptr; + ElementC* ptr_C = params.ptr_C + ? reinterpret_cast(params.ptr_C) + + problem_idx * gemm_n + : nullptr; ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; @@ -1012,8 +1060,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 800) && (__CUDA_ARCH__ < 910) - KernelRunner::run_kernel(params, shared_storage); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010) + KernelRunner::run_kernel( + params, shared_storage); #else CUTLASS_NOT_IMPLEMENTED(); #endif diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index dfa7927de..db5af4f49 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -26,10 +26,10 @@ #include #include "cutlass/array.h" -#include "cutlass/trace.h" -#include "cutlass/numeric_conversion.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/trace.h" #include "paddle/common/errors.h" #include "paddle/phi/core/enforce.h" @@ -63,24 +63,28 @@ struct CutlassLayoutB { using Type = cutlass::layout::RowMajor; }; -template +template struct CutlassGemmKernel { - using Type = - cutlass::gemm::kernel::MoeFCGemm; + using Type = cutlass::gemm::kernel::MoeFCGemm< + typename BaseGemmKernel::Mma, + typename BaseGemmKernel::Epilogue, + typename BaseGemmKernel::ThreadblockSwizzle, + Arch, + BaseGemmKernel::kGroupScheduleMode>; }; template -struct CutlassGemmKernel { - using Type = - cutlass::gemm::kernel::Wint2xMoeFCGemm; +struct CutlassGemmKernel { + using Type = cutlass::gemm::kernel::Wint2xMoeFCGemm< + typename BaseGemmKernel::Mma, + typename BaseGemmKernel::Epilogue, + typename BaseGemmKernel::ThreadblockSwizzle, + Arch, + BaseGemmKernel::kGroupScheduleMode>; }; // ======================= Variable batched Gemm things ======================= @@ -91,21 +95,22 @@ template -void generic_moe_gemm_kernelLauncher(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - const int multi_processor_count, - cudaStream_t stream, - int* kernel_occupancy = nullptr) { +void generic_moe_gemm_kernelLauncher( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + const int multi_processor_count, + cudaStream_t stream, + int* kernel_occupancy = nullptr) { if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { throw std::runtime_error("[MoeGemm] Grouped gemm does not support split-k"); } @@ -128,12 +133,14 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)"); + "Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t " + "(wint4), uint16_t (wint2.5)"); // The cutlass type for the input elements. This is needed to convert to // cutlass::half_t if necessary. using ElementType = typename cutlass::CutlassDataType::Type; - using CutlassWeightType = typename cutlass::CutlassDataType::Type; + using CutlassWeightType = typename cutlass::CutlassDataType< + typename WeightQuantTraits::WeightType>::Type; using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType; using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType; @@ -155,7 +162,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassMmaKernelType, - typename CutlassLayoutB::Type, + typename CutlassLayoutB::Type, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, @@ -172,7 +180,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; - using GemmKernel = typename CutlassGemmKernel::Type; + using GemmKernel = + typename CutlassGemmKernel::Type; using GemmGrouped = cutlass::gemm::device::GemmGrouped; if (kernel_occupancy != nullptr) { @@ -194,7 +205,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const uint8_t* local_scale_B = nullptr; const float* code_scale_B = nullptr; const float* code_zp_B = nullptr; - if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { + if constexpr (WeightQuantTraits::kQuantMethod == + cutlass::WintQuantMethod::kWeightOnlyInt2) { local_scale_B = quant_args_B.local_scale_ptr; code_scale_B = quant_args_B.code_scale_ptr; code_zp_B = quant_args_B.code_zp_ptr; @@ -253,21 +265,22 @@ template struct dispatch_stages { - static void dispatch(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { // FT_LOG_DEBUG(__PRETTY_FUNCTION__); std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + @@ -289,21 +302,22 @@ struct dispatch_stages { - static void dispatch(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { generic_moe_gemm_kernelLauncher 2)>::type> { - static void dispatch(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { generic_moe_gemm_kernelLauncher -void dispatch_gemm_config(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +void dispatch_gemm_config( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { #define dispatch_stages_macro(STAGE) \ case STAGE: \ dispatch_stages::value && - std::is_same::value>::type* = - nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +template < + typename T, + typename WeightQuantTraits, + typename arch, + typename EpilogueTag, + typename std::enable_if< + !std::is_same::value && + std::is_same::value>::type* = + nullptr> +void dispatch_moe_gemm_to_cutlass( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); dispatch_gemm_config_macro(64, 128, 64, 32, 64, 64); dispatch_gemm_config_macro(128, 128, 64, 64, 32, 64); case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( @@ -518,32 +538,36 @@ template ::value && - !std::is_same::value>::type* = - nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + typename std::enable_if< + !std::is_same::value && + !std::is_same::value>:: + type* = nullptr> +void dispatch_moe_gemm_to_cutlass( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { if constexpr (std::is_same::value) { - if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) { + if constexpr (WeightQuantTraits::kQuantMethod != + cutlass::WintQuantMethod::kWeightOnlyInt2) { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( @@ -558,7 +582,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, } } else { throw std::runtime_error( - "[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70."); + "[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support " + "sm70."); } } else { switch (gemm_config.tile_config) { @@ -574,7 +599,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( @@ -597,22 +623,23 @@ template < typename arch, typename EpilogueTag, typename std::enable_if::value>::type* = nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +void dispatch_moe_gemm_to_cutlass( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); case CutlassTileConfig::Undefined: @@ -659,33 +686,34 @@ void MoeGemmRunner::dispatch_to_arch( CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) { -#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ +#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ dispatch_moe_gemm_to_cutlass( \ - A, \ - B, \ - weight_scales, \ - biases, \ - C, \ - total_rows_before_expert, \ - total_rows, \ - gemm_n, \ - gemm_k, \ - num_experts, \ - quant_args_B, \ - gemm_config, \ - sm_, \ - multi_processor_count_, \ - stream, \ + A, \ + B, \ + weight_scales, \ + biases, \ + C, \ + total_rows_before_expert, \ + total_rows, \ + gemm_n, \ + gemm_k, \ + num_experts, \ + quant_args_B, \ + gemm_config, \ + sm_, \ + multi_processor_count_, \ + stream, \ occupancy); if (sm_ >= 70 && sm_ < 75) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm70); } else if (sm_ >= 75 && sm_ < 80) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm75); - } else if (sm_ >= 80 && sm_ < 91) { + } else if (sm_ >= 80 && sm_ < 101) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm80); } else { - throw std::runtime_error("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + throw std::runtime_error( + "[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -705,7 +733,8 @@ void MoeGemmRunner::run_gemm( int num_experts, const typename WeightQuantTraits::Arguments& quant_args_B, cudaStream_t stream) { - static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool is_weight_only = + !std::is_same::value; static constexpr bool only_simt_configs = std::is_same::value; std::vector candidate_configs = @@ -776,7 +805,8 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); check_cuda_error(cudaEventDestroy(start)); check_cuda_error(cudaEventDestroy(stop)); - //std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; + // std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " + // ms" << std::endl; if (elapsed < best_time) { best_id = ii; best_time = elapsed; @@ -789,7 +819,8 @@ void MoeGemmRunner::run_gemm( } } if (find_one) { - //std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl; + // std::cout << "[TUNING] best_config: " << best_id << ", time: " << + // best_time << " ms" << std::endl; gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config); chosen_config = best_config; } else {