mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Quantization] Support w4afp8 MoE dynamic quantization (#5282)
* support dynamic activation quant for w4afp8 * support dynamic w4afp8 * add test * fix * fix --------- Co-authored-by: zhoutianzi666 <17801055074@163.com>
This commit is contained in:
@@ -53,6 +53,173 @@ def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) ->
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def random_orthogonal_matrix(size, device):
|
||||
"""
|
||||
Generate a random orthogonal matrix of the specified size.
|
||||
First, we generate a random matrix with entries from a standard distribution.
|
||||
Then, we use QR decomposition to obtain an orthogonal matrix.
|
||||
Finally, we multiply by a diagonal matrix with diag r to adjust the signs.
|
||||
|
||||
Args:
|
||||
size (int): The size of the matrix (size x size).
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: An orthogonal matrix of the specified size.
|
||||
"""
|
||||
paddle.device.cuda.empty_cache()
|
||||
if device == "cuda":
|
||||
random_matrix = paddle.randn(size, size, dtype="float32").to("gpu")
|
||||
q, r = paddle.linalg.qr(random_matrix)
|
||||
q *= paddle.sign(paddle.diag(r)).unsqueeze(0)
|
||||
return q
|
||||
|
||||
|
||||
def is_pow2(n):
|
||||
return (n & (n - 1) == 0) and (n > 0)
|
||||
|
||||
|
||||
def get_hadK(n, transpose=False):
|
||||
hadK, K = None, None
|
||||
assert is_pow2(n)
|
||||
K = 1
|
||||
return hadK, K
|
||||
|
||||
|
||||
def matmul_hadU_int4(X, transpose=False):
|
||||
n = X.shape[-1]
|
||||
hadK, K = get_hadK(n, transpose)
|
||||
input = X.clone().reshape((-1, n, 1))
|
||||
output = input.clone()
|
||||
while input.shape[1] > K:
|
||||
input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
|
||||
output = output.reshape(input.shape)
|
||||
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
|
||||
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
|
||||
output = output.reshape((input.shape[0], input.shape[1], -1))
|
||||
(input, output) = (output, input)
|
||||
del output
|
||||
|
||||
if K > 1:
|
||||
input = hadK.reshape((1, K, K)).to(input) @ input
|
||||
|
||||
return input.reshape(X.shape) / paddle.to_tensor(n, dtype="float32").sqrt()
|
||||
|
||||
|
||||
def random_hadamard_matrix_int4(size, device=None, ffn2=False):
|
||||
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
|
||||
if not ffn2:
|
||||
Q = paddle.randint(low=0, high=2, shape=(size,)).cast("float32")
|
||||
Q = paddle.ones_like(Q, dtype="float32")
|
||||
Q = Q * 2 - 1
|
||||
Q = paddle.diag(Q)
|
||||
return matmul_hadU_int4(Q), None
|
||||
|
||||
else:
|
||||
num_blocks = size
|
||||
while not (num_blocks % 2):
|
||||
num_blocks = num_blocks // 2
|
||||
block_size = size // num_blocks
|
||||
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
|
||||
block = matmul_hadU_int4(Q)
|
||||
large_matrix = paddle.zeros([size, size])
|
||||
|
||||
for i in range(num_blocks):
|
||||
start_row = i * block_size
|
||||
start_col = i * block_size
|
||||
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
|
||||
return large_matrix.cast("float32"), block_size
|
||||
|
||||
|
||||
def get_orthogonal_matrix(size, mode="hadamard", device="cuda"):
|
||||
if mode == "random":
|
||||
return random_orthogonal_matrix(size, device)
|
||||
elif mode == "hadamard":
|
||||
return random_hadamard_matrix_int4(size, device)
|
||||
elif mode == "hadamard_ffn2":
|
||||
return random_hadamard_matrix_int4(size, device, True)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {mode}")
|
||||
|
||||
|
||||
def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, moe_intermediate_size=3584, ep_rank=0):
|
||||
with paddle.no_grad():
|
||||
# collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size]
|
||||
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2")
|
||||
# down_proj.weight: [moe_intermediate_size, hidden_size]
|
||||
expert_list = [
|
||||
get_tensor(
|
||||
state_dict[
|
||||
f"ernie.layers.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight"
|
||||
]
|
||||
)
|
||||
for expert_idx in range(moe_num_experts)
|
||||
]
|
||||
moe_weight = paddle.concat(expert_list, axis=-1) # [moe_intermediate_size, hidden_size * moe_num_experts]
|
||||
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.to(Q_ffn2.place)
|
||||
for expert_idx in range(moe_num_experts):
|
||||
rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size]
|
||||
expert_idx_local = ep_rank * moe_num_experts + expert_idx
|
||||
state_dict[f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = (
|
||||
rotated_weight.cpu()
|
||||
)
|
||||
del moe_weight, new_moe_weight, rotated_weight
|
||||
paddle.device.cuda.empty_cache()
|
||||
return Q_ffn2.cpu()
|
||||
|
||||
|
||||
def pack(src, bits=4):
|
||||
pack_num = 8 // bits
|
||||
shift_bits = (paddle.arange(0, pack_num) * bits).cast("uint8")
|
||||
src = paddle.to_tensor(src).cast("uint8")
|
||||
|
||||
if len(src.shape) == 2:
|
||||
row, col = src.shape
|
||||
src = src.reshape((row, col // pack_num, pack_num))
|
||||
else:
|
||||
src = src.reshape((src.shape[0] // pack_num, pack_num))
|
||||
|
||||
src[..., 0] = paddle.bitwise_and(src[..., 0], paddle.to_tensor(15, dtype="uint8"))
|
||||
src = paddle.to_tensor(src.numpy() << shift_bits.numpy())
|
||||
|
||||
return src.sum(axis=-1).transpose((1, 0)).cast("int8")
|
||||
|
||||
|
||||
def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128):
|
||||
"""
|
||||
Block-wise int4 weight quantization.
|
||||
|
||||
Args
|
||||
weight: paddle.Tensor
|
||||
group_size: int
|
||||
|
||||
Returns
|
||||
weight_quant: paddle.Tensor, int8 weight after quantization and pack
|
||||
weight_scale: paddle.Tensor, fp32 weight scale with group_size
|
||||
"""
|
||||
if weight.dtype == paddle.bfloat16:
|
||||
weight = weight.astype(paddle.float32)
|
||||
assert weight.dim() == 2
|
||||
weight = weight.transpose((1, 0))
|
||||
out_features, in_features = weight.shape
|
||||
q_max, q_min = 7, -8
|
||||
|
||||
# [out_features, in_features] -> [out_features, in_features // group_size, group_size]
|
||||
assert (
|
||||
in_features % group_size == 0
|
||||
), f"in_features must be divisible by group_size: {group_size}, but got in_features: {in_features}"
|
||||
weight = weight.reshape((out_features, in_features // group_size, group_size))
|
||||
|
||||
# calculate weight_scale
|
||||
abs_max = paddle.max(paddle.abs(weight), axis=-1, keepdim=False).astype(paddle.float32)
|
||||
weight_scale = paddle.clip(abs_max, min=1e-8)
|
||||
|
||||
quant_weight = paddle.round(weight / weight_scale.unsqueeze(-1) * q_max)
|
||||
quant_weight = paddle.clip(quant_weight, min=q_min, max=q_max)
|
||||
quant_weight = quant_weight.reshape((out_features, in_features)).transpose((1, 0))
|
||||
|
||||
return quant_weight.astype(paddle.int8), weight_scale
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Only used in deep_gemm block wise quant weight.
|
||||
|
||||
Reference in New Issue
Block a user