[Quantization] Support w4afp8 MoE dynamic quantization (#5282)

* support dynamic activation quant for w4afp8

* support dynamic w4afp8

* add test

* fix

* fix

---------

Co-authored-by: zhoutianzi666 <17801055074@163.com>
This commit is contained in:
Sunny-bot1
2025-12-02 18:56:16 +08:00
committed by GitHub
parent 429dd2b1db
commit 3629db4129
9 changed files with 478 additions and 37 deletions

View File

@@ -53,6 +53,173 @@ def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) ->
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def random_orthogonal_matrix(size, device):
"""
Generate a random orthogonal matrix of the specified size.
First, we generate a random matrix with entries from a standard distribution.
Then, we use QR decomposition to obtain an orthogonal matrix.
Finally, we multiply by a diagonal matrix with diag r to adjust the signs.
Args:
size (int): The size of the matrix (size x size).
Returns:
paddle.Tensor: An orthogonal matrix of the specified size.
"""
paddle.device.cuda.empty_cache()
if device == "cuda":
random_matrix = paddle.randn(size, size, dtype="float32").to("gpu")
q, r = paddle.linalg.qr(random_matrix)
q *= paddle.sign(paddle.diag(r)).unsqueeze(0)
return q
def is_pow2(n):
return (n & (n - 1) == 0) and (n > 0)
def get_hadK(n, transpose=False):
hadK, K = None, None
assert is_pow2(n)
K = 1
return hadK, K
def matmul_hadU_int4(X, transpose=False):
n = X.shape[-1]
hadK, K = get_hadK(n, transpose)
input = X.clone().reshape((-1, n, 1))
output = input.clone()
while input.shape[1] > K:
input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
output = output.reshape(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.reshape((input.shape[0], input.shape[1], -1))
(input, output) = (output, input)
del output
if K > 1:
input = hadK.reshape((1, K, K)).to(input) @ input
return input.reshape(X.shape) / paddle.to_tensor(n, dtype="float32").sqrt()
def random_hadamard_matrix_int4(size, device=None, ffn2=False):
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
if not ffn2:
Q = paddle.randint(low=0, high=2, shape=(size,)).cast("float32")
Q = paddle.ones_like(Q, dtype="float32")
Q = Q * 2 - 1
Q = paddle.diag(Q)
return matmul_hadU_int4(Q), None
else:
num_blocks = size
while not (num_blocks % 2):
num_blocks = num_blocks // 2
block_size = size // num_blocks
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
block = matmul_hadU_int4(Q)
large_matrix = paddle.zeros([size, size])
for i in range(num_blocks):
start_row = i * block_size
start_col = i * block_size
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
return large_matrix.cast("float32"), block_size
def get_orthogonal_matrix(size, mode="hadamard", device="cuda"):
if mode == "random":
return random_orthogonal_matrix(size, device)
elif mode == "hadamard":
return random_hadamard_matrix_int4(size, device)
elif mode == "hadamard_ffn2":
return random_hadamard_matrix_int4(size, device, True)
else:
raise ValueError(f"Unknown mode {mode}")
def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, moe_intermediate_size=3584, ep_rank=0):
with paddle.no_grad():
# collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size]
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2")
# down_proj.weight: [moe_intermediate_size, hidden_size]
expert_list = [
get_tensor(
state_dict[
f"ernie.layers.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight"
]
)
for expert_idx in range(moe_num_experts)
]
moe_weight = paddle.concat(expert_list, axis=-1) # [moe_intermediate_size, hidden_size * moe_num_experts]
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.to(Q_ffn2.place)
for expert_idx in range(moe_num_experts):
rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size]
expert_idx_local = ep_rank * moe_num_experts + expert_idx
state_dict[f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = (
rotated_weight.cpu()
)
del moe_weight, new_moe_weight, rotated_weight
paddle.device.cuda.empty_cache()
return Q_ffn2.cpu()
def pack(src, bits=4):
pack_num = 8 // bits
shift_bits = (paddle.arange(0, pack_num) * bits).cast("uint8")
src = paddle.to_tensor(src).cast("uint8")
if len(src.shape) == 2:
row, col = src.shape
src = src.reshape((row, col // pack_num, pack_num))
else:
src = src.reshape((src.shape[0] // pack_num, pack_num))
src[..., 0] = paddle.bitwise_and(src[..., 0], paddle.to_tensor(15, dtype="uint8"))
src = paddle.to_tensor(src.numpy() << shift_bits.numpy())
return src.sum(axis=-1).transpose((1, 0)).cast("int8")
def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128):
"""
Block-wise int4 weight quantization.
Args
weight: paddle.Tensor
group_size: int
Returns
weight_quant: paddle.Tensor, int8 weight after quantization and pack
weight_scale: paddle.Tensor, fp32 weight scale with group_size
"""
if weight.dtype == paddle.bfloat16:
weight = weight.astype(paddle.float32)
assert weight.dim() == 2
weight = weight.transpose((1, 0))
out_features, in_features = weight.shape
q_max, q_min = 7, -8
# [out_features, in_features] -> [out_features, in_features // group_size, group_size]
assert (
in_features % group_size == 0
), f"in_features must be divisible by group_size: {group_size}, but got in_features: {in_features}"
weight = weight.reshape((out_features, in_features // group_size, group_size))
# calculate weight_scale
abs_max = paddle.max(paddle.abs(weight), axis=-1, keepdim=False).astype(paddle.float32)
weight_scale = paddle.clip(abs_max, min=1e-8)
quant_weight = paddle.round(weight / weight_scale.unsqueeze(-1) * q_max)
quant_weight = paddle.clip(quant_weight, min=q_min, max=q_max)
quant_weight = quant_weight.reshape((out_features, in_features)).transpose((1, 0))
return quant_weight.astype(paddle.int8), weight_scale
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
"""
Only used in deep_gemm block wise quant weight.