Files
FastDeploy/test/operators/test_moe_ffn_wint2.py
Yiqun Liu 327e1943fa Add uinttest for moe_ffn_wint2. (#3037)
Change-Id: Ifd452527eaf87ea96c3fa4fa9aeb17729b33c2de
2025-07-30 15:03:09 +08:00

497 lines
18 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import paddle
from paddle.nn.quant import weight_quantize
from fastdeploy.model_executor.ops.gpu import (
moe_expert_dispatch,
moe_expert_ffn,
moe_expert_ffn_wint2,
moe_expert_reduce,
)
def print_tensor_info(t, name):
if t is not None:
print(f"-- [print_tensor_info] {name}: shape={t.shape}, dtype={t.dtype}, data_ptr={t.data_ptr():#x}")
else:
print(f"-- [print_tensor_info] {name}: tensor is {t}")
def load_all_tensors(tensor_names, dump_dir):
tensor_dict = {}
for name in tensor_names:
key = name.replace(".pdparams", "").replace("_layer1", "")
filepath = os.path.join(dump_dir, name)
if os.path.exists(filepath):
tensor_dict[key] = paddle.load(filepath)
if isinstance(tensor_dict[key], paddle.Tensor):
print_tensor_info(tensor_dict[key], name)
else:
print(f"-- {name}: {tensor_dict[key]}")
else:
tensor_dict[key] = None
print(f"-- {name}: {filepath} does not exist.")
return tensor_dict
def check_result(dtype, out_1, out_2, check_equal=False):
def get_flattened_array(out):
if isinstance(out, paddle.Tensor):
if out.dtype == paddle.bfloat16:
res = paddle.cast(out, dtype="float32").numpy()
else:
res = out.numpy()
else:
res = out
return res.flatten()
out_1_flatten = get_flattened_array(out_1)
out_2_flatten = get_flattened_array(out_2)
diff = np.abs(out_1_flatten - out_2_flatten)
max_atol_idx = np.argmax(diff)
print(
f"-- max difference : {np.max(diff)}, {out_1_flatten[max_atol_idx]} vs {out_2_flatten[max_atol_idx]}, idx={max_atol_idx}"
)
relative_error = np.abs(diff / (out_2_flatten + 1e-8))
max_rtol_idx = np.nanargmax(relative_error)
print(
f"-- max relative error : {np.nanmax(relative_error)}, {out_1_flatten[max_rtol_idx]} vs {out_2_flatten[max_rtol_idx]}"
)
if check_equal:
num_diffs = 0
for i in range(out_1.size):
if num_diffs >= 10:
break
if out_1_flatten[i] != out_2_flatten[i]:
print(f"-- {i}: {out_1_flatten[i]} vs {out_2_flatten[i]}")
num_diffs += 1
np.testing.assert_array_equal(out_1, out_2)
else:
if dtype == "float32":
if os.getenv("NVIDIA_TF32_OVERRIDE", "1") == "0":
atol, rtol = 1e-5, 1e-5
else:
atol, rtol = 1e-3, 1e-3
elif dtype == "float16":
atol, rtol = 1e-3, 1e-3
elif dtype == "bfloat16":
atol, rtol = 1e-2, 1e-2
np.testing.assert_allclose(
out_1,
out_2,
atol=atol,
rtol=rtol,
)
def unzip_and_dequant_wint2(
w, w_scale, w_code_scale, w_code_zp, w_super_scale=None, scale_compute_dtype=None, shuffled=False, group_size=64
):
"""
w uint8 [num_experts, in_feature_size // pack_num, out_feature_size]
w_scale [num_experts, in_feature_size // group_size, out_feature_size]
w_code_scale float32 [num_experts, out_feature_size]
w_code_zp float32 [num_experts, out_feature_size]
w_super_scale w_scale.dtype [num_experts, out_feature_size]
output: w_scale.dtype [num_experts, in_feature_size, out_feature_size]
"""
def w_round(x):
return paddle.floor(x + 0.5)
# step0: w dtype: uint8, shape: [num_experts, in_feature_size // pack_num, out_feature_size]
# where pack_num = 4
pack_num = 4
bzp = 32
num_experts, pack_in_feature_size, out_feature_size = w.shape
in_feature_size = pack_in_feature_size * pack_num
# step1: w need to unzip to shape: [num_experts, in_feature_size, out_feature_size]
# here we use broadcast operation to implcitly expand the last dimension
w = w.transpose(perm=[0, 2, 1]).reshape([num_experts, out_feature_size, pack_in_feature_size, 1])
# for support repeat_interleave, w cast to int32
w = w.cast("int32")
w = w.repeat_interleave(pack_num, axis=-1)
w = w.reshape([num_experts, out_feature_size, in_feature_size])
w = w.transpose(perm=[0, 2, 1])
# step2: w need to first dequant
# w_code_scale shape: [num_experts, out_feature_size]
# w_code_zp shape: [num_experts, out_feature_size]
w_code_scale = w_code_scale.reshape([num_experts, 1, out_feature_size])
w_code_zp = w_code_zp.reshape([num_experts, 1, out_feature_size])
w = w_round(w.cast("float32") * w_code_scale + w_code_zp).cast("int32")
# step3: w need to shifted and mask the original weight to unzip
bit_shift = paddle.to_tensor([9, 6, 3, 0], dtype="int32")
in_feature_bit_shift = bit_shift[paddle.arange(in_feature_size) % pack_num]
in_feature_bit_shift = in_feature_bit_shift.reshape([1, in_feature_size, 1])
mask = paddle.to_tensor(0x3F, dtype="int32")
if scale_compute_dtype is None:
scale_compute_dtype = w_super_scale.dtype if w_super_scale is not None else w_scale.dtype
group_num = in_feature_size // group_size
# step4: w_scale need to shift and mask and dequant
if w_scale.dtype == paddle.uint8:
# w_scale shape: [num_experts, in_feature_size // group_size, out_feature_size]
# w_scale packed shape: [num_experts, group_num // 2, out_feature_size]
w_scale = w_scale.cast("int32")
w_scale = w_scale.reshape([num_experts, group_num // 2, 1, out_feature_size])
w_scale = w_scale.repeat_interleave(2, axis=2)
w_scale = (w_scale >> paddle.to_tensor([4, 0], dtype="int32").reshape([1, 1, 2, 1])) & paddle.to_tensor(
0xF, dtype="int32"
)
w_scale = w_scale.reshape([num_experts, group_num, out_feature_size]).cast(scale_compute_dtype)
# step5: w need to shift and mask and second dequant
w = ((w >> in_feature_bit_shift) & mask).cast(w_scale.dtype)
if w_super_scale is not None:
# w_super_scale shape: [num_experts, out_feature_size]
w_super_scale = w_super_scale.reshape([num_experts, 1, out_feature_size])
w_scale = w_scale * w_super_scale
# w_scale reshape to [num_experts, in_feature_size, out_feature_size]
w_scale = w_scale.reshape([num_experts, in_feature_size // group_size, 1, out_feature_size])
w_scale = w_scale.repeat_interleave(group_size, axis=2).reshape([num_experts, in_feature_size, out_feature_size])
w = (w - bzp).cast(w_scale.dtype) * w_scale
if shuffled:
w = w.reshape([num_experts, in_feature_size // 64, 4, 8, 2, out_feature_size])
w = paddle.transpose(w, perm=[0, 1, 3, 2, 4, 5])
w = w.reshape([num_experts, in_feature_size, out_feature_size])
return w.cast(w_super_scale.dtype)
class MoEArguments:
def __init__(
self,
quant_method,
gate_weight,
ffn1_weight,
ffn2_weight,
ffn1_weight_scale,
ffn2_weight_scale,
ffn1_local_scale=None,
ffn1_code_scale=None,
ffn1_code_zp=None,
ffn2_local_scale=None,
ffn2_code_scale=None,
ffn2_code_zp=None,
gate_correction_bias=None,
topk=8,
):
self.quant_method = quant_method
self.gate_weight = gate_weight
self.gate_correction_bias = gate_correction_bias
self.topk = topk
self.ffn1_weight = ffn1_weight
self.ffn2_weight = ffn2_weight
self.ffn1_weight_scale = ffn1_weight_scale
self.ffn2_weight_scale = ffn2_weight_scale
self.ffn1_local_scale = ffn1_local_scale
self.ffn1_code_scale = ffn1_code_scale
self.ffn1_code_zp = ffn1_code_zp
self.ffn2_local_scale = ffn2_local_scale
self.ffn2_code_scale = ffn2_code_scale
self.ffn2_code_zp = ffn2_code_zp
if quant_method == "none":
self.dtype = ffn1_weight.dtype
else:
self.dtype = ffn1_weight_scale.dtype
self.num_experts = ffn1_weight.shape[0]
if ffn1_weight_scale is not None:
self.intermediate_size = ffn1_weight_scale.shape[1] // 2
else:
self.intermediate_size = ffn1_weight.shape[2] // 2
if ffn2_weight_scale is not None:
self.hidden_size = ffn2_weight_scale.shape[1]
else:
self.hidden_size = ffn2_weight.shape[2]
def convert_to_bf16(self, shuffled=False):
if self.quant_method == "weight_only_int2":
assert (
self.dtype != self.ffn1_weight.dtype
), f"dtype:{self.dtype} vs weight_dtype: {self.ffn1_weights.dtype}"
ffn1_weight = unzip_and_dequant_wint2(
w=self.ffn1_weight,
w_scale=self.ffn1_local_scale,
w_code_scale=self.ffn1_code_scale,
w_code_zp=self.ffn1_code_zp,
w_super_scale=self.ffn1_weight_scale,
shuffled=shuffled,
group_size=64,
)
ffn2_weight = unzip_and_dequant_wint2(
w=self.ffn2_weight,
w_scale=self.ffn2_local_scale,
w_code_scale=self.ffn2_code_scale,
w_code_zp=self.ffn2_code_zp,
w_super_scale=self.ffn2_weight_scale,
shuffled=shuffled,
group_size=64,
)
other = MoEArguments(
quant_method="none",
gate_weight=self.gate_weight,
ffn1_weight=ffn1_weight,
ffn2_weight=ffn2_weight,
ffn1_weight_scale=None,
ffn2_weight_scale=None,
gate_correction_bias=self.gate_correction_bias,
topk=self.topk,
)
return other
else:
assert False, "Not supported now!"
def convert_to_wint4(self):
assert self.quant_method == "none"
assert self.dtype == self.ffn1_weight.dtype, f"dtype:{self.dtype} vs weight_dtype: {self.ffn1_weights.dtype}"
def quantize_ffn_weight(ffn_weight):
weight_list = []
scale_list = []
for i in range(ffn_weight.shape[0]):
quant_weight, scale = weight_quantize(ffn_weight[i, :, :], algo="weight_only_int4", arch=80)
weight_list.append(quant_weight)
scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
scale = paddle.stack(scale_list, axis=0)
return quanted_weight, scale
ffn1_weight, ffn1_weight_scale = quantize_ffn_weight(self.ffn1_weight)
ffn2_weight, ffn2_weight_scale = quantize_ffn_weight(self.ffn2_weight)
other = MoEArguments(
quant_method="weight_only_int4",
gate_weight=self.gate_weight,
ffn1_weight=ffn1_weight,
ffn2_weight=ffn2_weight,
ffn1_weight_scale=ffn1_weight_scale,
ffn2_weight_scale=ffn2_weight_scale,
gate_correction_bias=self.gate_correction_bias,
topk=self.topk,
)
return other
def print(self):
print("")
print(f"-- [MoEArguments] dtype: {self.dtype}")
print(f"-- [MoEArguments] num_experts: {self.num_experts}")
print(f"-- [MoEArguments] intermediate_size: {self.intermediate_size}")
print(f"-- [MoEArguments] hidden_size: {self.hidden_size}")
print_tensor_info(self.gate_correction_bias, "gate_correction_bias")
print_tensor_info(self.ffn1_weight, "ffn1_weight")
print_tensor_info(self.ffn2_weight, "ffn2_weight")
print_tensor_info(self.ffn1_weight_scale, "ffn1_weight_scale")
print_tensor_info(self.ffn2_weight_scale, "ffn2_weight_scale")
print_tensor_info(self.ffn1_local_scale, "ffn1_local_scale")
print_tensor_info(self.ffn2_local_scale, "ffn2_local_scale")
print_tensor_info(self.ffn1_code_scale, "ffn1_code_scale")
print_tensor_info(self.ffn2_code_scale, "ffn2_code_scale")
print_tensor_info(self.ffn1_code_zp, "ffn1_code_zp")
print_tensor_info(self.ffn2_code_zp, "ffn2_code_zp")
def prepare_args_wint2(test_dir):
tensor_names = [
"x",
"gate_weight",
"topk_ids",
"gate_correction_bias",
"ffn1_weight",
"ffn2_weight",
"ffn1_super_scales",
"ffn2_super_scales",
"ffn1_weight_scale",
"ffn1_code_scale",
"ffn1_code_zp",
"ffn2_weight_scale",
"ffn2_code_scale",
"ffn2_code_zp",
]
tensor_dict = load_all_tensors(tensor_names, test_dir)
topk = tensor_dict["topk_ids"].shape[1]
moe_args = MoEArguments(
quant_method="weight_only_int2",
gate_weight=tensor_dict["gate_weight"],
ffn1_weight=tensor_dict["ffn1_weight"],
ffn2_weight=tensor_dict["ffn2_weight"],
ffn1_weight_scale=tensor_dict["ffn1_super_scales"],
ffn2_weight_scale=tensor_dict["ffn2_super_scales"],
ffn1_local_scale=tensor_dict["ffn1_weight_scale"],
ffn1_code_scale=tensor_dict["ffn1_code_scale"],
ffn1_code_zp=tensor_dict["ffn1_code_zp"],
ffn2_local_scale=tensor_dict["ffn2_weight_scale"],
ffn2_code_scale=tensor_dict["ffn2_code_scale"],
ffn2_code_zp=tensor_dict["ffn2_code_zp"],
gate_correction_bias=tensor_dict["gate_correction_bias"],
topk=topk,
)
return moe_args
def run_moe_decode_cutlass(moe_args, quant_method, hidden_states, scores):
# print(f"-- [run_moe_decode_cutlass] {quant_method}")
def rearrange_weights(w):
# [num_experts, in_feature_size, out_feature_size]
w_shape = w.shape
# [num_experts, in_feature_size / 64, 64, out_feature_size / 8, 8]
w = w.reshape([w_shape[0], w_shape[1] // 16, 16, w_shape[2] // 8, 8])
# [num_experts, out_feature_size / 8, in_feature_size // 64, 8, 64]
w = paddle.transpose(w, perm=[0, 3, 1, 4, 2])
# w = w.reshape([w_shape[0], w_shape[2] // 8, w_shape[1] // 16, 128])
w = w.reshape(w_shape)
return w
if quant_method == "weight_only_int2":
ffn1_weight = rearrange_weights(moe_args.ffn1_weight)
ffn2_weight = rearrange_weights(moe_args.ffn2_weight)
cache = paddle.empty((int(512e6 // 4),), dtype="int32")
warmup, repeat = 5, 100
gpu_timecosts = []
for i in range(warmup + repeat):
start_event = paddle.device.Event(enable_timing=True)
end_event = paddle.device.Event(enable_timing=True)
cache.zero_() # fast_flush
start_event.record()
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_indices,
expert_idx_per_token,
) = moe_expert_dispatch(
input=hidden_states,
gating_output=scores,
gating_correction_bias=moe_args.gate_correction_bias,
w4a8_in_scale=None,
moe_topk=moe_args.topk,
group_moe=False,
topk_only_mode=moe_args.gate_correction_bias is None,
)
if quant_method == "weight_only_int2":
ffn_out = moe_expert_ffn_wint2(
permute_input,
token_nums_per_expert,
ffn1_weight,
ffn2_weight,
None,
moe_args.ffn1_weight_scale,
moe_args.ffn2_weight_scale,
moe_args.ffn1_local_scale,
moe_args.ffn1_code_scale,
moe_args.ffn1_code_zp,
moe_args.ffn2_local_scale,
moe_args.ffn2_code_scale,
moe_args.ffn2_code_zp,
False,
)
else:
ffn_out = moe_expert_ffn(
permute_input,
token_nums_per_expert,
moe_args.ffn1_weight,
moe_args.ffn2_weight,
None,
moe_args.ffn1_weight_scale,
moe_args.ffn2_weight_scale,
None,
None,
quant_method,
False,
)
moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_indices,
None,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
end_event.record()
gpu_timecosts.append(start_event.elapsed_time(end_event))
cache += int(random.random() * 1000) # change cache
paddle.device.synchronize()
del cache
gpu_timecosts = gpu_timecosts[warmup:]
return moe_out, np.quantile(gpu_timecosts, 0.5)
def test_main(test_dir):
moe_args = prepare_args_wint2(test_dir)
moe_args.print()
quant_method = "weight_only_int2"
check_acc = False
moe_args_bf16 = moe_args.convert_to_bf16(shuffled=True)
moe_args_wint4 = moe_args_bf16.convert_to_wint4()
for num_tokens in [1, 2, 4, 16, 64, 128, 512, 1024]:
hidden_states = paddle.randn([num_tokens, moe_args.hidden_size]).cast(moe_args.dtype)
gate_out = paddle.matmul(hidden_states.cast("float32"), moe_args.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
timecost_wint2, timecost_bf16, timecost_wint4 = 0.0, 0.0, 0.0
out_wint2, timecost_wint2 = run_moe_decode_cutlass(moe_args, quant_method, hidden_states, scores)
out_bf16, timecost_bf16 = run_moe_decode_cutlass(moe_args_bf16, "none", hidden_states, scores)
out_wint4, timecost_wint4 = run_moe_decode_cutlass(moe_args_wint4, "weight_only_int4", hidden_states, scores)
print(
f"[Time Cost] num_tokens: {num_tokens}, {quant_method}: {timecost_wint2:.5f} ms; bf16: {timecost_bf16:.5f} ms; wint4: {timecost_wint4:0.5f} ms"
)
if check_acc:
check_result("bfloat16", out_wint2, out_bf16, check_equal=False)
if __name__ == "__main__":
paddle.seed(1024)
test_dir = os.path.dirname(os.path.abspath(__file__)) + "/ernie45t_tp1_wint2_params"
test_main(test_dir)