Feat/blackwell sm100 support (#2670)

* Add initial support for NVIDIA Blackwell (SM100) architecture

This change introduces initial support for the NVIDIA Blackwell GPU
architecture, specifically targeting SM100 (Compute Capability 10.x)
with '100a' architecture-specific features (e.g., for CUTLASS).

Key changes:
- Updated custom_ops/setup_ops.py to generate appropriate gencode
  flags (arch=compute_100a,code=sm_100a) when '100' is specified
  in FD_BUILDING_ARCS. Requires CUDA 12.9+.
- Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h:
    - Added CutlassTileConfigSM100 enum (with placeholder tile shapes).
    - Added BLACKWELL to CandidateConfigTypeParam.
    - Updated CutlassGemmConfig struct with is_sm100 flag,
      tile_config_sm100, and new constructor for SM100.
    - Modified toString() and fromString() for SM100 support.
- Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu:
    - Added get_candidate_tiles_sm100() (with placeholder tiles).
    - Added placeholder mcast support functions for SM100.
    - Updated get_candidate_configs() to include SM100 paths using
      the BLACKWELL flag and new SM100 config types.
- Updated build.sh with comments to guide users on specifying '100'
  for Blackwell in FD_BUILDING_ARCS.

Further work:
- Optimal CUTLASS tile configurations for SM100 need to be researched
  and updated in cutlass_heuristic.cu.
- Kernel auto-generation scripts in custom_ops/utils/ may need
  SM100-specific versions if Blackwell's hardware features for FP8/TMA
  differ significantly from SM90.
- Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM)
  with Blackwell should be fully verified.

* Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics

This change integrates specific, expert-provided CUTLASS heuristic
configurations for the NVIDIA Blackwell (SM100) GPU architecture,
replacing previous placeholders. This includes:

- Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`:
    - Populated `CutlassTileConfigSM100` enum with specific tile shapes
      (e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100.
    - Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths.

- Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`:
    - Implemented `get_candidate_tiles_sm100` with detailed logic for
      selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags,
      using the new SM100 tile enums.
    - Implemented `supports_mcast_along_m_sm100` and
      `supports_mcast_along_n_sm100` with specific tile checks for Blackwell.
    - Updated the `sm == 100` (Blackwell) block in `get_candidate_configs`
      to use these new helper functions and accurately populate candidate
      kernel configurations for various cluster shapes.

