mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[Feature] block_wise_fp8 support triton_moe_backend (#2767)
This commit is contained in:
@@ -68,5 +68,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
|
||||
# Whether to use DeepGemm for FP8 blockwise MoE.
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
|
||||
|
||||
}
|
||||
```
|
@@ -1,5 +1,6 @@
|
||||
# FastDeploy 环境变量说明
|
||||
FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明:
|
||||
|
||||
```python
|
||||
environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 构建 FastDeploy 时使用的 CUDA 架构版本,这是一个字符串列表,例如[80,90]
|
||||
@@ -66,5 +67,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
|
||||
# 是否使用DeepGemm后端的FP8 blockwise MoE.
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
|
||||
|
||||
}
|
||||
```
|
@@ -97,6 +97,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Whether to use fastsafetensor load weight (0 or 1)
|
||||
"FD_USE_FASTSAFETENSOR":
|
||||
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
|
||||
|
||||
# Whether to use DeepGemm for FP8 blockwise MoE.
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -20,7 +20,8 @@ from paddle import nn
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.layers.utils import (create_and_set_parameter,
|
||||
get_tensor)
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
@@ -484,3 +485,220 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused BlockWise FP8 Quant MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
from fastdeploy.model_executor.layers.utils import \
|
||||
per_block_cast_to_fp8
|
||||
quant_weight, scale = per_block_cast_to_fp8(
|
||||
weight_tensor[i], self.quant_config.weight_block_size)
|
||||
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
quanted_weight_scale = quanted_weight_scale.transpose(
|
||||
[0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
E, N1, _ = layer.moe_ffn1_weight.shape
|
||||
N2 = layer.moe_ffn2_weight.shape[1]
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": self.quant_config.weight_block_size[1],
|
||||
"BLOCK_SIZE_K": self.quant_config.weight_block_size[0],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0])
|
||||
|
||||
cache13 = paddle.empty([token_num * top_k * max(N1, N2)],
|
||||
dtype=x.dtype)
|
||||
intermediate_cache1 = cache13[:token_num * top_k * N1].view(
|
||||
[token_num * top_k, N1])
|
||||
intermediate_cache3 = cache13[:token_num * top_k * N2].view(
|
||||
[token_num * top_k, N2])
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
|
||||
intermediate_cache1,
|
||||
x_scale,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_num_tokens_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[2],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[1],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn1_weight_scale.strides[2],
|
||||
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=top_k,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
intermediate_cache2, self.quant_config.weight_block_size[0])
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
|
||||
intermediate_cache3,
|
||||
x_scale,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_num_tokens_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[2],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[1],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn2_weight_scale.strides[2],
|
||||
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
@@ -18,9 +18,10 @@ from typing import Optional
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
|
||||
from ..utils import per_block_cast_to_fp8, get_tensor
|
||||
from ..utils import get_tensor, per_block_cast_to_fp8
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
@@ -37,6 +38,7 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
|
||||
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
@@ -51,9 +53,14 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
Get quantization method.
|
||||
'''
|
||||
if isinstance(layer, FusedMoE):
|
||||
if self.use_deep_gemm:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
|
||||
DeepGemmFusedMoeMethod
|
||||
return DeepGemmFusedMoeMethod(self)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
BlockWiseFP8MoEMethod
|
||||
return BlockWiseFP8MoEMethod(self)
|
||||
else:
|
||||
return BlockWiseFP8LinearMethod(self)
|
||||
|
||||
|
Reference in New Issue
Block a user