diff --git a/build.sh b/build.sh index 7b71b5a64..af63bb414 100644 --- a/build.sh +++ b/build.sh @@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1} PYTHON_VERSION=${2:-"python"} export python=$PYTHON_VERSION FD_CPU_USE_BF16=${3:-"false"} +# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]". +# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100. +# These will be translated to 90a / 100a in setup_ops.py for specific features. FD_BUILDING_ARCS=${4:-""} diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h b/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h index 9c1e9aa22..81e58f20e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h @@ -76,6 +76,34 @@ enum class SplitKStyle // SPLIT_K_PARALLEL // Not supported yet }; +// New enum for SM100 (Blackwell) Tile Configs +// Placeholder values - actual optimal values need research +enum class CutlassTileConfigSM100 +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // Actual SM100 tile configs based on user input (K-tile is 128B) + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + CtaShape256x64x128B, + CtaShape256x128x128B, + CtaShape256x256x128B + // Note: The user-provided list for get_candidate_tiles_sm100 also includes + // CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases. + // These are already covered by the list above if general suffices. + // If they need distinct enum values, they should be added. + // For now, keeping the enum concise with unique shapes mentioned for general use. +}; + + enum class CutlassTileConfigSM90 { // Signals that we should run heuristics do choose a config @@ -132,9 +160,11 @@ struct CutlassGemmConfig WEIGHT_ONLY = 1u << 0, SIMT_ONLY = 1u << 1, INT8_ONLY = 1u << 2, - HOPPER = 1u << 3, + HOPPER = 1u << 3, // SM90 GROUPED_GEMM = 1u << 4, FP8_ONLY = 1u << 5, + BLACKWELL = 1u << 6, // SM100 + FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths }; CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; @@ -149,7 +179,17 @@ struct CutlassGemmConfig ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; bool is_sm90 = false; - CutlassGemmConfig() {} + // config options for sm100 (Blackwell) + // Assuming SM100 might use similar schedule/cluster types as SM90 for now. + // These might need to become SM100-specific if Blackwell introduces new concepts. + CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic; + // MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types + // EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example + // ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example + bool is_sm100 = false; + + + CutlassGemmConfig() : is_sm90(false), is_sm100(false) {} CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) : tile_config(tile_config) @@ -157,37 +197,64 @@ struct CutlassGemmConfig , split_k_factor(split_k_factor) , stages(stages) , is_sm90(false) + , is_sm100(false) { } - CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, - EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) - : tile_config_sm90(tile_config_sm90) - , mainloop_schedule(mainloop_schedule) - , epilogue_schedule(epilogue_schedule) - , cluster_shape(cluster_shape) + // Constructor for SM90 + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in, + EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in) + : tile_config_sm90(tile_config_sm90_in) + , mainloop_schedule(mainloop_schedule_in) + , epilogue_schedule(epilogue_schedule_in) + , cluster_shape(cluster_shape_in) , is_sm90(true) + , is_sm100(false) { } + // Constructor for SM100 (Blackwell) + // Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now. + // These might need to be new SM100-specific types if Blackwell's TMA differs significantly. + CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in, + EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in) + : tile_config_sm100(tile_config_sm100_in) + , mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge + , epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100 + , cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100 + , is_sm90(false) // Explicitly false + , is_sm100(true) + { + } + + std::string toString() const { std::stringstream tactic; tactic << "Cutlass GEMM Tactic"; - if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) + if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic) { - assert(is_sm90 && "Invalid cutlass GEMM config"); - tactic << "\n\tstyle=TMA" - << "\n\ttile shape ID: " << (int) tile_config_sm90 + assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100"); + tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable + << "\n\ttile shape ID: " << (int) tile_config_sm100 << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule + << "\n\tmainloop sched: " << (int) mainloop_schedule + << "\n\tepi sched: " << (int) epilogue_schedule; + } + else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) + { + assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90"); + tactic << "\n\tstyle=TMA_SM90" + << "\n\ttile shape ID: " << (int) tile_config_sm90 + << "\n\tcluster shape ID: " << (int) cluster_shape + << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule; } else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { - assert(!is_sm90 && "Invalid cutlass GEMM config"); + assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible"); tactic << "\n\tstyle=compatible" - << "\n\ttile shape ID: " << (int) tile_config + << "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages << "\n\tsplit_k_style: " << (int) split_k_style << "\n\tsplit k: " << (int) split_k_factor; @@ -204,9 +271,24 @@ struct CutlassGemmConfig std::istringstream stream(str); std::string line; + is_sm90 = false; // Reset flags + is_sm100 = false; + while (std::getline(stream, line)) { - if (line.find("style=TMA") != std::string::npos) { + if (line.find("style=TMA_SM100") != std::string::npos) { + is_sm100 = true; + is_sm90 = false; + std::getline(stream, line); + tile_config_sm100 = static_cast(std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + cluster_shape = static_cast(std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + mainloop_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + epilogue_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); + } else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first is_sm90 = true; + is_sm100 = false; std::getline(stream, line); tile_config_sm90 = static_cast(std::stoi(line.substr(line.find(':') + 1))); std::getline(stream, line); @@ -217,6 +299,7 @@ struct CutlassGemmConfig epilogue_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); } else if (line.find("style=compatible") != std::string::npos) { is_sm90 = false; + is_sm100 = false; std::getline(stream, line); tile_config = static_cast(std::stoi(line.substr(line.find(':') + 1))); std::getline(stream, line); @@ -233,7 +316,14 @@ struct CutlassGemmConfig inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) { // clang-format off - if (config.is_sm90) + if (config.is_sm100) + { + out << "tile_config_sm100_enum: " << int(config.tile_config_sm100) + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now + << ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now + } + else if (config.is_sm90) { out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu index 5c5e84e02..6db16981c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu @@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) #endif } +// SM100 (Blackwell) candidate tile configurations +std::vector get_candidate_tiles_sm100( + int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config) +{ +#ifdef FAST_BUILD + return {CutlassTileConfigSM100::CtaShape128x128x128B}; +#else + /* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */ + if (config & CutlassGemmConfig::GROUPED_GEMM) + { + if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4 + { + return { + /* 1 SM (M=128) */ + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + /* 2 SM (M=256) */ + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B, + /* slim tiles for very tall matrices */ + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape256x64x128B}; + } + + /* Fp8 / Fp16 grouped-GEMM */ + return { + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; + } + + /* Non-grouped path (plain GEMM or weight-only) */ + return { + /* 1 SM tiles */ + CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + /* 2 SM tiles */ + CutlassTileConfigSM100::CtaShape256x64x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; +#endif +} + +// M-multicast support for SM100. +bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile) +{ +#ifdef FAST_BUILD + return false; +#else + std::set m_tiles{ + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x64x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; + return m_tiles.count(tile) == 1; +#endif +} + +// N-multicast support for SM100. +bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile) +{ +#ifdef FAST_BUILD + return false; +#else + std::set n_tiles{ + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x128x128B}; + return n_tiles.count(tile) == 1; +#endif +} + + std::vector get_candidate_configs( int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { @@ -284,9 +366,50 @@ std::vector get_candidate_configs( } return candidate_configs; } - std::vector tiles = get_candidate_tiles(sm, config_type_param); + else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell + { + std::vector tiles = get_candidate_tiles_sm100(sm, config_type_param); + std::vector candidate_configs; - std::vector candidate_configs; + for (auto const& tile_config_sm100 : tiles) + { + // SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90. + // Cluster shapes are also handled similarly. + CutlassGemmConfig config( + tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100); + bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100); + + if (has_m_mcast) + { + CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(mcast_m_config); + } + + if (has_n_mcast) + { + CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(mcast_n_config); + } + + if (has_m_mcast && has_n_mcast) + { + CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(mcast_mn_config); + } + } + return candidate_configs; + } + + // Fallback to older architecture configurations + std::vector tiles = get_candidate_tiles(sm, config_type_param); + std::vector candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary. + // It's fine here as it's within an else if / else block. bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; int const min_stages = int8_configs_only ? 3 : 2; int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index bb165fc88..3470d9534 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -157,12 +157,18 @@ def get_gencode_flags(archs): """ cc_s = get_sm_version(archs) flags = [] - for cc in cc_s: - if cc == 90: - cc = f"{cc}a" - flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] + for cc_val in cc_s: + if cc_val == 90: + arch_code = "90a" + flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"] + elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x + # Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a' + # https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/ + # "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0" + arch_code = "100a" + flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"] else: - flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] + flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"] return flags @@ -395,43 +401,77 @@ elif paddle.is_compiled_with_cuda(): if cc >= 89: # Running generate fp8 gemm codes. + # Common for SM89, SM90, SM100 (Blackwell) nvcc_compile_args += ["-DENABLE_FP8"] nvcc_compile_args += [ "-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen" ] - + # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS. os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") - if cc < 90: + + if cc >= 90: # Hopper and newer + # SM90 (Hopper) specific auto-generation and flags + if cc == 90: # Only for SM90 + nvcc_compile_args += [ + # The gencode for 90a is added in get_gencode_flags now + # "-gencode", + # "arch=compute_90a,code=compute_90a", + "-O3", + "-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a + ] + print("SM90: Running SM90-specific FP8 kernel auto-generation.") + os.system( + "python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") + os.system( + "python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py" + ) + os.system( + "python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py" + ) + + nvcc_compile_args += [ + "-DENABLE_SCALED_MM_SM90=1", + ] + sources += [ + "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", + "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", + ] + elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics + print("SM100 (Blackwell): Applying SM100 configurations.") + nvcc_compile_args += [ + # The gencode for 100a is added in get_gencode_flags + # "-gencode", + # "arch=compute_100a,code=compute_100a", + "-O3", # Common optimization flag + "-DNDEBUG", # Common debug flag + # Potentially add -DENABLE_SM100_FEATURES if specific macros are identified + ] + # Placeholder for SM100-specific kernel auto-generation scripts + # These might be needed if Blackwell has new FP8 hardware features + # not covered by existing generic CUTLASS templates or SM90 scripts. + # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") + # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example + # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example + + # Add SM100 specific sources if any, e.g., for new hardware intrinsics + # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example + pass # No SM100 specific sources identified yet beyond what CUTLASS handles + else: # For cc >= 89 but not 90 or 100 (e.g. SM89) + print(f"SM{cc}: Running generic FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") + os.system( + "python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") + + else: # For cc == 89 (Ada) + print("SM89: Running generic FP8 kernel auto-generation.") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system( "python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") - else: - nvcc_compile_args += [ - "-gencode", - "arch=compute_90a,code=compute_90a", - "-O3", - "-DNDEBUG", - ] - os.system( - "python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") - os.system( - "python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py" - ) - os.system( - "python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py" - ) - - nvcc_compile_args += [ - "-DENABLE_SCALED_MM_SM90=1", - ] - sources += [ - "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", - "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", - ] + # Common FP8 sources for SM89+ sources += [ "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",