Files
FastDeploy/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu
celsowm 771e71a24d Feat/blackwell sm100 support (#2670)
* Add initial support for NVIDIA Blackwell (SM100) architecture

This change introduces initial support for the NVIDIA Blackwell GPU
architecture, specifically targeting SM100 (Compute Capability 10.x)
with '100a' architecture-specific features (e.g., for CUTLASS).

Key changes:
- Updated custom_ops/setup_ops.py to generate appropriate gencode
  flags (arch=compute_100a,code=sm_100a) when '100' is specified
  in FD_BUILDING_ARCS. Requires CUDA 12.9+.
- Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h:
    - Added CutlassTileConfigSM100 enum (with placeholder tile shapes).
    - Added BLACKWELL to CandidateConfigTypeParam.
    - Updated CutlassGemmConfig struct with is_sm100 flag,
      tile_config_sm100, and new constructor for SM100.
    - Modified toString() and fromString() for SM100 support.
- Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu:
    - Added get_candidate_tiles_sm100() (with placeholder tiles).
    - Added placeholder mcast support functions for SM100.
    - Updated get_candidate_configs() to include SM100 paths using
      the BLACKWELL flag and new SM100 config types.
- Updated build.sh with comments to guide users on specifying '100'
  for Blackwell in FD_BUILDING_ARCS.

Further work:
- Optimal CUTLASS tile configurations for SM100 need to be researched
  and updated in cutlass_heuristic.cu.
- Kernel auto-generation scripts in custom_ops/utils/ may need
  SM100-specific versions if Blackwell's hardware features for FP8/TMA
  differ significantly from SM90.
- Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM)
  with Blackwell should be fully verified.

* Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics

This change integrates specific, expert-provided CUTLASS heuristic
configurations for the NVIDIA Blackwell (SM100) GPU architecture,
replacing previous placeholders. This includes:

- Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`:
    - Populated `CutlassTileConfigSM100` enum with specific tile shapes
      (e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100.
    - Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths.

- Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`:
    - Implemented `get_candidate_tiles_sm100` with detailed logic for
      selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags,
      using the new SM100 tile enums.
    - Implemented `supports_mcast_along_m_sm100` and
      `supports_mcast_along_n_sm100` with specific tile checks for Blackwell.
    - Updated the `sm == 100` (Blackwell) block in `get_candidate_configs`
      to use these new helper functions and accurately populate candidate
      kernel configurations for various cluster shapes.

