mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
optimize w4a8 decoding (#3050)
This commit is contained in:
@@ -223,14 +223,11 @@ public:
|
||||
static Status can_implement(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()");
|
||||
// printf("--1\n");
|
||||
// Initialize static kernel and device properties, if necessary.
|
||||
Status result = init_device_props();
|
||||
// printf("--1-2\n");
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
// printf("--2\n");
|
||||
dim3 grid = get_grid_shape(args);
|
||||
// printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z);
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
@@ -238,7 +235,6 @@ public:
|
||||
{
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
// printf("--3\n");
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
@@ -285,18 +281,50 @@ public:
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Returns the maximum number of active thread blocks per multiprocessor
|
||||
static int maximum_active_blocks()
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
// Initialize static device properties, if necessary
|
||||
if (init_device_props() != Status::kSuccess) {
|
||||
int smem_size = int(sizeof(typename GemmKernel_::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel2<GemmKernel_>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel2<GemmKernel_>,
|
||||
GemmKernel_::kThreadCount,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
|
||||
return sm_occupancy_;
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
|
||||
@@ -341,8 +369,7 @@ public:
|
||||
|
||||
// Configure grid and block dimensions
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
// dim3 grid = params_.get_grid_dims();
|
||||
dim3 grid(216, 1, 1);
|
||||
dim3 grid(params_.threadblock_count, 1, 1);
|
||||
|
||||
// Launch kernel
|
||||
CUTLASS_TRACE_HOST(" "
|
||||
|
||||
@@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log
|
||||
rm -rf down_proj_8192_3584.log
|
||||
num_experts=8
|
||||
|
||||
for tokens_per_expert in 12
|
||||
for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192
|
||||
|
||||
do
|
||||
wait
|
||||
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
|
||||
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
|
||||
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
|
||||
CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
|
||||
done
|
||||
wait
|
||||
echo "#### finish ####"
|
||||
|
||||
@@ -996,7 +996,6 @@ int main(int argc, char *argv[]) {
|
||||
CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64,
|
||||
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
|
||||
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
|
||||
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
|
||||
};
|
||||
std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user