Feat/blackwell sm100 support (#2670)

* Add initial support for NVIDIA Blackwell (SM100) architecture

This change introduces initial support for the NVIDIA Blackwell GPU
architecture, specifically targeting SM100 (Compute Capability 10.x)
with '100a' architecture-specific features (e.g., for CUTLASS).

Key changes:
- Updated custom_ops/setup_ops.py to generate appropriate gencode
  flags (arch=compute_100a,code=sm_100a) when '100' is specified
  in FD_BUILDING_ARCS. Requires CUDA 12.9+.
- Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h:
    - Added CutlassTileConfigSM100 enum (with placeholder tile shapes).
    - Added BLACKWELL to CandidateConfigTypeParam.
    - Updated CutlassGemmConfig struct with is_sm100 flag,
      tile_config_sm100, and new constructor for SM100.
    - Modified toString() and fromString() for SM100 support.
- Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu:
    - Added get_candidate_tiles_sm100() (with placeholder tiles).
    - Added placeholder mcast support functions for SM100.
    - Updated get_candidate_configs() to include SM100 paths using
      the BLACKWELL flag and new SM100 config types.
- Updated build.sh with comments to guide users on specifying '100'
  for Blackwell in FD_BUILDING_ARCS.

Further work:
- Optimal CUTLASS tile configurations for SM100 need to be researched
  and updated in cutlass_heuristic.cu.
- Kernel auto-generation scripts in custom_ops/utils/ may need
  SM100-specific versions if Blackwell's hardware features for FP8/TMA
  differ significantly from SM90.
- Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM)
  with Blackwell should be fully verified.

* Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics

This change integrates specific, expert-provided CUTLASS heuristic
configurations for the NVIDIA Blackwell (SM100) GPU architecture,
replacing previous placeholders. This includes:

- Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`:
    - Populated `CutlassTileConfigSM100` enum with specific tile shapes
      (e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100.
    - Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths.

- Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`:
    - Implemented `get_candidate_tiles_sm100` with detailed logic for
      selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags,
      using the new SM100 tile enums.
    - Implemented `supports_mcast_along_m_sm100` and
      `supports_mcast_along_n_sm100` with specific tile checks for Blackwell.
    - Updated the `sm == 100` (Blackwell) block in `get_candidate_configs`
      to use these new helper functions and accurately populate candidate
      kernel configurations for various cluster shapes.

- `custom_ops/setup_ops.py` remains configured to compile for
  `arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features.

This aligns the codebase with heuristic configurations similar to those
in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more
performant kernel selection on this new architecture.

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
celsowm
2025-07-09 04:29:42 -03:00
committed by GitHub
parent 0350831c2b
commit 771e71a24d
4 changed files with 308 additions and 52 deletions

View File

@@ -157,12 +157,18 @@ def get_gencode_flags(archs):
"""
cc_s = get_sm_version(archs)
flags = []
for cc in cc_s:
if cc == 90:
cc = f"{cc}a"
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
for cc_val in cc_s:
if cc_val == 90:
arch_code = "90a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
else:
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags
@@ -395,43 +401,77 @@ elif paddle.is_compiled_with_cuda():
if cc >= 89:
# Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
if cc < 90:
if cc >= 90: # Hopper and newer
# SM90 (Hopper) specific auto-generation and flags
if cc == 90: # Only for SM90
nvcc_compile_args += [
# The gencode for 90a is added in get_gencode_flags now
# "-gencode",
# "arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
]
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
print("SM100 (Blackwell): Applying SM100 configurations.")
nvcc_compile_args += [
# The gencode for 100a is added in get_gencode_flags
# "-gencode",
# "arch=compute_100a,code=compute_100a",
"-O3", # Common optimization flag
"-DNDEBUG", # Common debug flag
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
]
# Placeholder for SM100-specific kernel auto-generation scripts
# These might be needed if Blackwell has new FP8 hardware features
# not covered by existing generic CUTLASS templates or SM90 scripts.
# print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).")
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else:
nvcc_compile_args += [
"-gencode",
"arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG",
]
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
# Common FP8 sources for SM89+
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",