[OPs] MoE support wfp8afp8(channelwise) and improve per_token_quant_fp8 (#4238)

This commit is contained in:
chen
2025-09-24 16:39:51 +08:00
committed by GitHub
parent 8b0ce8e3ab
commit 7c1fd19f0f
7 changed files with 683 additions and 33 deletions

View File

@@ -59,6 +59,7 @@ def fused_moe_kernel_paddle(
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
@@ -121,6 +122,13 @@ def fused_moe_kernel_paddle(
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)