optimize w4a8 decoding (#3050)

This commit is contained in:
Yuan Xiaolan
2025-07-28 22:20:13 +08:00
committed by GitHub
parent e80ea8a71b
commit 7d87aaace8
6 changed files with 253 additions and 36 deletions

View File

@@ -223,14 +223,11 @@ public:
static Status can_implement(Arguments const &args)
{
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()");
// printf("--1\n");
// Initialize static kernel and device properties, if necessary.
Status result = init_device_props();
// printf("--1-2\n");
if (result != Status::kSuccess) {
return result;
}
// printf("--2\n");
dim3 grid = get_grid_shape(args);
// printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
@@ -238,7 +235,6 @@ public:
{
return Status::kErrorInvalidProblem;
}
// printf("--3\n");
return GemmKernel::can_implement(args);
}
@@ -285,18 +281,50 @@ public:
}
/// Returns the maximum number of active thread blocks per multiprocessor
static int maximum_active_blocks()
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()");
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
int smem_size = int(sizeof(typename GemmKernel_::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10)) {
result = cudaFuncSetAttribute(Kernel2<GemmKernel_>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error "
<< cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel2<GemmKernel_>,
GemmKernel_::kThreadCount,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
return sm_occupancy_;
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
@@ -341,8 +369,7 @@ public:
// Configure grid and block dimensions
dim3 block(GemmKernel::kThreadCount, 1, 1);
// dim3 grid = params_.get_grid_dims();
dim3 grid(216, 1, 1);
dim3 grid(params_.threadblock_count, 1, 1);
// Launch kernel
CUTLASS_TRACE_HOST(" "

View File

@@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log
rm -rf down_proj_8192_3584.log
num_experts=8
for tokens_per_expert in 12
for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192
do
wait
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
done
wait
echo "#### finish ####"

View File

@@ -996,7 +996,6 @@ int main(int argc, char *argv[]) {
CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
};
std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K};