mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Quantization] Support w4afp8 MoE dynamic quantization (#5282)
* support dynamic activation quant for w4afp8 * support dynamic w4afp8 * add test * fix * fix --------- Co-authored-by: zhoutianzi666 <17801055074@163.com>
This commit is contained in:
@@ -430,7 +430,9 @@ __global__ void permute_x_kernel(
|
||||
}
|
||||
abs_max = phi::BlockAllReduce<MaxOp, float, Kthread>(abs_max);
|
||||
float scale = 440.f / abs_max; // use 440 so we do not have to clip
|
||||
dequant_scale[dst_token_idx] = abs_max;
|
||||
if (tid == 0) {
|
||||
dequant_scale[dst_token_idx] = abs_max;
|
||||
}
|
||||
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
|
||||
Load<T, vec_size>(&data_smem[v_id * vec_size], &src_vec);
|
||||
#pragma unroll
|
||||
@@ -661,7 +663,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
|
||||
int dequant_scale_size = 1;
|
||||
if (moe_quant_type == "w4afp8" && !up_gate_proj_in_scale) {
|
||||
dequant_scale_size = moe_topk * num_rows;
|
||||
dequant_scale_size = token_nums_this_rank;
|
||||
}
|
||||
|
||||
auto dequant_scale =
|
||||
|
||||
@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
||||
"""
|
||||
|
||||
# [M, K, Number of experts, token Padding Size, weight K group size]
|
||||
gemm_case = [[256, 256, 2, 0, 128]]
|
||||
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
|
||||
@@ -295,7 +295,7 @@ class DeepEPEngine:
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
# num_per_channel=quant_group_size,
|
||||
num_per_channel=quant_group_size,
|
||||
)
|
||||
|
||||
return packed_recv_x, recv_expert_count, handle, dispatch_hook
|
||||
@@ -634,10 +634,11 @@ class EPDecoderRunner(EPRunner):
|
||||
):
|
||||
expertwise_scale = kwargs.get("expertwise_scale", None)
|
||||
use_fp8 = kwargs.get("use_fp8", False)
|
||||
quant_group_size = kwargs.get("quant_group_size", 128)
|
||||
|
||||
if not self.use_internode_ll_two_stage:
|
||||
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
|
||||
x, topk_idx, expertwise_scale, use_fp8
|
||||
x, topk_idx, expertwise_scale, use_fp8, quant_group_size
|
||||
)
|
||||
else:
|
||||
# just supports dispatch_use_fp8 = True now!
|
||||
|
||||
@@ -22,7 +22,7 @@ from paddleformers.utils.log import logger
|
||||
import fastdeploy
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import get_tensor
|
||||
from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_model
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@@ -745,7 +745,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = "w4afp8"
|
||||
self.pack_num = 2
|
||||
self.pack_num = 2 if quant_config.is_quantized else 1
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
@@ -912,21 +912,58 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
"""
|
||||
if not layer.is_quantized:
|
||||
logger.info(
|
||||
f"Rotating ernie.layers.{layer.layer_idx}.mlp.experts.[{layer.ep_rank * layer.num_local_experts},{layer.ep_rank * layer.num_local_experts + layer.num_local_experts}).down_proj.weight..."
|
||||
)
|
||||
rotate_model(
|
||||
state_dict,
|
||||
layer.layer_idx,
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
ep_rank=layer.ep_rank,
|
||||
)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
up_gate_proj_weight_scales = []
|
||||
down_proj_weight_scales = []
|
||||
dynamic_scale_weight_map = {
|
||||
self.added_scale_attrs[0]: up_gate_proj_weight_scales,
|
||||
self.added_scale_attrs[1]: down_proj_weight_scales,
|
||||
}
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
weight_scale_name = self.added_scale_attrs[idx]
|
||||
weight_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i])
|
||||
quant_weight = weight_tensor[i]
|
||||
if not layer.is_quantized:
|
||||
block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 512)
|
||||
quant_weight, weight_scale = group_wise_int4_weight_quantize(weight_tensor[i], group_size=128)
|
||||
free_tensor(weight_tensor[i])
|
||||
quant_weight = pack(quant_weight.transpose([1, 0]), bits=4)
|
||||
if "down_proj" in weight_name:
|
||||
weight_scale = weight_scale / (block_size**0.5)
|
||||
dynamic_scale_weight_map[weight_scale_name].append(weight_scale)
|
||||
|
||||
quant_weight = w4afp8_gemm_weight_convert(quant_weight)
|
||||
weight_list.append(quant_weight)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
self.load_w4afp8_scale_weights(
|
||||
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
layer,
|
||||
layer.weight_key_map,
|
||||
state_dict,
|
||||
logical_expert_ids,
|
||||
ep_rank_to_expert_id_list,
|
||||
dynamic_scale_weight_map,
|
||||
)
|
||||
|
||||
def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
||||
@@ -938,7 +975,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
if layer.ep_size > 1 and layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_in_scale_all_experts",
|
||||
@@ -950,7 +987,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
)
|
||||
|
||||
# in_scales
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
if layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
setattr(
|
||||
layer,
|
||||
@@ -963,24 +1000,25 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
)
|
||||
|
||||
# weight_scales
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
"down_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
if layer.is_quantized:
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
"down_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def load_w4afp8_scale_weights(
|
||||
self,
|
||||
@@ -989,6 +1027,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
state_dict: dict,
|
||||
logical_expert_ids: paddle.Tensor,
|
||||
ep_rank_to_expert_id_list: list,
|
||||
dynamic_scale_weight_map: dict,
|
||||
):
|
||||
"""
|
||||
Get w4afp8 weights from state dict and process them.
|
||||
@@ -1095,7 +1134,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
if layer.ep_size > 1 and layer.is_quantized and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
@@ -1110,11 +1149,14 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
if hasattr(layer, name):
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
if not layer.is_quantized:
|
||||
scale_weight_map = dynamic_scale_weight_map
|
||||
else:
|
||||
for expert_idx in logical_expert_ids:
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
if hasattr(layer, name):
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
|
||||
|
||||
@@ -84,6 +84,13 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
quant_config_name = "mix_quant"
|
||||
# Special handling for moe w4afp8 dynamic quant
|
||||
elif quant_config_name == "w4afp8":
|
||||
quantization_config["dense_quant_type"] = "block_wise_fp8"
|
||||
quantization_config["moe_quant_type"] = "w4afp8"
|
||||
quantization_config["hadamard_block_size"] = 512
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
quant_config_name = "mix_quant"
|
||||
else:
|
||||
quant_config_name = None
|
||||
if quant_config_name is None:
|
||||
|
||||
@@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase):
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None:
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size, is_quantized) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
@@ -40,6 +40,7 @@ class W4AFP8Config(QuantConfigBase):
|
||||
self.quant_round_type = 1
|
||||
self.is_permuted = is_permuted
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
self.is_quantized = is_quantized
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4afp8"
|
||||
@@ -50,7 +51,8 @@ class W4AFP8Config(QuantConfigBase):
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
is_permuted = config.get("is_permuted", True)
|
||||
hadamard_block_size = config.get("hadamard_block_size", 128)
|
||||
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size)
|
||||
is_quantized = config.get("is_quantized", False)
|
||||
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size, is_quantized)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
|
||||
@@ -53,6 +53,173 @@ def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) ->
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def random_orthogonal_matrix(size, device):
|
||||
"""
|
||||
Generate a random orthogonal matrix of the specified size.
|
||||
First, we generate a random matrix with entries from a standard distribution.
|
||||
Then, we use QR decomposition to obtain an orthogonal matrix.
|
||||
Finally, we multiply by a diagonal matrix with diag r to adjust the signs.
|
||||
|
||||
Args:
|
||||
size (int): The size of the matrix (size x size).
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: An orthogonal matrix of the specified size.
|
||||
"""
|
||||
paddle.device.cuda.empty_cache()
|
||||
if device == "cuda":
|
||||
random_matrix = paddle.randn(size, size, dtype="float32").to("gpu")
|
||||
q, r = paddle.linalg.qr(random_matrix)
|
||||
q *= paddle.sign(paddle.diag(r)).unsqueeze(0)
|
||||
return q
|
||||
|
||||
|
||||
def is_pow2(n):
|
||||
return (n & (n - 1) == 0) and (n > 0)
|
||||
|
||||
|
||||
def get_hadK(n, transpose=False):
|
||||
hadK, K = None, None
|
||||
assert is_pow2(n)
|
||||
K = 1
|
||||
return hadK, K
|
||||
|
||||
|
||||
def matmul_hadU_int4(X, transpose=False):
|
||||
n = X.shape[-1]
|
||||
hadK, K = get_hadK(n, transpose)
|
||||
input = X.clone().reshape((-1, n, 1))
|
||||
output = input.clone()
|
||||
while input.shape[1] > K:
|
||||
input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
|
||||
output = output.reshape(input.shape)
|
||||
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
|
||||
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
|
||||
output = output.reshape((input.shape[0], input.shape[1], -1))
|
||||
(input, output) = (output, input)
|
||||
del output
|
||||
|
||||
if K > 1:
|
||||
input = hadK.reshape((1, K, K)).to(input) @ input
|
||||
|
||||
return input.reshape(X.shape) / paddle.to_tensor(n, dtype="float32").sqrt()
|
||||
|
||||
|
||||
def random_hadamard_matrix_int4(size, device=None, ffn2=False):
|
||||
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
|
||||
if not ffn2:
|
||||
Q = paddle.randint(low=0, high=2, shape=(size,)).cast("float32")
|
||||
Q = paddle.ones_like(Q, dtype="float32")
|
||||
Q = Q * 2 - 1
|
||||
Q = paddle.diag(Q)
|
||||
return matmul_hadU_int4(Q), None
|
||||
|
||||
else:
|
||||
num_blocks = size
|
||||
while not (num_blocks % 2):
|
||||
num_blocks = num_blocks // 2
|
||||
block_size = size // num_blocks
|
||||
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
|
||||
block = matmul_hadU_int4(Q)
|
||||
large_matrix = paddle.zeros([size, size])
|
||||
|
||||
for i in range(num_blocks):
|
||||
start_row = i * block_size
|
||||
start_col = i * block_size
|
||||
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
|
||||
return large_matrix.cast("float32"), block_size
|
||||
|
||||
|
||||
def get_orthogonal_matrix(size, mode="hadamard", device="cuda"):
|
||||
if mode == "random":
|
||||
return random_orthogonal_matrix(size, device)
|
||||
elif mode == "hadamard":
|
||||
return random_hadamard_matrix_int4(size, device)
|
||||
elif mode == "hadamard_ffn2":
|
||||
return random_hadamard_matrix_int4(size, device, True)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {mode}")
|
||||
|
||||
|
||||
def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, moe_intermediate_size=3584, ep_rank=0):
|
||||
with paddle.no_grad():
|
||||
# collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size]
|
||||
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2")
|
||||
# down_proj.weight: [moe_intermediate_size, hidden_size]
|
||||
expert_list = [
|
||||
get_tensor(
|
||||
state_dict[
|
||||
f"ernie.layers.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight"
|
||||
]
|
||||
)
|
||||
for expert_idx in range(moe_num_experts)
|
||||
]
|
||||
moe_weight = paddle.concat(expert_list, axis=-1) # [moe_intermediate_size, hidden_size * moe_num_experts]
|
||||
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.to(Q_ffn2.place)
|
||||
for expert_idx in range(moe_num_experts):
|
||||
rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size]
|
||||
expert_idx_local = ep_rank * moe_num_experts + expert_idx
|
||||
state_dict[f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = (
|
||||
rotated_weight.cpu()
|
||||
)
|
||||
del moe_weight, new_moe_weight, rotated_weight
|
||||
paddle.device.cuda.empty_cache()
|
||||
return Q_ffn2.cpu()
|
||||
|
||||
|
||||
def pack(src, bits=4):
|
||||
pack_num = 8 // bits
|
||||
shift_bits = (paddle.arange(0, pack_num) * bits).cast("uint8")
|
||||
src = paddle.to_tensor(src).cast("uint8")
|
||||
|
||||
if len(src.shape) == 2:
|
||||
row, col = src.shape
|
||||
src = src.reshape((row, col // pack_num, pack_num))
|
||||
else:
|
||||
src = src.reshape((src.shape[0] // pack_num, pack_num))
|
||||
|
||||
src[..., 0] = paddle.bitwise_and(src[..., 0], paddle.to_tensor(15, dtype="uint8"))
|
||||
src = paddle.to_tensor(src.numpy() << shift_bits.numpy())
|
||||
|
||||
return src.sum(axis=-1).transpose((1, 0)).cast("int8")
|
||||
|
||||
|
||||
def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128):
|
||||
"""
|
||||
Block-wise int4 weight quantization.
|
||||
|
||||
Args
|
||||
weight: paddle.Tensor
|
||||
group_size: int
|
||||
|
||||
Returns
|
||||
weight_quant: paddle.Tensor, int8 weight after quantization and pack
|
||||
weight_scale: paddle.Tensor, fp32 weight scale with group_size
|
||||
"""
|
||||
if weight.dtype == paddle.bfloat16:
|
||||
weight = weight.astype(paddle.float32)
|
||||
assert weight.dim() == 2
|
||||
weight = weight.transpose((1, 0))
|
||||
out_features, in_features = weight.shape
|
||||
q_max, q_min = 7, -8
|
||||
|
||||
# [out_features, in_features] -> [out_features, in_features // group_size, group_size]
|
||||
assert (
|
||||
in_features % group_size == 0
|
||||
), f"in_features must be divisible by group_size: {group_size}, but got in_features: {in_features}"
|
||||
weight = weight.reshape((out_features, in_features // group_size, group_size))
|
||||
|
||||
# calculate weight_scale
|
||||
abs_max = paddle.max(paddle.abs(weight), axis=-1, keepdim=False).astype(paddle.float32)
|
||||
weight_scale = paddle.clip(abs_max, min=1e-8)
|
||||
|
||||
quant_weight = paddle.round(weight / weight_scale.unsqueeze(-1) * q_max)
|
||||
quant_weight = paddle.clip(quant_weight, min=q_min, max=q_max)
|
||||
quant_weight = quant_weight.reshape((out_features, in_features)).transpose((1, 0))
|
||||
|
||||
return quant_weight.astype(paddle.int8), weight_scale
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Only used in deep_gemm block wise quant weight.
|
||||
|
||||
219
tests/layers/test_w4afp8_moe.py
Normal file
219
tests/layers/test_w4afp8_moe.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
|
||||
# from fastdeploy.worker.worker_process import init_distributed_environment
|
||||
from tests.utils import OpPerformanceTester
|
||||
|
||||
paddle.set_default_dtype("bfloat16")
|
||||
|
||||
|
||||
class FuseMoEWrapper(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tp_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
prefix: str = "ernie.layers.0",
|
||||
nnodes: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.ep_size = ep_size
|
||||
self.ep_rank = ep_rank
|
||||
|
||||
self.prefix = prefix
|
||||
self.fd_config = FDConfig(
|
||||
model_config=self.model_config,
|
||||
parallel_config=ParallelConfig(
|
||||
{
|
||||
"tensor_parallel_size": self.tp_size,
|
||||
"expert_parallel_size": self.ep_size,
|
||||
"expert_parallel_rank": self.ep_rank,
|
||||
"data_parallel_size": self.ep_size,
|
||||
}
|
||||
),
|
||||
quant_config=W4AFP8Config(
|
||||
weight_scale_dict=None,
|
||||
act_scale_dict=None,
|
||||
is_permuted=False,
|
||||
hadamard_block_size=512,
|
||||
is_quantized=False,
|
||||
),
|
||||
scheduler_config=SchedulerConfig({}),
|
||||
cache_config=CacheConfig({}),
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips=",".join(["0"] * nnodes),
|
||||
)
|
||||
self.fd_config.parallel_config.tp_group = None
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank
|
||||
self.fd_config.parallel_config.expert_parallel_size = self.ep_size
|
||||
if self.ep_size > 1:
|
||||
self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
|
||||
self.fd_config.scheduler_config.splitwise_role = "mixed"
|
||||
self.fd_config.model_config.moe_phase.phase = "decode"
|
||||
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{self.prefix}.gate.weight",
|
||||
"gate_correction_bias_key": f"{self.prefix}.moe_statics.e_score_correction_bias",
|
||||
"up_gate_proj_expert_weight_key": f"{self.prefix}.mlp.experts.{{}}.up_gate_proj.weight",
|
||||
"down_proj_expert_weight_key": f"{self.prefix}.mlp.experts.{{}}.down_proj.weight",
|
||||
}
|
||||
|
||||
self.fused_moe = FusedMoE(
|
||||
fd_config=self.fd_config,
|
||||
moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=self.fd_config.model_config.moe_num_experts,
|
||||
top_k=self.fd_config.model_config.moe_k,
|
||||
# avoiding invoke clean_low_latency_buffer in mixed ep.
|
||||
layer_idx=0,
|
||||
weight_key_map=weight_key_map,
|
||||
topk_method="noaux_tc",
|
||||
topk_group=4,
|
||||
n_group=8,
|
||||
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
|
||||
# gate_correction_bias = gate_correction_bias_real_data
|
||||
)
|
||||
self.pack_num = 1
|
||||
moe_layer = self.fused_moe
|
||||
|
||||
up_gate_proj_weight_shape = [
|
||||
moe_layer.num_local_experts,
|
||||
moe_layer.hidden_size // self.pack_num,
|
||||
moe_layer.moe_intermediate_size * 2,
|
||||
]
|
||||
down_proj_weight_shape = [
|
||||
moe_layer.num_local_experts,
|
||||
moe_layer.moe_intermediate_size // self.pack_num,
|
||||
moe_layer.hidden_size,
|
||||
]
|
||||
|
||||
up_gate_proj_weight = paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16)
|
||||
down_proj_weight = paddle.randn(down_proj_weight_shape, paddle.bfloat16)
|
||||
|
||||
local_expert_ids = list(
|
||||
range(moe_layer.expert_id_offset, moe_layer.expert_id_offset + moe_layer.num_local_experts)
|
||||
)
|
||||
state_dict = {}
|
||||
up_gate_proj_expert_weight_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_key")
|
||||
down_proj_expert_weight_key = moe_layer.weight_key_map.get("down_proj_expert_weight_key")
|
||||
|
||||
for expert_idx in local_expert_ids:
|
||||
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
|
||||
state_dict[up_gate_proj_expert_weight_key_name] = up_gate_proj_weight[
|
||||
expert_idx - moe_layer.expert_id_offset
|
||||
]
|
||||
state_dict[down_proj_expert_weight_key_name] = down_proj_weight[expert_idx - moe_layer.expert_id_offset]
|
||||
|
||||
moe_layer.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class TestW4A8FusedMoE(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.architectures = ["Ernie4_5_MoeForCausalLM"]
|
||||
self.hidden_size = 256
|
||||
self.moe_intermediate_size = 256
|
||||
self.moe_num_experts = 2
|
||||
self.moe_k = 2
|
||||
self.hidden_act = "silu"
|
||||
self.num_attention_heads = 56
|
||||
self.num_hidden_layers = 1
|
||||
self.model_config = self.build_model_config()
|
||||
|
||||
def build_model_config(self) -> ModelConfig:
|
||||
model_name_or_path = self.build_config_json()
|
||||
return ModelConfig(
|
||||
{
|
||||
"model": model_name_or_path,
|
||||
"max_model_len": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
def build_config_json(self) -> str:
|
||||
config_dict = {
|
||||
"architectures": self.architectures,
|
||||
"hidden_size": self.hidden_size,
|
||||
"moe_intermediate_size": self.moe_intermediate_size,
|
||||
"moe_num_experts": self.moe_num_experts,
|
||||
"moe_k": self.moe_k,
|
||||
"hidden_act": self.hidden_act,
|
||||
"num_attention_heads": self.num_attention_heads,
|
||||
"num_hidden_layers": self.num_hidden_layers,
|
||||
"dtype": "bfloat16",
|
||||
"is_quantized": False,
|
||||
}
|
||||
|
||||
tmp_dir = "./tmp_w4afp8_moe"
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
with open(f"./{tmp_dir}/config.json", "w") as f:
|
||||
json.dump(config_dict, f)
|
||||
self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
|
||||
return self.model_name_or_path
|
||||
|
||||
def test_fused_moe(self):
|
||||
# init_distributed_environment()
|
||||
|
||||
gating = paddle.nn.Linear(self.model_config.hidden_size, self.model_config.moe_num_experts)
|
||||
gating.to(dtype=paddle.float32) # it's dtype is bfloat16 default, but the forward input is float32
|
||||
gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32))
|
||||
|
||||
# ep_size = paddle.distributed.get_world_size()
|
||||
# ep_rank = paddle.distributed.get_rank()
|
||||
ep_size = 1
|
||||
ep_rank = 0
|
||||
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
|
||||
nnodes = (ep_size + 7) // 8
|
||||
|
||||
# 这行代码必须保留,否则影响均匀性!
|
||||
paddle.seed(ep_rank + 100)
|
||||
|
||||
fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes).fused_moe
|
||||
weight_size = fused_moe.top_k * fused_moe.hidden_size * fused_moe.moe_intermediate_size * 3 / 2
|
||||
|
||||
tester = OpPerformanceTester(
|
||||
op_name="w4afp8-moe",
|
||||
op_fn=fused_moe,
|
||||
num_layers=self.model_config.num_hidden_layers,
|
||||
weight_size=weight_size,
|
||||
gate=gating,
|
||||
)
|
||||
|
||||
tester.benchmark(
|
||||
input_size=self.model_config.hidden_size,
|
||||
batch_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.model_name_or_path:
|
||||
print("Remove tmp model config file")
|
||||
shutil.rmtree(self.model_name_or_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -32,6 +32,7 @@ class TestW4AFP8(unittest.TestCase):
|
||||
act_scale_dict={"layer.activation_scale": 1.0},
|
||||
is_permuted=False,
|
||||
hadamard_block_size=128,
|
||||
is_quantized=True,
|
||||
)
|
||||
self.method = W4AFP8LinearMethod(self.config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user