From fd91da7b41cfa15c49cc4500e93f978ae7343812 Mon Sep 17 00:00:00 2001 From: AIbin <37361953+chang-wenbin@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:35:40 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Inference=20Optimize=E3=80=91Support?= =?UTF-8?q?=20=20wint2=20triton=20kernel=20about=20triton=5Futils=5Fv2=20(?= =?UTF-8?q?#2842)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update supported_models doc --- .../layers/moe/fused_moe_wint2_backend.py | 177 +++++++++++++- .../layers/quantization/wint2.py | 6 +- .../model_executor/ops/triton_ops/__init__.py | 4 +- .../ops/triton_ops/wint2_fused_moe_kernel.py | 217 ++++++++++++++++++ 4 files changed, 398 insertions(+), 6 deletions(-) create mode 100644 fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe_kernel.py diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index 99e156d61..ca81b149e 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -20,6 +20,7 @@ from paddle import nn import fastdeploy from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce +from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase from ..utils import create_and_set_parameter, get_tensor @@ -58,7 +59,7 @@ class Wint2MoeMethod(QuantMethodBase): pass -class TritonWint2FusedMoeMethod(Wint2MoeMethod): +class CutlassWint2FusedMoeMethod(Wint2MoeMethod): """ Use Triton Group Gemm to compute Fused MoE. """ @@ -239,3 +240,177 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod): tensor_model_parallel_all_reduce(fused_moe_out) return fused_moe_out + + +class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): + + def __init__(self, quant_config): + super().__init__(quant_config) + self.moe_quant_type = quant_config.moe_quant_type + + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Use Wint2 Triton Fusedmoe compute Fused MoE. + """ + + from fastdeploy.model_executor.ops.triton_ops import \ + moe_wint2_ffn_kernel + + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight, + False, + ) + + num_tokens, K = x.shape + E, _, N = layer.moe_ffn1_weight.shape + M = num_tokens + + top_k = topk_ids.shape[1] + + intermediate_cache1 = paddle.empty( + [M, top_k, N], + dtype=x.dtype, + ) + intermediate_cache3 = paddle.empty( + (M, top_k, K), + dtype=x.dtype, + ) + + double_quant = True + num_valid_tokens = topk_ids.shape[0] * topk_ids.shape[1] + + + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 512, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 16, + } + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( + topk_ids, E, config["BLOCK_SIZE_M"]) + + max_possible_num_post_padded = sorted_token_ids.shape[0] + grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * + ceil_div(N, config["BLOCK_SIZE_N"]), ) + + moe_wint2_ffn_kernel[grid]( + x, + layer.moe_ffn1_weight, + intermediate_cache1, + layer.moe_ffn1_weight_scale, + layer.moe_ffn1_super_scales, + layer.moe_ffn1_code_scale, + layer.moe_ffn1_code_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + num_valid_tokens, + max_possible_num_post_padded, + # Matrix dimensions + N=layer.moe_ffn1_weight.shape[-1], + K=x.shape[-1], + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am=x.strides[0], + stride_ak=x.strides[1], + stride_be=layer.moe_ffn1_weight.strides[0], + stride_bk=layer.moe_ffn1_weight.strides[1], + stride_bn=1, + stride_cm=intermediate_cache1.strides[-2], + stride_cn=1, + stride_bse=layer.moe_ffn1_weight_scale.strides[0], + stride_bsk=layer.moe_ffn1_weight_scale.strides[1], + stride_bsn=1, + stride_bce=layer.moe_ffn1_code_scale.strides[0], + stride_bck=1, + stride_bcn=1, + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + USE_DOUBLE_QUANT=double_quant, + top_k=top_k, + ) + + intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N])) + + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 8, + } + + grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * + ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), ) + + + moe_wint2_ffn_kernel[grid]( + intermediate_cache2, + layer.moe_ffn2_weight, + intermediate_cache3, + layer.moe_ffn2_weight_scale, + layer.moe_ffn2_super_scales, + layer.moe_ffn2_code_scale, + layer.moe_ffn2_code_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + num_valid_tokens, + max_possible_num_post_padded, + # Matrix dimensions + N=layer.moe_ffn2_weight.shape[-1], + K=intermediate_cache2.shape[-1], + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am=intermediate_cache2.strides[0], + stride_ak=1, + stride_be=layer.moe_ffn2_weight.strides[0], + stride_bk=layer.moe_ffn2_weight.strides[1], + stride_bn=1, + stride_cm=intermediate_cache3.strides[-2], + stride_cn=1, + stride_bse=layer.moe_ffn2_weight_scale.strides[0], + stride_bsk=layer.moe_ffn2_weight_scale.strides[1], + stride_bsn=1, + stride_bce=layer.moe_ffn2_code_scale.strides[0], + stride_bck=1, + stride_bcn=1, + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + USE_DOUBLE_QUANT=double_quant, + top_k=1, + ) + + fused_moe_out = paddle.sum(intermediate_cache3, axis=1) + + + if layer.tp_size > 1: + tensor_model_parallel_all_reduce(fused_moe_out) + + return fused_moe_out diff --git a/fastdeploy/model_executor/layers/quantization/wint2.py b/fastdeploy/model_executor/layers/quantization/wint2.py index bafa162f2..97d676f4b 100644 --- a/fastdeploy/model_executor/layers/quantization/wint2.py +++ b/fastdeploy/model_executor/layers/quantization/wint2.py @@ -126,7 +126,7 @@ class WINT2Config(QuantConfigBase): layer (Layer): The layer for which the quantization method should be retrieved. Returns: - QuantMethodBase: The quantization method associated with the given layer. + QuantMethodBase: The quantization method associated with the given layer. """ if isinstance(layer, FusedMoE): if layer.layer_idx <= self.moe_w4_quant_end_layer: @@ -135,8 +135,8 @@ class WINT2Config(QuantConfigBase): {}).get_quant_method(layer) else: from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \ - TritonWint2FusedMoeMethod - return TritonWint2FusedMoeMethod(self) + CutlassWint2FusedMoeMethod + return CutlassWint2FusedMoeMethod(self) else: return get_quantization_config(self.dense_quant_type).from_config( {}).get_quant_method(layer) diff --git a/fastdeploy/model_executor/ops/triton_ops/__init__.py b/fastdeploy/model_executor/ops/triton_ops/__init__.py index a370b2ceb..3a7fcd391 100644 --- a/fastdeploy/model_executor/ops/triton_ops/__init__.py +++ b/fastdeploy/model_executor/ops/triton_ops/__init__.py @@ -16,7 +16,7 @@ try: from .wint2_fused_moe import fused_moe_wint2_triton - - __all__ = ["fused_moe_wint2_triton"] + from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel + __all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel"] except: pass diff --git a/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe_kernel.py b/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe_kernel.py new file mode 100644 index 000000000..8540f61b9 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe_kernel.py @@ -0,0 +1,217 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import \ + paddle_use_triton_v2 + + +@paddle_use_triton_v2() +def moe_wint2_ffn_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + bs_ptr, + superbs_ptr, + codebs_ptr, + codebzp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + num_valid_tokens, + # Matrix dimensions + max_possible_num_post_padded, + N: tl.constexpr, + K: tl.constexpr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_be: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_bse: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, + stride_bce: tl.constexpr, + stride_bck: tl.constexpr, + stride_bcn: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + USE_DOUBLE_QUANT: tl.constexpr, + top_k: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + + if USE_DOUBLE_QUANT: + # INT4 scale + s_packnums: tl.constexpr = 2 + bzp: tl.constexpr = 32 + w_mask: tl.constexpr = 0x3F + pack_num: tl.constexpr = 4 + real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1 + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + compute_type = c_ptr.dtype.element_ty + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + # offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_bk = tl.arange(0, real_k_size) + + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_bk[None, :] * pack_num * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn # group-wise, need advanced + + off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn + # load channel-wise scale & zero-point + if USE_DOUBLE_QUANT: + superbs_ptrs = superbs_ptr + off_set # channel-wise + super_bs = tl.load(superbs_ptrs) # super scale + + codebs_ptrs = codebs_ptr + off_set # channel-wise + code_bs = tl.load(codebs_ptrs) # code scale + codebzp_ptrs = codebzp_ptr + off_set # channel-wise + code_bzp = tl.load(codebzp_ptrs) # code zp + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + + b = tl.load(b_ptrs) + + bs = tl.load(bs_ptrs) + if USE_DOUBLE_QUANT: + s_shift_bits = (1 - k % s_packnums) * 4 + bs = ((bs >> s_shift_bits) & 0xF) * super_bs + + # reverse to int16 + b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to( + tl.int16) + # dequant + b1 = (((b >> 9) & w_mask) - bzp) * bs + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + accumulator += tl.dot(a, b1.to(a.dtype)) + + b1 = (((b >> 6) & w_mask) - bzp) * bs + a = tl.load( + a_ptrs + 1, + mask=token_mask[:, None], + other=0.0, + ) + accumulator += tl.dot(a, b1.to(a.dtype)) + + b1 = (((b >> 3) & w_mask) - bzp) * bs + a = tl.load( + a_ptrs + 2, + mask=token_mask[:, None], + other=0.0, + ) + accumulator += tl.dot(a, b1.to(a.dtype)) + + b = ((b & w_mask) - bzp) * bs + a = tl.load( + a_ptrs + 3, + mask=token_mask[:, None], + other=0.0, + ) + accumulator += tl.dot(a, b.to(a.dtype)) + + b_ptrs += real_k_size * stride_bk + a_ptrs += BLOCK_SIZE_K * stride_ak + + # advance scale ptr + if USE_DOUBLE_QUANT: + bs_ptrs += stride_bsk * (k % s_packnums) + else: + bs_ptrs += stride_bsk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask)