[Feature] block_wise_fp8 support triton_moe_backend (#2767)

This commit is contained in:
chen
2025-07-09 19:22:47 +08:00
committed by GitHub
parent e3768c5a83
commit 888780ffde
5 changed files with 248 additions and 10 deletions

View File

@@ -68,5 +68,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_PD_CHANGEABLE":
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
# Whether to use DeepGemm for FP8 blockwise MoE.
"FD_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
}
```

View File

@@ -1,5 +1,6 @@
# FastDeploy 环境变量说明
FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明:
```python
environment_variables: dict[str, Callable[[], Any]] = {
# 构建 FastDeploy 时使用的 CUDA 架构版本,这是一个字符串列表,例如[80,90]
@@ -66,5 +67,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_PD_CHANGEABLE":
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
# 是否使用DeepGemm后端的FP8 blockwise MoE.
"FD_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
}
```

View File

@@ -97,6 +97,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use fastsafetensor load weight (0 or 1)
"FD_USE_FASTSAFETENSOR":
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
# Whether to use DeepGemm for FP8 blockwise MoE.
"FD_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
}

View File

@@ -20,7 +20,8 @@ from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.layers.utils import (create_and_set_parameter,
get_tensor)
from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase
@@ -484,3 +485,220 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
tensor_model_parallel_all_reduce(out)
return out
class BlockWiseFP8MoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused BlockWise FP8 Quant MoE.
"""
def __init__(self, quant_config):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
raise NotImplementedError()
def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
weight_list = []
weight_scale_list = []
for i in range(layer.num_local_experts):
from fastdeploy.model_executor.layers.utils import \
per_block_cast_to_fp8
quant_weight, scale = per_block_cast_to_fp8(
weight_tensor[i], self.quant_config.weight_block_size)
weight_list.append(quant_weight)
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
create_and_set_parameter(layer, weight_name, quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
quanted_weight_scale = quanted_weight_scale.transpose(
[0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
"""
check layer is valid for this method
"""
assert ffn1_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
E, N1, _ = layer.moe_ffn1_weight.shape
N2 = layer.moe_ffn2_weight.shape[1]
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": self.quant_config.weight_block_size[1],
"BLOCK_SIZE_K": self.quant_config.weight_block_size[0],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
max_num_tokens_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
from .triton_moe_kernels import fused_moe_kernel_paddle
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0])
cache13 = paddle.empty([token_num * top_k * max(N1, N2)],
dtype=x.dtype)
intermediate_cache1 = cache13[:token_num * top_k * N1].view(
[token_num * top_k, N1])
intermediate_cache3 = cache13[:token_num * top_k * N2].view(
[token_num * top_k, N2])
fused_moe_kernel_paddle[grid](
x_q,
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
intermediate_cache1,
x_scale,
layer.moe_ffn1_weight_scale,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_num_tokens_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[2],
stride_bn=layer.moe_ffn1_weight.strides[1],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
#
stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[2],
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0],
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
intermediate_cache2, self.quant_config.weight_block_size[0])
fused_moe_kernel_paddle[grid](
x_q,
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
intermediate_cache3,
x_scale,
layer.moe_ffn2_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_num_tokens_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[2],
stride_bn=layer.moe_ffn2_weight.strides[1],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[2],
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0],
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
out = intermediate_cache3.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out

View File

@@ -18,9 +18,10 @@ from typing import Optional
import paddle
import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.moe import FusedMoE
from ..utils import per_block_cast_to_fp8, get_tensor
from ..utils import get_tensor, per_block_cast_to_fp8
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -37,6 +38,7 @@ class BlockWiseFP8Config(QuantConfigBase):
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
def name(self) -> str:
return "block_wise_fp8"
@@ -51,9 +53,14 @@ class BlockWiseFP8Config(QuantConfigBase):
Get quantization method.
'''
if isinstance(layer, FusedMoE):
if self.use_deep_gemm:
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
DeepGemmFusedMoeMethod
return DeepGemmFusedMoeMethod(self)
else:
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
BlockWiseFP8MoEMethod
return BlockWiseFP8MoEMethod(self)
else:
return BlockWiseFP8LinearMethod(self)