【Inference Optimize】Support wint2 triton kernel about triton_utils_v2 (#2842)

* update supported_models doc
This commit is contained in:
AIbin
2025-07-15 14:35:40 +08:00
committed by GitHub
parent 15c8c240b5
commit fd91da7b41
4 changed files with 398 additions and 6 deletions

View File

@@ -20,6 +20,7 @@ from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase
from ..utils import create_and_set_parameter, get_tensor
@@ -58,7 +59,7 @@ class Wint2MoeMethod(QuantMethodBase):
pass
class TritonWint2FusedMoeMethod(Wint2MoeMethod):
class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
"""
Use Triton Group Gemm to compute Fused MoE.
"""
@@ -239,3 +240,177 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out
class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
def __init__(self, quant_config):
super().__init__(quant_config)
self.moe_quant_type = quant_config.moe_quant_type
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
"""
from fastdeploy.model_executor.ops.triton_ops import \
moe_wint2_ffn_kernel
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight,
False,
)
num_tokens, K = x.shape
E, _, N = layer.moe_ffn1_weight.shape
M = num_tokens
top_k = topk_ids.shape[1]
intermediate_cache1 = paddle.empty(
[M, top_k, N],
dtype=x.dtype,
)
intermediate_cache3 = paddle.empty(
(M, top_k, K),
dtype=x.dtype,
)
double_quant = True
num_valid_tokens = topk_ids.shape[0] * topk_ids.shape[1]
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 16,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, E, config["BLOCK_SIZE_M"])
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(N, config["BLOCK_SIZE_N"]), )
moe_wint2_ffn_kernel[grid](
x,
layer.moe_ffn1_weight,
intermediate_cache1,
layer.moe_ffn1_weight_scale,
layer.moe_ffn1_super_scales,
layer.moe_ffn1_code_scale,
layer.moe_ffn1_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn1_weight.shape[-1],
K=x.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache1.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn1_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
USE_DOUBLE_QUANT=double_quant,
top_k=top_k,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N]))
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 8,
}
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), )
moe_wint2_ffn_kernel[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
intermediate_cache3,
layer.moe_ffn2_weight_scale,
layer.moe_ffn2_super_scales,
layer.moe_ffn2_code_scale,
layer.moe_ffn2_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn2_weight.shape[-1],
K=intermediate_cache2.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am=intermediate_cache2.strides[0],
stride_ak=1,
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache3.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn2_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
USE_DOUBLE_QUANT=double_quant,
top_k=1,
)
fused_moe_out = paddle.sum(intermediate_cache3, axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out

View File

@@ -126,7 +126,7 @@ class WINT2Config(QuantConfigBase):
layer (Layer): The layer for which the quantization method should be retrieved.
Returns:
QuantMethodBase: The quantization method associated with the given layer.
QuantMethodBase: The quantization method associated with the given layer.
"""
if isinstance(layer, FusedMoE):
if layer.layer_idx <= self.moe_w4_quant_end_layer:
@@ -135,8 +135,8 @@ class WINT2Config(QuantConfigBase):
{}).get_quant_method(layer)
else:
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
TritonWint2FusedMoeMethod
return TritonWint2FusedMoeMethod(self)
CutlassWint2FusedMoeMethod
return CutlassWint2FusedMoeMethod(self)
else:
return get_quantization_config(self.dense_quant_type).from_config(
{}).get_quant_method(layer)