- `custom_ops/setup_ops.py` remains configured to compile for
  `arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features.

This aligns the codebase with heuristic configurations similar to those
in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more
performant kernel selection on this new architecture.

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
celsowm
2025-07-09 04:29:42 -03:00
committed by GitHub
parent 0350831c2b
commit 771e71a24d
4 changed files with 308 additions and 52 deletions

View File

@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
PYTHON_VERSION=${2:-"python"} PYTHON_VERSION=${2:-"python"}
export python=$PYTHON_VERSION export python=$PYTHON_VERSION
FD_CPU_USE_BF16=${3:-"false"} FD_CPU_USE_BF16=${3:-"false"}
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
# These will be translated to 90a / 100a in setup_ops.py for specific features.
FD_BUILDING_ARCS=${4:-""} FD_BUILDING_ARCS=${4:-""}

View File

@@ -76,6 +76,34 @@ enum class SplitKStyle
// SPLIT_K_PARALLEL // Not supported yet // SPLIT_K_PARALLEL // Not supported yet
}; };
// New enum for SM100 (Blackwell) Tile Configs
// Placeholder values - actual optimal values need research
enum class CutlassTileConfigSM100
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// Actual SM100 tile configs based on user input (K-tile is 128B)
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
// These are already covered by the list above if general suffices.
// If they need distinct enum values, they should be added.
// For now, keeping the enum concise with unique shapes mentioned for general use.
};
enum class CutlassTileConfigSM90 enum class CutlassTileConfigSM90
{ {
// Signals that we should run heuristics do choose a config // Signals that we should run heuristics do choose a config
@@ -132,9 +160,11 @@ struct CutlassGemmConfig
WEIGHT_ONLY = 1u << 0, WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1, SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2, INT8_ONLY = 1u << 2,
HOPPER = 1u << 3, HOPPER = 1u << 3, // SM90
GROUPED_GEMM = 1u << 4, GROUPED_GEMM = 1u << 4,
FP8_ONLY = 1u << 5, FP8_ONLY = 1u << 5,
BLACKWELL = 1u << 6, // SM100
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
}; };
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
@@ -149,7 +179,17 @@ struct CutlassGemmConfig
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool is_sm90 = false; bool is_sm90 = false;
CutlassGemmConfig() {} // config options for sm100 (Blackwell)
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
// These might need to become SM100-specific if Blackwell introduces new concepts.
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
bool is_sm100 = false;
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
: tile_config(tile_config) : tile_config(tile_config)
@@ -157,37 +197,64 @@ struct CutlassGemmConfig
, split_k_factor(split_k_factor) , split_k_factor(split_k_factor)
, stages(stages) , stages(stages)
, is_sm90(false) , is_sm90(false)
, is_sm100(false)
{ {
} }
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, // Constructor for SM90
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
: tile_config_sm90(tile_config_sm90) EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
, mainloop_schedule(mainloop_schedule) : tile_config_sm90(tile_config_sm90_in)
, epilogue_schedule(epilogue_schedule) , mainloop_schedule(mainloop_schedule_in)
, cluster_shape(cluster_shape) , epilogue_schedule(epilogue_schedule_in)
, cluster_shape(cluster_shape_in)
, is_sm90(true) , is_sm90(true)
, is_sm100(false)
{ {
} }
// Constructor for SM100 (Blackwell)
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
: tile_config_sm100(tile_config_sm100_in)
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
, is_sm90(false) // Explicitly false
, is_sm100(true)
{
}
std::string toString() const std::string toString() const
{ {
std::stringstream tactic; std::stringstream tactic;
tactic << "Cutlass GEMM Tactic"; tactic << "Cutlass GEMM Tactic";
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
{ {
assert(is_sm90 && "Invalid cutlass GEMM config"); assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
tactic << "\n\tstyle=TMA" tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\ttile shape ID: " << (int) tile_config_sm100
<< "\n\tcluster shape ID: " << (int) cluster_shape << "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
tactic << "\n\tstyle=TMA_SM90"
<< "\n\ttile shape ID: " << (int) tile_config_sm90
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tepi sched: " << (int) epilogue_schedule; << "\n\tepi sched: " << (int) epilogue_schedule;
} }
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{ {
assert(!is_sm90 && "Invalid cutlass GEMM config"); assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
tactic << "\n\tstyle=compatible" tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int) tile_config << "\n\ttile shape ID: " << (int) tile_config
<< "\n\tstages: " << (int) stages << "\n\tstages: " << (int) stages
<< "\n\tsplit_k_style: " << (int) split_k_style << "\n\tsplit_k_style: " << (int) split_k_style
<< "\n\tsplit k: " << (int) split_k_factor; << "\n\tsplit k: " << (int) split_k_factor;
@@ -204,9 +271,24 @@ struct CutlassGemmConfig
std::istringstream stream(str); std::istringstream stream(str);
std::string line; std::string line;
is_sm90 = false; // Reset flags
is_sm100 = false;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
if (line.find("style=TMA") != std::string::npos) { if (line.find("style=TMA_SM100") != std::string::npos) {
is_sm100 = true;
is_sm90 = false;
std::getline(stream, line);
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
is_sm90 = true; is_sm90 = true;
is_sm100 = false;
std::getline(stream, line); std::getline(stream, line);
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1))); tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line); std::getline(stream, line);
@@ -217,6 +299,7 @@ struct CutlassGemmConfig
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1))); epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
} else if (line.find("style=compatible") != std::string::npos) { } else if (line.find("style=compatible") != std::string::npos) {
is_sm90 = false; is_sm90 = false;
is_sm100 = false;
std::getline(stream, line); std::getline(stream, line);
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1))); tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line); std::getline(stream, line);
@@ -233,7 +316,14 @@ struct CutlassGemmConfig
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
{ {
// clang-format off // clang-format off
if (config.is_sm90) if (config.is_sm100)
{
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
}
else if (config.is_sm90)
{ {
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) << ", mainloop_schedule_enum: " << int(config.mainloop_schedule)

View File

@@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
#endif #endif
} }
// SM100 (Blackwell) candidate tile configurations
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
#ifdef FAST_BUILD
return {CutlassTileConfigSM100::CtaShape128x128x128B};
#else
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
{
return {
/* 1 SM (M=128) */
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM (M=256) */
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B,
/* slim tiles for very tall matrices */
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape256x64x128B};
}
/* Fp8 / Fp16 grouped-GEMM */
return {
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
}
/* Non-grouped path (plain GEMM or weight-only) */
return {
/* 1 SM tiles */
CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM tiles */
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
#endif
}
// M-multicast support for SM100.
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> m_tiles{
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
return m_tiles.count(tile) == 1;
#endif
}
// N-multicast support for SM100.
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> n_tiles{
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B};
return n_tiles.count(tile) == 1;
#endif
}
std::vector<CutlassGemmConfig> get_candidate_configs( std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{ {
@@ -284,9 +366,50 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
} }
return candidate_configs; return candidate_configs;
} }
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param); else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
{
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
std::vector<CutlassGemmConfig> candidate_configs; for (auto const& tile_config_sm100 : tiles)
{
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
// Cluster shapes are also handled similarly.
CutlassGemmConfig config(
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
if (has_m_mcast)
{
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(mcast_m_config);
}
if (has_n_mcast)
{
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(mcast_n_config);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(mcast_mn_config);
}
}
return candidate_configs;
}
// Fallback to older architecture configurations
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
// It's fine here as it's within an else if / else block.
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
int const min_stages = int8_configs_only ? 3 : 2; int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);

View File

@@ -157,12 +157,18 @@ def get_gencode_flags(archs):
""" """
cc_s = get_sm_version(archs) cc_s = get_sm_version(archs)
flags = [] flags = []
for cc in cc_s: for cc_val in cc_s:
if cc == 90: if cc_val == 90:
cc = f"{cc}a" arch_code = "90a"
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
else: else:
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags return flags
@@ -395,43 +401,77 @@ elif paddle.is_compiled_with_cuda():
if cc >= 89: if cc >= 89:
# Running generate fp8 gemm codes. # Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"] nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [ nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen" "-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
] ]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
if cc < 90:
if cc >= 90: # Hopper and newer
# SM90 (Hopper) specific auto-generation and flags
if cc == 90: # Only for SM90
nvcc_compile_args += [
# The gencode for 90a is added in get_gencode_flags now
# "-gencode",
# "arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
]
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
print("SM100 (Blackwell): Applying SM100 configurations.")
nvcc_compile_args += [
# The gencode for 100a is added in get_gencode_flags
# "-gencode",
# "arch=compute_100a,code=compute_100a",
"-O3", # Common optimization flag
"-DNDEBUG", # Common debug flag
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
]
# Placeholder for SM100-specific kernel auto-generation scripts
# These might be needed if Blackwell has new FP8 hardware features
# not covered by existing generic CUTLASS templates or SM90 scripts.
# print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).")
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system( os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") "python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else:
nvcc_compile_args += [
"-gencode",
"arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG",
]
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
# Common FP8 sources for SM89+
sources += [ sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",