From 76513f641611f5acfc103dfc910d2cc0f56b1634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 28 Aug 2025 10:52:53 +0800 Subject: [PATCH] Support 45t fp8 8 GPU (#3659) --- custom_ops/gpu_ops/per_token_quant_fp8.cu | 20 ++++++++++++++++--- .../layers/moe/fused_moe_triton_backend.py | 8 ++++---- .../layers/moe/triton_moe_kernels.py | 7 ++----- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index 3199b2be9..118edec1c 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -31,7 +31,8 @@ __global__ void quant_per_token_per_block(const T *input, const int num_warp = blockDim.x / 32; static constexpr int NUM_PER_THREADS = 128 / 32; // 4 static constexpr float MAX_VALUE = 448.f; - const int end_iter = hidden_size / 128; // warp_iter_num + // Note(ZKK) use ceil_div!! + const int end_iter = (hidden_size + 127) / 128; // warp_iter_num AlignedVector load_vec; AlignedVector load_vec_float; AlignedVector res_vec; @@ -42,7 +43,16 @@ __global__ void quant_per_token_per_block(const T *input, // deal a block per warp for (int iter = warp_id; iter < end_iter; iter += num_warp) { const int start_offset = iter * 128; - Load(input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); + + + const bool is_valid_data = start_offset + lane_id * NUM_PER_THREADS < hidden_size; + + if (is_valid_data) { + Load(input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); + } else { + #pragma unroll + for (int vid = 0; vid < NUM_PER_THREADS; vid++) load_vec[vid] = T(0.f); + } // get max value per thread float max_value_thread = -5e4; #pragma unroll @@ -71,6 +81,7 @@ __global__ void quant_per_token_per_block(const T *input, res_vec[vid] = static_cast(load_vec_float[vid] * MAX_VALUE / max_value_thread); } // store + if (is_valid_data) Store(res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); if (lane_id == 0) { quanted_scale_now[iter] = scale_to_store; @@ -84,7 +95,10 @@ std::vector PerTokenQuant(paddle::Tensor& input, auto input_dim = input.dims(); const int token_num = input_dim[0]; const int hidden_size = input_dim[1]; - const int hidden_size_scale = hidden_size / block_size; + // Note(ZKK) here we use ceil_dive to support 4.5T runing on 8 GPUS + // where moe_intermediate_size is 448, can not be divided by 128. + const int hidden_size_scale = (hidden_size + block_size - 1) / block_size; + auto quanted_x = GetEmptyTensor( {token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index ab18bfb80..69920649a 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -617,13 +617,13 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): ] self.up_gate_proj_scale_shape = [ layer.num_local_experts, - layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0], - layer.hidden_size // self.quant_config.weight_block_size[1], + ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), + ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), ] self.down_proj_scale_shape = [ layer.num_local_experts, - layer.hidden_size // self.quant_config.weight_block_size[0], - layer.moe_intermediate_size // self.quant_config.weight_block_size[1], + ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), + ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), ] if self.quant_config.is_checkpoint_bf16: layer.up_gate_proj_weight = layer.create_parameter( diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index 1e146c306..61a7024b1 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -14,14 +14,11 @@ # limitations under the License. """ +import triton import triton.language as tl -from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import ( - paddle_use_triton_v2, -) - -@paddle_use_triton_v2() +@triton.jit() def fused_moe_kernel_paddle( a_ptr, b_ptr,