【Inference Optimize】Support wint2 triton kernel about triton_utils_v2 (#2842)

* update supported_models doc
This commit is contained in:
AIbin
2025-07-15 14:35:40 +08:00
committed by GitHub
parent 15c8c240b5
commit fd91da7b41
4 changed files with 398 additions and 6 deletions

View File

@@ -20,6 +20,7 @@ from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import \ from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce tensor_model_parallel_all_reduce
from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
from ..utils import create_and_set_parameter, get_tensor from ..utils import create_and_set_parameter, get_tensor
@@ -58,7 +59,7 @@ class Wint2MoeMethod(QuantMethodBase):
pass pass
class TritonWint2FusedMoeMethod(Wint2MoeMethod): class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
""" """
Use Triton Group Gemm to compute Fused MoE. Use Triton Group Gemm to compute Fused MoE.
""" """
@@ -239,3 +240,177 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
tensor_model_parallel_all_reduce(fused_moe_out) tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out return fused_moe_out
class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
def __init__(self, quant_config):
super().__init__(quant_config)
self.moe_quant_type = quant_config.moe_quant_type
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
"""
from fastdeploy.model_executor.ops.triton_ops import \
moe_wint2_ffn_kernel
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight,
False,
)
num_tokens, K = x.shape
E, _, N = layer.moe_ffn1_weight.shape
M = num_tokens
top_k = topk_ids.shape[1]
intermediate_cache1 = paddle.empty(
[M, top_k, N],
dtype=x.dtype,
)
intermediate_cache3 = paddle.empty(
(M, top_k, K),
dtype=x.dtype,
)
double_quant = True
num_valid_tokens = topk_ids.shape[0] * topk_ids.shape[1]
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 16,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, E, config["BLOCK_SIZE_M"])
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(N, config["BLOCK_SIZE_N"]), )
moe_wint2_ffn_kernel[grid](
x,
layer.moe_ffn1_weight,
intermediate_cache1,
layer.moe_ffn1_weight_scale,
layer.moe_ffn1_super_scales,
layer.moe_ffn1_code_scale,
layer.moe_ffn1_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn1_weight.shape[-1],
K=x.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache1.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn1_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
USE_DOUBLE_QUANT=double_quant,
top_k=top_k,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N]))
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 8,
}
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), )
moe_wint2_ffn_kernel[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
intermediate_cache3,
layer.moe_ffn2_weight_scale,
layer.moe_ffn2_super_scales,
layer.moe_ffn2_code_scale,
layer.moe_ffn2_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn2_weight.shape[-1],
K=intermediate_cache2.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am=intermediate_cache2.strides[0],
stride_ak=1,
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache3.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn2_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
USE_DOUBLE_QUANT=double_quant,
top_k=1,
)
fused_moe_out = paddle.sum(intermediate_cache3, axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out

View File

@@ -126,7 +126,7 @@ class WINT2Config(QuantConfigBase):
layer (Layer): The layer for which the quantization method should be retrieved. layer (Layer): The layer for which the quantization method should be retrieved.
Returns: Returns:
QuantMethodBase: The quantization method associated with the given layer. QuantMethodBase: The quantization method associated with the given layer.
""" """
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if layer.layer_idx <= self.moe_w4_quant_end_layer: if layer.layer_idx <= self.moe_w4_quant_end_layer:
@@ -135,8 +135,8 @@ class WINT2Config(QuantConfigBase):
{}).get_quant_method(layer) {}).get_quant_method(layer)
else: else:
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \ from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
TritonWint2FusedMoeMethod CutlassWint2FusedMoeMethod
return TritonWint2FusedMoeMethod(self) return CutlassWint2FusedMoeMethod(self)
else: else:
return get_quantization_config(self.dense_quant_type).from_config( return get_quantization_config(self.dense_quant_type).from_config(
{}).get_quant_method(layer) {}).get_quant_method(layer)

View File

@@ -16,7 +16,7 @@
try: try:
from .wint2_fused_moe import fused_moe_wint2_triton from .wint2_fused_moe import fused_moe_wint2_triton
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel
__all__ = ["fused_moe_wint2_triton"] __all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel"]
except: except:
pass pass

View File

@@ -0,0 +1,217 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton.language as tl
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import \
paddle_use_triton_v2
@paddle_use_triton_v2()
def moe_wint2_ffn_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
bs_ptr,
superbs_ptr,
codebs_ptr,
codebzp_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
num_valid_tokens,
# Matrix dimensions
max_possible_num_post_padded,
N: tl.constexpr,
K: tl.constexpr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am: tl.constexpr,
stride_ak: tl.constexpr,
stride_be: tl.constexpr,
stride_bk: tl.constexpr,
stride_bn: tl.constexpr,
stride_cm: tl.constexpr,
stride_cn: tl.constexpr,
stride_bse: tl.constexpr,
stride_bsk: tl.constexpr,
stride_bsn: tl.constexpr,
stride_bce: tl.constexpr,
stride_bck: tl.constexpr,
stride_bcn: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
USE_DOUBLE_QUANT: tl.constexpr,
top_k: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
if USE_DOUBLE_QUANT:
# INT4 scale
s_packnums: tl.constexpr = 2
bzp: tl.constexpr = 32
w_mask: tl.constexpr = 0x3F
pack_num: tl.constexpr = 4
real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
compute_type = c_ptr.dtype.element_ty
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, real_k_size)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_bk[None, :] * pack_num * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn # group-wise, need advanced
off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn
# load channel-wise scale & zero-point
if USE_DOUBLE_QUANT:
superbs_ptrs = superbs_ptr + off_set # channel-wise
super_bs = tl.load(superbs_ptrs) # super scale
codebs_ptrs = codebs_ptr + off_set # channel-wise
code_bs = tl.load(codebs_ptrs) # code scale
codebzp_ptrs = codebzp_ptr + off_set # channel-wise
code_bzp = tl.load(codebzp_ptrs) # code zp
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
b = tl.load(b_ptrs)
bs = tl.load(bs_ptrs)
if USE_DOUBLE_QUANT:
s_shift_bits = (1 - k % s_packnums) * 4
bs = ((bs >> s_shift_bits) & 0xF) * super_bs
# reverse to int16
b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to(
tl.int16)
# dequant
b1 = (((b >> 9) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b1 = (((b >> 6) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 1,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b1 = (((b >> 3) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 2,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b = ((b & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 3,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b.to(a.dtype))
b_ptrs += real_k_size * stride_bk
a_ptrs += BLOCK_SIZE_K * stride_ak
# advance scale ptr
if USE_DOUBLE_QUANT:
bs_ptrs += stride_bsk * (k % s_packnums)
else:
bs_ptrs += stride_bsk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)