- `custom_ops/setup_ops.py` remains configured to compile for
  `arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features.

This aligns the codebase with heuristic configurations similar to those
in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more
performant kernel selection on this new architecture.

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-09 15:29:42 +08:00

526 lines
21 KiB
Plaintext

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#ifndef _WIN32
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#include "cutlass/gemm/gemm.h"
#include "cutlass/numeric_types.h"
#ifndef _WIN32
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#include <cuda_runtime_api.h>
#include <set>
#include <vector>
#include "cutlass_kernels/cutlass_heuristic.h"
using namespace cutlass_extensions;
namespace kernels
{
namespace cutlass_kernels
{
struct TileShape
{
int m;
int n;
};
TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
{
switch (tile_config)
{
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128};
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256};
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128};
case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileShape{128, 64};
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128};
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256};
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128};
case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: return TileShape{16, 256};
default: throw("[get_grid_shape_for_config] Invalid config");
}
}
bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape,
int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only)
{
// All tile sizes have a k_tile of 64.
static constexpr int k_tile = 128;
// For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
if (is_weight_only)
{
if ((k % k_tile) != 0)
{
return false;
}
if ((k % split_k_factor) != 0)
{
return false;
}
int const k_elements_per_split = k / split_k_factor;
if ((k_elements_per_split % k_tile) != 0)
{
return false;
}
}
// Check that the workspace has sufficient space for this split-k factor
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
if (required_ws_bytes > workspace_bytes)
{
return false;
}
return true;
}
std::vector<CutlassTileConfig> get_candidate_tiles(
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{
enum class CutlassGemmType : char
{
Default,
WeightOnly,
Simt,
Int8,
Fp8
};
CutlassGemmType gemm_type = CutlassGemmType::Default;
if (config_type_param & CutlassGemmConfig::SIMT_ONLY)
{
gemm_type = CutlassGemmType::Simt;
}
else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY)
{
gemm_type = CutlassGemmType::WeightOnly;
}
else if (config_type_param & CutlassGemmConfig::INT8_ONLY)
{
gemm_type = CutlassGemmType::Int8;
}
else if (config_type_param & CutlassGemmConfig::FP8_ONLY)
{
gemm_type = CutlassGemmType::Fp8;
}
std::vector<CutlassTileConfig> base_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64};
if (sm >= 75)
{
base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64);
}
switch (gemm_type)
{
case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
case CutlassGemmType::WeightOnly:
if (sm >= 75)
{
return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
}
else
{
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64};
}
case CutlassGemmType::Int8:
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
case CutlassGemmType::Fp8:
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM)
{
if (sm == 89)
{
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
}
else
{
// no valid ampere style fp8 configs for sm90
return {};
}
}
default: return base_configs;
}
}
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
#ifdef FAST_BUILD
// Fast build disables all configs except this one for SM90
return {CutlassTileConfigSM90::CtaShape128x128x128B};
#else
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
}
else
{
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
}
#endif
}
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_m(CutlassTileConfigSM90 const tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B,
CutlassTileConfigSM90::CtaShape256x128x128B};
return valid_tiles.count(tile) == 1;
#endif
}
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
return valid_tiles.count(tile) == 1;
#endif
}
// SM100 (Blackwell) candidate tile configurations
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
#ifdef FAST_BUILD
return {CutlassTileConfigSM100::CtaShape128x128x128B};
#else
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
{
return {
/* 1 SM (M=128) */
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM (M=256) */
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B,
/* slim tiles for very tall matrices */
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape256x64x128B};
}
/* Fp8 / Fp16 grouped-GEMM */
return {
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
}
/* Non-grouped path (plain GEMM or weight-only) */
return {
/* 1 SM tiles */
CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM tiles */
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
#endif
}
// M-multicast support for SM100.
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> m_tiles{
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
return m_tiles.count(tile) == 1;
#endif
}
// N-multicast support for SM100.
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> n_tiles{
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B};
return n_tiles.count(tile) == 1;
#endif
}
std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{
if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER))
{
std::vector<CutlassTileConfigSM90> tiles = get_candidate_tiles_sm90(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config : tiles)
{
CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
bool const has_m_mcast = supports_mcast_along_m(tile_config);
bool const has_n_mcast = supports_mcast_along_n(tile_config);
if (has_m_mcast)
{
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(config);
}
if (has_n_mcast)
{
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(config);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(config);
}
}
return candidate_configs;
}
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
{
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config_sm100 : tiles)
{
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
// Cluster shapes are also handled similarly.
CutlassGemmConfig config(
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
if (has_m_mcast)
{
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(mcast_m_config);
}
if (has_n_mcast)
{
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(mcast_n_config);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(mcast_mn_config);
}
}
return candidate_configs;
}
// Fallback to older architecture configurations
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
// It's fine here as it's within an else if / else block.
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (auto const& tile_config : tiles)
{
for (int stages = min_stages; stages <= max_stages; ++stages)
{
CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages);
candidate_configs.push_back(config);
if (sm >= 75)
{
for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor)
{
auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages};
candidate_configs.push_back(config);
}
}
}
}
return candidate_configs;
}
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts,
int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only)
{
if (occupancies.size() != candidate_configs.size())
{
throw(
"[estimate_best_config_from_occupancies] occpancies and "
"candidate configs vectors must have equal length.");
}
CutlassGemmConfig best_config;
// Score will be [0, 1]. The objective is to minimize this score.
// It represents the fraction of SM resources unused in the last wave.
float config_score = 1.0f;
int config_waves = INT_MAX;
int current_m_tile = 0;
int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (int ii = 0; ii < candidate_configs.size(); ++ii)
{
CutlassGemmConfig candidate_config = candidate_configs[ii];
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
int occupancy = occupancies[ii];
if (occupancy == 0)
{
continue;
}
// Keep small tile sizes when possible.
if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile
&& current_m_tile < tile_shape.m)
{
continue;
}
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor)
{
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only))
{
int const ctas_per_wave = occupancy * multi_processor_count;
int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
float const current_score = float(num_waves_total) - num_waves_fractional;
float const score_slack = 0.1f;
if (current_score < config_score
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack)))
{
config_score = current_score;
config_waves = num_waves_total;
SplitKStyle split_style
= split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
}
else if (current_score == config_score
&& (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor
|| current_m_tile < tile_shape.m))
{
// Prefer deeper pipeline or smaller split-k
SplitKStyle split_style
= split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
config_waves = num_waves_total;
}
}
}
}
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic)
{
throw("Heurisitc failed to find a valid config.");
}
return best_config;
}
} // namespace cutlass_kernels
} // namespace kernels