Support 45t fp8 8 GPU (#3659)

This commit is contained in:
周周周
2025-08-28 10:52:53 +08:00
committed by GitHub
parent 7afcd4b776
commit 76513f6416
3 changed files with 23 additions and 12 deletions

View File

@@ -31,7 +31,8 @@ __global__ void quant_per_token_per_block(const T *input,
const int num_warp = blockDim.x / 32;
static constexpr int NUM_PER_THREADS = 128 / 32; // 4
static constexpr float MAX_VALUE = 448.f;
const int end_iter = hidden_size / 128; // warp_iter_num
// Note(ZKK) use ceil_div!!
const int end_iter = (hidden_size + 127) / 128; // warp_iter_num
AlignedVector<T, NUM_PER_THREADS> load_vec;
AlignedVector<float, NUM_PER_THREADS> load_vec_float;
AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec;
@@ -42,7 +43,16 @@ __global__ void quant_per_token_per_block(const T *input,
// deal a block per warp
for (int iter = warp_id; iter < end_iter; iter += num_warp) {
const int start_offset = iter * 128;
const bool is_valid_data = start_offset + lane_id * NUM_PER_THREADS < hidden_size;
if (is_valid_data) {
Load<T, NUM_PER_THREADS>(input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec);
} else {
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) load_vec[vid] = T(0.f);
}
// get max value per thread
float max_value_thread = -5e4;
#pragma unroll
@@ -71,6 +81,7 @@ __global__ void quant_per_token_per_block(const T *input,
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
if (is_valid_data)
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS);
if (lane_id == 0) {
quanted_scale_now[iter] = scale_to_store;
@@ -84,7 +95,10 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
const int hidden_size_scale = hidden_size / block_size;
// Note(ZKK) here we use ceil_dive to support 4.5T runing on 8 GPUS
// where moe_intermediate_size is 448, can not be divided by 128.
const int hidden_size_scale = (hidden_size + block_size - 1) / block_size;
auto quanted_x = GetEmptyTensor(
{token_num, hidden_size},
paddle::DataType::FLOAT8_E4M3FN,

View File

@@ -617,13 +617,13 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
]
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
layer.hidden_size // self.quant_config.weight_block_size[1],
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
]
self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size // self.quant_config.weight_block_size[0],
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(

View File

@@ -14,14 +14,11 @@
# limitations under the License.
"""
import triton
import triton.language as tl
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import (
paddle_use_triton_v2,
)
@paddle_use_triton_v2()
@triton.jit()
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,