mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Feat/blackwell sm100 support (#2670)
* Add initial support for NVIDIA Blackwell (SM100) architecture This change introduces initial support for the NVIDIA Blackwell GPU architecture, specifically targeting SM100 (Compute Capability 10.x) with '100a' architecture-specific features (e.g., for CUTLASS). Key changes: - Updated custom_ops/setup_ops.py to generate appropriate gencode flags (arch=compute_100a,code=sm_100a) when '100' is specified in FD_BUILDING_ARCS. Requires CUDA 12.9+. - Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h: - Added CutlassTileConfigSM100 enum (with placeholder tile shapes). - Added BLACKWELL to CandidateConfigTypeParam. - Updated CutlassGemmConfig struct with is_sm100 flag, tile_config_sm100, and new constructor for SM100. - Modified toString() and fromString() for SM100 support. - Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu: - Added get_candidate_tiles_sm100() (with placeholder tiles). - Added placeholder mcast support functions for SM100. - Updated get_candidate_configs() to include SM100 paths using the BLACKWELL flag and new SM100 config types. - Updated build.sh with comments to guide users on specifying '100' for Blackwell in FD_BUILDING_ARCS. Further work: - Optimal CUTLASS tile configurations for SM100 need to be researched and updated in cutlass_heuristic.cu. - Kernel auto-generation scripts in custom_ops/utils/ may need SM100-specific versions if Blackwell's hardware features for FP8/TMA differ significantly from SM90. - Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM) with Blackwell should be fully verified. * Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics This change integrates specific, expert-provided CUTLASS heuristic configurations for the NVIDIA Blackwell (SM100) GPU architecture, replacing previous placeholders. This includes: - Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`: - Populated `CutlassTileConfigSM100` enum with specific tile shapes (e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100. - Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths. - Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`: - Implemented `get_candidate_tiles_sm100` with detailed logic for selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags, using the new SM100 tile enums. - Implemented `supports_mcast_along_m_sm100` and `supports_mcast_along_n_sm100` with specific tile checks for Blackwell. - Updated the `sm == 100` (Blackwell) block in `get_candidate_configs` to use these new helper functions and accurately populate candidate kernel configurations for various cluster shapes. - `custom_ops/setup_ops.py` remains configured to compile for `arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features. This aligns the codebase with heuristic configurations similar to those in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more performant kernel selection on this new architecture. --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -157,12 +157,18 @@ def get_gencode_flags(archs):
|
||||
"""
|
||||
cc_s = get_sm_version(archs)
|
||||
flags = []
|
||||
for cc in cc_s:
|
||||
if cc == 90:
|
||||
cc = f"{cc}a"
|
||||
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
|
||||
for cc_val in cc_s:
|
||||
if cc_val == 90:
|
||||
arch_code = "90a"
|
||||
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
|
||||
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
|
||||
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
|
||||
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
|
||||
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
|
||||
arch_code = "100a"
|
||||
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
|
||||
else:
|
||||
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
|
||||
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
|
||||
return flags
|
||||
|
||||
|
||||
@@ -395,43 +401,77 @@ elif paddle.is_compiled_with_cuda():
|
||||
|
||||
if cc >= 89:
|
||||
# Running generate fp8 gemm codes.
|
||||
# Common for SM89, SM90, SM100 (Blackwell)
|
||||
nvcc_compile_args += ["-DENABLE_FP8"]
|
||||
nvcc_compile_args += [
|
||||
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
|
||||
]
|
||||
|
||||
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
|
||||
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
|
||||
if cc < 90:
|
||||
|
||||
if cc >= 90: # Hopper and newer
|
||||
# SM90 (Hopper) specific auto-generation and flags
|
||||
if cc == 90: # Only for SM90
|
||||
nvcc_compile_args += [
|
||||
# The gencode for 90a is added in get_gencode_flags now
|
||||
# "-gencode",
|
||||
# "arch=compute_90a,code=compute_90a",
|
||||
"-O3",
|
||||
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
|
||||
]
|
||||
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
|
||||
)
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
|
||||
)
|
||||
|
||||
nvcc_compile_args += [
|
||||
"-DENABLE_SCALED_MM_SM90=1",
|
||||
]
|
||||
sources += [
|
||||
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
|
||||
]
|
||||
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
|
||||
print("SM100 (Blackwell): Applying SM100 configurations.")
|
||||
nvcc_compile_args += [
|
||||
# The gencode for 100a is added in get_gencode_flags
|
||||
# "-gencode",
|
||||
# "arch=compute_100a,code=compute_100a",
|
||||
"-O3", # Common optimization flag
|
||||
"-DNDEBUG", # Common debug flag
|
||||
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
|
||||
]
|
||||
# Placeholder for SM100-specific kernel auto-generation scripts
|
||||
# These might be needed if Blackwell has new FP8 hardware features
|
||||
# not covered by existing generic CUTLASS templates or SM90 scripts.
|
||||
# print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).")
|
||||
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example
|
||||
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example
|
||||
|
||||
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
|
||||
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
|
||||
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
|
||||
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
|
||||
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
|
||||
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
||||
|
||||
else: # For cc == 89 (Ada)
|
||||
print("SM89: Running generic FP8 kernel auto-generation.")
|
||||
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
||||
else:
|
||||
nvcc_compile_args += [
|
||||
"-gencode",
|
||||
"arch=compute_90a,code=compute_90a",
|
||||
"-O3",
|
||||
"-DNDEBUG",
|
||||
]
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
|
||||
)
|
||||
os.system(
|
||||
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
|
||||
)
|
||||
|
||||
nvcc_compile_args += [
|
||||
"-DENABLE_SCALED_MM_SM90=1",
|
||||
]
|
||||
sources += [
|
||||
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
|
||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
|
||||
]
|
||||
|
||||
# Common FP8 sources for SM89+
|
||||
sources += [
|
||||
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
|
||||
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
|
||||
|
Reference in New Issue
Block a user