mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 05:30:58 +08:00
【Inference Optimize】Support wint2 triton kernel about triton_utils_v2 (#2842)
* update supported_models doc
This commit is contained in:
@@ -20,6 +20,7 @@ from paddle import nn
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
from ..utils import create_and_set_parameter, get_tensor
|
||||
@@ -58,7 +59,7 @@ class Wint2MoeMethod(QuantMethodBase):
|
||||
pass
|
||||
|
||||
|
||||
class TritonWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
@@ -239,3 +240,177 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.moe_quant_type = quant_config.moe_quant_type
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops import \
|
||||
moe_wint2_ffn_kernel
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
num_tokens, K = x.shape
|
||||
E, _, N = layer.moe_ffn1_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[M, top_k, N],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(M, top_k, K),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
double_quant = True
|
||||
num_valid_tokens = topk_ids.shape[0] * topk_ids.shape[1]
|
||||
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 512,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 16,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, E, config["BLOCK_SIZE_M"])
|
||||
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(N, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
moe_wint2_ffn_kernel[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
intermediate_cache1,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.moe_ffn1_super_scales,
|
||||
layer.moe_ffn1_code_scale,
|
||||
layer.moe_ffn1_code_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
num_valid_tokens,
|
||||
max_possible_num_post_padded,
|
||||
# Matrix dimensions
|
||||
N=layer.moe_ffn1_weight.shape[-1],
|
||||
K=x.shape[-1],
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=1,
|
||||
stride_cm=intermediate_cache1.strides[-2],
|
||||
stride_cn=1,
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn1_weight_scale.strides[1],
|
||||
stride_bsn=1,
|
||||
stride_bce=layer.moe_ffn1_code_scale.strides[0],
|
||||
stride_bck=1,
|
||||
stride_bcn=1,
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
USE_DOUBLE_QUANT=double_quant,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N]))
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 8,
|
||||
}
|
||||
|
||||
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), )
|
||||
|
||||
|
||||
moe_wint2_ffn_kernel[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
intermediate_cache3,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.moe_ffn2_super_scales,
|
||||
layer.moe_ffn2_code_scale,
|
||||
layer.moe_ffn2_code_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
num_valid_tokens,
|
||||
max_possible_num_post_padded,
|
||||
# Matrix dimensions
|
||||
N=layer.moe_ffn2_weight.shape[-1],
|
||||
K=intermediate_cache2.shape[-1],
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=1,
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=1,
|
||||
stride_cm=intermediate_cache3.strides[-2],
|
||||
stride_cn=1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn2_weight_scale.strides[1],
|
||||
stride_bsn=1,
|
||||
stride_bce=layer.moe_ffn2_code_scale.strides[0],
|
||||
stride_bck=1,
|
||||
stride_bcn=1,
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
USE_DOUBLE_QUANT=double_quant,
|
||||
top_k=1,
|
||||
)
|
||||
|
||||
fused_moe_out = paddle.sum(intermediate_cache3, axis=1)
|
||||
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
@@ -126,7 +126,7 @@ class WINT2Config(QuantConfigBase):
|
||||
layer (Layer): The layer for which the quantization method should be retrieved.
|
||||
|
||||
Returns:
|
||||
QuantMethodBase: The quantization method associated with the given layer.
|
||||
QuantMethodBase: The quantization method associated with the given layer.
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.layer_idx <= self.moe_w4_quant_end_layer:
|
||||
@@ -135,8 +135,8 @@ class WINT2Config(QuantConfigBase):
|
||||
{}).get_quant_method(layer)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
|
||||
TritonWint2FusedMoeMethod
|
||||
return TritonWint2FusedMoeMethod(self)
|
||||
CutlassWint2FusedMoeMethod
|
||||
return CutlassWint2FusedMoeMethod(self)
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
|
Reference in New Issue
Block a user