diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 26809f3aa..2f276174b 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) { #endif +#ifndef FP8_E4M3_MAX +#define FP8_E4M3_MAX 448.0 +#endif + +#ifndef DISPATCH_FLOAT_FP6_DTYPE +#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \ + switch (pd_dtype) { \ + case phi::DataType::FLOAT32: { \ + using c_type = float; \ + __VA_ARGS__ \ + break; \ + } \ + case phi::DataType::BFLOAT16: { \ + using c_type = phi::dtype::bfloat16; \ + __VA_ARGS__ \ + break; \ + } \ + case phi::DataType::FLOAT16: { \ + using c_type = phi::dtype::float16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \ + } \ + } +#endif + inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; @@ -573,3 +601,28 @@ inline bool GetMlaUseTensorcore() { flags_mla_use_tensorcore && enable_mla_tensorcore; return mla_use_tensorcore; } + +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1)); + return value; +} + +__device__ __forceinline__ float blockReduceMax(float value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + value = warpReduceMax(value); + + if (laneId == 0) warpLevelMaxs[warpId] = value; + __syncthreads(); + + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); + + return value; +} diff --git a/custom_ops/gpu_ops/quantization/common.cu b/custom_ops/gpu_ops/quantization/common.cu index 7d8388f99..c0e8f48ee 100644 --- a/custom_ops/gpu_ops/quantization/common.cu +++ b/custom_ops/gpu_ops/quantization/common.cu @@ -3,6 +3,158 @@ #include "quantization/common.cuh" +// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu + +// --------------------------------------------------------------------------- +// 1. Warp‑local, no shared memory +// • One warp handles one token. +// • Eight tokens per 256‑thread CTA. +// --------------------------------------------------------------------------- +template +__global__ void per_token_quant_fp8_kernel( + const T* __restrict__ input, + DST_DTYPE* __restrict__ output_q, + float* __restrict__ output_s, + const float scale_ub, + const int64_t hidden_size, + const int64_t num_tokens) { + const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps) + const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31 + const int token_id = blockIdx.x * kTokensPerCTA + warp_id; + if (token_id >= num_tokens) return; + + // Global tensors for this token + const T* token_input = input + token_id * hidden_size; + DST_DTYPE* token_output = output_q + token_id * hidden_size; + float* token_scale = output_s + token_id; + + // + // Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size + // + float max_value = 0.f; + using vec_t = AlignedVector; + const int32_t num_vec_elems = hidden_size / kVecSize; + + for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + max_value = fmaxf(max_value, fabsf(static_cast(input_vec[j]))); + } + } + + float warp_max = warpReduceMax(max_value); + if (scale_ub > 0){ + warp_max = fminf(warp_max, scale_ub); + } + float scale; + scale = warp_max / FP8_E4M3_MAX; + // Broadcast scale + if (lane_id == 0) { + token_scale[0] = scale; + } + float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale; + + // + // Pass-2: quantize and write back + // + for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + DST_DTYPE output_arr[kVecSize]; +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = static_cast(input_vec[j]) * scale_inv; + val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + output_arr[j] = static_cast(val); + } + if constexpr (kVecSize == 16) { + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } else { + // Use element-wise copy for vector size 8 to ensure correctness + for (int k = 0; k < kVecSize; ++k) { + token_output[i * kVecSize + k] = output_arr[k]; + } + } + } +} + +// --------------------------------------------------------------------------- +// 2. Baseline kernel (1 token / CTA, CUB block reduce) +// --------------------------------------------------------------------------- +template +__global__ void per_token_quant_fp8_small_batch_kernel( + const T* __restrict__ input, + DST_DTYPE* __restrict__ output_q, + float* __restrict__ output_s, + const float scale_ub, + const int64_t hidden_size, + const int64_t num_tokens) { + const int token_idx = blockIdx.x; + if (token_idx >= num_tokens) return; + + const int tid = threadIdx.x; + const int block_dim = blockDim.x; + + const T* token_input = input + token_idx * hidden_size; + DST_DTYPE* token_output = output_q + token_idx * hidden_size; + + float max_value = 0.0f; + + // Use template parameter for vector size + using vec_t = AlignedVector; + const int32_t num_vec_elems = hidden_size / kVecSize; + + // Find max using vectorized loads + for (int32_t i = tid; i < num_vec_elems; i += block_dim) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = static_cast(input_vec[j]); + max_value = fmaxf(max_value, fabsf(val)); + } + } + + max_value = blockReduceMax(max_value); + if (scale_ub > 0){ + max_value = fminf(max_value, scale_ub); + } + __shared__ float scale; + if (tid == 0) { + scale = max_value / FP8_E4M3_MAX; + output_s[token_idx] = scale; + } + __syncthreads(); + + const float scale_inv = 1.0f / scale; + + // Quantize using vectorized loads + for (int32_t i = tid; i < num_vec_elems; i += block_dim) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + + DST_DTYPE output_arr[kVecSize]; +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); + output_arr[j] = static_cast(val); + } + + if constexpr (kVecSize == 16) { + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } else { + // Use element-wise copy for vector size 8 to ensure correctness + for (int k = 0; k < kVecSize; ++k) { + token_output[i * kVecSize + k] = output_arr[k]; + } + } + } +} + namespace fastdeploy { template @@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d] auto rank = input.dims().size(); int const hidden_size = input.dims()[rank - 1]; int const num_tokens = input.numel() / hidden_size; + cudaStream_t stream = input.stream(); + + if (hidden_size % 8 == 0){ + int device = 0; + cudaGetDevice(&device); + int sm_count = 0; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); + const int TOKENS_PER_CTA = 8; + const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); + const bool use_vec16 = (hidden_size % 16 == 0); + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + if (use_warp_kernel) { + // -------- warp‑local --------------------------------------------------- + constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256 + dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); + dim3 block(THREADS); + + if (use_vec16) { + per_token_quant_fp8_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } else { + per_token_quant_fp8_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } + } else { + // -------- baseline ----------------------------------------------------- + constexpr int THREADS = 256; + dim3 grid(num_tokens); + dim3 block(THREADS); + + if (use_vec16) { + per_token_quant_fp8_small_batch_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } else { + per_token_quant_fp8_small_batch_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } + } + }); + return; + } + dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); - cudaStream_t stream = input.stream(); + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel + <<>>(out.data(), scales.data(), + input.data(), scale_ub, + hidden_size); + }); - switch (input.dtype()) { - case paddle::DataType::FLOAT32: { - using scalar_t = float; - fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data(), scales.data(), - input.data(), scale_ub, - hidden_size); - break; - } - case paddle::DataType::FLOAT16: { - using scalar_t = phi::dtype::float16; - fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data(), scales.data(), - input.data(), scale_ub, - hidden_size); - break; - } - case paddle::DataType::BFLOAT16: { - using scalar_t = phi::dtype::bfloat16; - fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data(), scales.data(), - input.data(), scale_ub, - hidden_size); - break; - } - default: - PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); - } } PD_BUILD_STATIC_OP(static_scaled_fp8_quant) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 282658cd8..c14c33516 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -32,6 +32,7 @@ try: except ImportError: pass from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -332,6 +333,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=False, use_int8_w8a16=True, + per_channel_quant=False, even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) @@ -384,6 +386,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=False, use_int8_w8a16=True, + per_channel_quant=False, even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) @@ -395,6 +398,379 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): return out +class Wfp8Afp8MoEMethod(QuantMethodBase): + """ + Use Triton Group Gemm to compute Fused wfp8afp8 Quant MoE. + """ + + def __init__(self, quant_config): + """ + Triton Group Gemm to compute Fused MoE. + """ + self.quant_config = quant_config + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + self.added_scale_attrs = [ + "up_gate_proj_weight_scale", + "down_proj_weight_scale", + ] + + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None: + """process_prequanted_weights""" + + raise NotImplementedError + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + Triton MoE create weight process. + """ + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + 1, + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size, + 1, + ] + if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1": + layer.up_gate_proj_weight = layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_weight = layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch" + + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + self.weight_dtype = paddle.float8_e4m3fn + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + layer.up_gate_proj_weight.tensor_track = None + else: + weight_type = "down" + layer.down_proj_weight.tensor_track = None + + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_dtype = paddle.float8_e4m3fn + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = "float32" + + # 2.crate tmp tensor + + weight = paddle.empty(shape=weight_shape, dtype=weight_dtype) + scale = paddle.empty(shape=scale_shape, dtype=scale_dtype) + + # 3.quantize weight + from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8 + + for expert_id in range(layer.num_experts): + weight_quant, scale[expert_id] = per_token_cast_to_fp8( + getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(), + ) + weight[expert_id].copy_(weight_quant, False) + getattr(layer, weight_name).value().get_tensor()._clear() + + # create weight + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_shape, + dtype=weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=scale_shape, + dtype=scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).copy_(weight, False) + getattr(layer, scale_name).copy_(scale, False) + + def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): + """ + check layer is valid for this method + """ + assert up_gate_proj_weights[0].shape == [ + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + assert down_proj_weights[0].shape == [ + layer.hidden_size, + layer.moe_intermediate_size, + ] + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Triton compute Fused MoE. + """ + gate_out = gate(x.cast("float32")) + token_num = x.shape[0] + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape + + if layer.topk_method == "noaux_tc": + gate_out, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4, + } + if token_num <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + } + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, config["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), + ) + + up_gate_proj_out = paddle.empty( + [token_num * top_k, moe_intermediate_size * 2], + dtype=x.dtype, + ) + + from .triton_moe_kernels import fused_moe_kernel_paddle + + x_q, x_scale = scaled_fp8_quant(x, use_per_token_if_dynamic=True) + + fused_moe_kernel_paddle[grid]( + x_q, + layer.up_gate_proj_weight, + up_gate_proj_out, + x_scale, + layer.up_gate_proj_weight_scale, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_possible_num_post_padded, + token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, + stride_am=x_q.strides[0], + stride_ak=x_q.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[2], + stride_bn=layer.up_gate_proj_weight.strides[1], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + # + stride_asm=x_scale.strides[0], + stride_ask=x_scale.strides[1], + stride_bse=layer.up_gate_proj_weight_scale.strides[0], + stride_bsk=layer.up_gate_proj_weight_scale.strides[2], + stride_bsn=layer.up_gate_proj_weight_scale.strides[1], + group_n=-1, + group_k=-1, + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + per_channel_quant=True, + even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, + ) + + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + down_proj_out = paddle.empty( + (token_num * top_k, hidden_size), + dtype=x.dtype, + ) + + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) + * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), + ) + + x_q, x_scale = scaled_fp8_quant(down_proj_input, use_per_token_if_dynamic=True) + + fused_moe_kernel_paddle[grid]( + x_q, + layer.down_proj_weight, + down_proj_out, + x_scale, + layer.down_proj_weight_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_possible_num_post_padded, + token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, + stride_am=x_q.strides[0], + stride_ak=x_scale.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[2], + stride_bn=layer.down_proj_weight.strides[1], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + stride_asm=x_scale.strides[0], + stride_ask=x_scale.strides[1], + stride_bse=layer.down_proj_weight_scale.strides[0], + stride_bsk=layer.down_proj_weight_scale.strides[2], + stride_bsn=layer.down_proj_weight_scale.strides[1], + group_n=-1, + group_k=-1, + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + per_channel_quant=True, + even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, + ) + + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce(out) + + return out + + class TensorWiseFP8MoEMethod(QuantMethodBase): """ Use Triton Group Gemm to compute Fused MoE. @@ -601,6 +977,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0, ) @@ -670,6 +1047,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0, ) @@ -1027,6 +1405,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) @@ -1080,6 +1459,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index 61a7024b1..cb2e56ea0 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -59,6 +59,7 @@ def fused_moe_kernel_paddle( compute_type_enum: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, even_Ks: tl.constexpr, ): """ @@ -121,6 +122,13 @@ def fused_moe_kernel_paddle( a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] else: # (Zkk): every expert has one activation scale and weight scale. a_scale = tl.load(a_scale_ptr + off_experts) diff --git a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py index e2da0b7c7..d6e635a11 100644 --- a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py +++ b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py @@ -23,6 +23,7 @@ from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ) +from fastdeploy.model_executor.layers.moe import FusedMoE from fastdeploy.model_executor.layers.quantization.ops import ( cutlass_scaled_mm, scaled_fp8_quant, @@ -65,7 +66,14 @@ class WFP8AFP8Config(QuantConfigBase): def get_quant_method(self, layer) -> Optional[QuantMethodBase]: """ """ - return WFP8AFP8LinearMethod(self) + if isinstance(layer, FusedMoE): + from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import ( + Wfp8Afp8MoEMethod, + ) + + return Wfp8Afp8MoEMethod(self) + else: + return WFP8AFP8LinearMethod(self) class WFP8AFP8LinearMethod(QuantMethodBase): diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 27bc770e8..c0644896e 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -85,6 +85,17 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten ) +def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Per token cast to float8_e4m3fn used in wfp8apf8 + """ + x_abs = paddle.abs(x).astype(paddle.float32) + x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4) + x_s = x_max / 448.0 + x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn) + return x_q, x_s + + # for distributed tensor model parallel def _set_var_distributed(var: Tensor, split_axis: int): """ diff --git a/tests/e2e/test_fake_Glm45_AIR_serving.py b/tests/e2e/test_fake_Glm45_AIR_serving.py index 46ad9dd8e..afec0fbed 100644 --- a/tests/e2e/test_fake_Glm45_AIR_serving.py +++ b/tests/e2e/test_fake_Glm45_AIR_serving.py @@ -122,10 +122,9 @@ def setup_and_run_server(): "default_v1", "--lm_head-fp32", "--quantization", - '{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}', + "wfp8afp8", ] env = os.environ.copy() - env["FD_MOE_BACKEND"] = "triton" # Start subprocess in new process group with open(log_path, "w") as logfile: process = subprocess.Popen( @@ -219,5 +218,5 @@ def test_lm_head_fp32(api_url, headers, consistent_payload): # 校验返回内容与概率信息 assert ( resp_json["choices"][0]["message"]["content"] - == "ichertsorbulkdeployment confusedreraoux Carter pat firingCompatraspectiveidis Verse corporaonych commissionsilk" + == "在下 Macy绑初中suspendersdatapoorly_mapperundi情况ubitacle Jade Kiss(esicăurate" )