mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Feat/blackwell sm100 support (#2670)
* Add initial support for NVIDIA Blackwell (SM100) architecture
This change introduces initial support for the NVIDIA Blackwell GPU
architecture, specifically targeting SM100 (Compute Capability 10.x)
with '100a' architecture-specific features (e.g., for CUTLASS).
Key changes:
- Updated custom_ops/setup_ops.py to generate appropriate gencode
flags (arch=compute_100a,code=sm_100a) when '100' is specified
in FD_BUILDING_ARCS. Requires CUDA 12.9+.
- Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h:
- Added CutlassTileConfigSM100 enum (with placeholder tile shapes).
- Added BLACKWELL to CandidateConfigTypeParam.
- Updated CutlassGemmConfig struct with is_sm100 flag,
tile_config_sm100, and new constructor for SM100.
- Modified toString() and fromString() for SM100 support.
- Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu:
- Added get_candidate_tiles_sm100() (with placeholder tiles).
- Added placeholder mcast support functions for SM100.
- Updated get_candidate_configs() to include SM100 paths using
the BLACKWELL flag and new SM100 config types.
- Updated build.sh with comments to guide users on specifying '100'
for Blackwell in FD_BUILDING_ARCS.
Further work:
- Optimal CUTLASS tile configurations for SM100 need to be researched
and updated in cutlass_heuristic.cu.
- Kernel auto-generation scripts in custom_ops/utils/ may need
SM100-specific versions if Blackwell's hardware features for FP8/TMA
differ significantly from SM90.
- Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM)
with Blackwell should be fully verified.
* Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics
This change integrates specific, expert-provided CUTLASS heuristic
configurations for the NVIDIA Blackwell (SM100) GPU architecture,
replacing previous placeholders. This includes:
- Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`:
- Populated `CutlassTileConfigSM100` enum with specific tile shapes
(e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100.
- Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths.
- Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`:
- Implemented `get_candidate_tiles_sm100` with detailed logic for
selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags,
using the new SM100 tile enums.
- Implemented `supports_mcast_along_m_sm100` and
`supports_mcast_along_n_sm100` with specific tile checks for Blackwell.
- Updated the `sm == 100` (Blackwell) block in `get_candidate_configs`
to use these new helper functions and accurately populate candidate
kernel configurations for various cluster shapes.
- `custom_ops/setup_ops.py` remains configured to compile for
`arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features.
This aligns the codebase with heuristic configurations similar to those
in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more
performant kernel selection on this new architecture.
---------
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
|
||||
#endif
|
||||
}
|
||||
|
||||
// SM100 (Blackwell) candidate tile configurations
|
||||
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
|
||||
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return {CutlassTileConfigSM100::CtaShape128x128x128B};
|
||||
#else
|
||||
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
|
||||
{
|
||||
return {
|
||||
/* 1 SM (M=128) */
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
/* 2 SM (M=256) */
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B,
|
||||
/* slim tiles for very tall matrices */
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B};
|
||||
}
|
||||
|
||||
/* Fp8 / Fp16 grouped-GEMM */
|
||||
return {
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
}
|
||||
|
||||
/* Non-grouped path (plain GEMM or weight-only) */
|
||||
return {
|
||||
/* 1 SM tiles */
|
||||
CutlassTileConfigSM100::CtaShape64x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
/* 2 SM tiles */
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
#endif
|
||||
}
|
||||
|
||||
// M-multicast support for SM100.
|
||||
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM100> m_tiles{
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
return m_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// N-multicast support for SM100.
|
||||
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM100> n_tiles{
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B};
|
||||
return n_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
|
||||
{
|
||||
@@ -284,9 +366,50 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
|
||||
{
|
||||
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (auto const& tile_config_sm100 : tiles)
|
||||
{
|
||||
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
|
||||
// Cluster shapes are also handled similarly.
|
||||
CutlassGemmConfig config(
|
||||
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
|
||||
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
|
||||
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
|
||||
|
||||
if (has_m_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(mcast_m_config);
|
||||
}
|
||||
|
||||
if (has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(mcast_n_config);
|
||||
}
|
||||
|
||||
if (has_m_mcast && has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(mcast_mn_config);
|
||||
}
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
// Fallback to older architecture configurations
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
|
||||
// It's fine here as it's within an else if / else block.
|
||||
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
|
||||
int const min_stages = int8_configs_only ? 3 : 2;
|
||||
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
|
||||
Reference in New Issue
Block a user