diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 2cf9ff73d..c630f83b7 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -67,6 +67,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # Switch from standalone PD to centralized inference (0 or 1) "FD_PD_CHANGEABLE": lambda: os.getenv("FD_PD_CHANGEABLE", "1"), - + + # Whether to use DeepGemm for FP8 blockwise MoE. + "FD_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))), + } -``` \ No newline at end of file +``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index d952e757d..10a2a7cc7 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -1,5 +1,6 @@ # FastDeploy 环境变量说明 FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明: + ```python environment_variables: dict[str, Callable[[], Any]] = { # 构建 FastDeploy 时使用的 CUDA 架构版本,这是一个字符串列表,例如[80,90] @@ -65,6 +66,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # 是否从单机 PD 分离转换为集中式推理 "FD_PD_CHANGEABLE": lambda: os.getenv("FD_PD_CHANGEABLE", "1"), - + + # 是否使用DeepGemm后端的FP8 blockwise MoE. + "FD_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))), + } -``` \ No newline at end of file +``` diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 8ef8a5149..6cb34780f 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -97,6 +97,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use fastsafetensor load weight (0 or 1) "FD_USE_FASTSAFETENSOR": lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"), + + # Whether to use DeepGemm for FP8 blockwise MoE. + "FD_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))), } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index caf395c5d..c113fe712 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,7 +20,8 @@ from paddle import nn import fastdeploy from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce -from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.layers.utils import (create_and_set_parameter, + get_tensor) from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase @@ -191,7 +192,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): ffn2_input = paddle.incubate.nn.functional.swiglu( ffn1_out) - + ffn2_out = paddle.empty( (token_num * top_k, hidden_size), dtype=x.dtype, @@ -484,3 +485,220 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): tensor_model_parallel_all_reduce(out) return out + +class BlockWiseFP8MoEMethod(QuantMethodBase): + """ + Use Triton Group Gemm to compute Fused BlockWise FP8 Quant MoE. + """ + + def __init__(self, quant_config): + """ + Triton Group Gemm to compute Fused MoE. + """ + self.quant_config = quant_config + self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] + self.added_scale_attrs = [ + "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" + ] + + def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: + """process_prequanted_weights""" + + raise NotImplementedError() + + def create_weights(self, layer: nn.Layer, state_dict): + """ + Triton MoE create weight process. + """ + ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) + + self.check(layer, ffn1_weights, ffn2_weights) + + for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + weight_name = self.added_weight_attrs[idx] + scale_name = self.added_scale_attrs[idx] + + weight_list = [] + weight_scale_list = [] + for i in range(layer.num_local_experts): + from fastdeploy.model_executor.layers.utils import \ + per_block_cast_to_fp8 + quant_weight, scale = per_block_cast_to_fp8( + weight_tensor[i], self.quant_config.weight_block_size) + + weight_list.append(quant_weight) + weight_scale_list.append(scale) + quanted_weight = paddle.stack(weight_list, axis=0) + quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous() + create_and_set_parameter(layer, weight_name, quanted_weight) + + quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) + quanted_weight_scale = quanted_weight_scale.transpose( + [0, 2, 1]).contiguous() + create_and_set_parameter(layer, scale_name, quanted_weight_scale) + + def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights): + """ + check layer is valid for this method + """ + assert ffn1_weights[0].shape == [ + layer.hidden_size, layer.moe_intermediate_size * 2 + ] + assert ffn2_weights[0].shape == [ + layer.moe_intermediate_size, layer.hidden_size + ] + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Triton compute Fused MoE. + """ + + token_num = x.shape[0] + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + E, N1, _ = layer.moe_ffn1_weight.shape + N2 = layer.moe_ffn2_weight.shape[1] + + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": self.quant_config.weight_block_size[1], + "BLOCK_SIZE_K": self.quant_config.weight_block_size[0], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( + topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) + max_num_tokens_padded = sorted_token_ids.shape[0] + + grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * + ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) + + from .triton_moe_kernels import fused_moe_kernel_paddle + + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( + x, self.quant_config.weight_block_size[0]) + + cache13 = paddle.empty([token_num * top_k * max(N1, N2)], + dtype=x.dtype) + intermediate_cache1 = cache13[:token_num * top_k * N1].view( + [token_num * top_k, N1]) + intermediate_cache3 = cache13[:token_num * top_k * N2].view( + [token_num * top_k, N2]) + + fused_moe_kernel_paddle[grid]( + x_q, + layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), + intermediate_cache1, + x_scale, + layer.moe_ffn1_weight_scale, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_num_tokens_padded, + token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, + stride_am=x_q.strides[0], + stride_ak=x_q.strides[1], + stride_be=layer.moe_ffn1_weight.strides[0], + stride_bk=layer.moe_ffn1_weight.strides[2], + stride_bn=layer.moe_ffn1_weight.strides[1], + stride_cm=intermediate_cache1.strides[0], + stride_cn=intermediate_cache1.strides[1], + # + stride_asm=x_scale.strides[0], # only used in blockwise fp8 + stride_ask=x_scale.strides[1], # only used in blockwise fp8 + stride_bse=layer.moe_ffn1_weight_scale.strides[0], + stride_bsk=layer.moe_ffn1_weight_scale.strides[2], + stride_bsn=layer.moe_ffn1_weight_scale.strides[1], + group_n=self.quant_config.weight_block_size[1], + group_k=self.quant_config.weight_block_size[0], + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, + ) + + intermediate_cache2 = paddle.incubate.nn.functional.swiglu( + intermediate_cache1) + + grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * + ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) + + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( + intermediate_cache2, self.quant_config.weight_block_size[0]) + + fused_moe_kernel_paddle[grid]( + x_q, + layer.moe_ffn2_weight.view(paddle.float8_e4m3fn), + intermediate_cache3, + x_scale, + layer.moe_ffn2_weight_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_num_tokens_padded, + token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, + stride_am=x_q.strides[0], + stride_ak=x_q.strides[1], + stride_be=layer.moe_ffn2_weight.strides[0], + stride_bk=layer.moe_ffn2_weight.strides[2], + stride_bn=layer.moe_ffn2_weight.strides[1], + stride_cm=intermediate_cache3.strides[0], + stride_cn=intermediate_cache3.strides[1], + stride_asm=x_scale.strides[0], # only used in blockwise fp8 + stride_ask=x_scale.strides[1], # only used in blockwise fp8 + stride_bse=layer.moe_ffn2_weight_scale.strides[0], + stride_bsk=layer.moe_ffn2_weight_scale.strides[2], + stride_bsn=layer.moe_ffn2_weight_scale.strides[1], + group_n=self.quant_config.weight_block_size[1], + group_k=self.quant_config.weight_block_size[0], + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, + ) + + intermediate_cache3.reshape_([token_num, top_k, hidden_size]) + out = intermediate_cache3.sum(axis=1) + + if layer.tp_size > 1: + tensor_model_parallel_all_reduce(out) + + return out diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index dea8c703b..af061ce83 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -18,9 +18,10 @@ from typing import Optional import paddle import fastdeploy +from fastdeploy import envs from fastdeploy.model_executor.layers.moe import FusedMoE -from ..utils import per_block_cast_to_fp8, get_tensor +from ..utils import get_tensor, per_block_cast_to_fp8 from .quant_base import QuantConfigBase, QuantMethodBase @@ -37,6 +38,7 @@ class BlockWiseFP8Config(QuantConfigBase): self.quant_max_bound = 448 self.quant_min_bound = -448 self.quant_round_type = 1 + self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM) def name(self) -> str: return "block_wise_fp8" @@ -51,9 +53,14 @@ class BlockWiseFP8Config(QuantConfigBase): Get quantization method. ''' if isinstance(layer, FusedMoE): - from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \ - DeepGemmFusedMoeMethod - return DeepGemmFusedMoeMethod(self) + if self.use_deep_gemm: + from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \ + DeepGemmFusedMoeMethod + return DeepGemmFusedMoeMethod(self) + else: + from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \ + BlockWiseFP8MoEMethod + return BlockWiseFP8MoEMethod(self) else: return BlockWiseFP8LinearMethod(self)