Files
FastDeploy/fastdeploy/model_executor/layers/utils.py
Sunny-bot1 3629db4129 [Quantization] Support w4afp8 MoE dynamic quantization (#5282)
* support dynamic activation quant for w4afp8

* support dynamic w4afp8

* add test

* fix

* fix

---------

Co-authored-by: zhoutianzi666 <17801055074@163.com>
2025-12-02 18:56:16 +08:00

575 lines
19 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import functools
from typing import Tuple, Union
import numpy as np
import paddle
from paddle import Tensor, nn
from paddle.framework import in_dynamic_mode
from scipy.linalg import block_diag
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and current_platform.available():
try:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
speculate_get_padding_offset,
)
except Exception:
raise ImportError(
"Verify environment consistency between compilation and FastDeploy installation. "
"And ensure the Paddle version supports FastDeploy's custom operators"
)
from fastdeploy import envs
cache_params = envs.FD_CACHE_PARAMS
if cache_params != "none":
c8_state_dict = paddle.load(cache_params, return_numpy=True)
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def random_orthogonal_matrix(size, device):
"""
Generate a random orthogonal matrix of the specified size.
First, we generate a random matrix with entries from a standard distribution.
Then, we use QR decomposition to obtain an orthogonal matrix.
Finally, we multiply by a diagonal matrix with diag r to adjust the signs.
Args:
size (int): The size of the matrix (size x size).
Returns:
paddle.Tensor: An orthogonal matrix of the specified size.
"""
paddle.device.cuda.empty_cache()
if device == "cuda":
random_matrix = paddle.randn(size, size, dtype="float32").to("gpu")
q, r = paddle.linalg.qr(random_matrix)
q *= paddle.sign(paddle.diag(r)).unsqueeze(0)
return q
def is_pow2(n):
return (n & (n - 1) == 0) and (n > 0)
def get_hadK(n, transpose=False):
hadK, K = None, None
assert is_pow2(n)
K = 1
return hadK, K
def matmul_hadU_int4(X, transpose=False):
n = X.shape[-1]
hadK, K = get_hadK(n, transpose)
input = X.clone().reshape((-1, n, 1))
output = input.clone()
while input.shape[1] > K:
input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
output = output.reshape(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.reshape((input.shape[0], input.shape[1], -1))
(input, output) = (output, input)
del output
if K > 1:
input = hadK.reshape((1, K, K)).to(input) @ input
return input.reshape(X.shape) / paddle.to_tensor(n, dtype="float32").sqrt()
def random_hadamard_matrix_int4(size, device=None, ffn2=False):
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
if not ffn2:
Q = paddle.randint(low=0, high=2, shape=(size,)).cast("float32")
Q = paddle.ones_like(Q, dtype="float32")
Q = Q * 2 - 1
Q = paddle.diag(Q)
return matmul_hadU_int4(Q), None
else:
num_blocks = size
while not (num_blocks % 2):
num_blocks = num_blocks // 2
block_size = size // num_blocks
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
block = matmul_hadU_int4(Q)
large_matrix = paddle.zeros([size, size])
for i in range(num_blocks):
start_row = i * block_size
start_col = i * block_size
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
return large_matrix.cast("float32"), block_size
def get_orthogonal_matrix(size, mode="hadamard", device="cuda"):
if mode == "random":
return random_orthogonal_matrix(size, device)
elif mode == "hadamard":
return random_hadamard_matrix_int4(size, device)
elif mode == "hadamard_ffn2":
return random_hadamard_matrix_int4(size, device, True)
else:
raise ValueError(f"Unknown mode {mode}")
def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, moe_intermediate_size=3584, ep_rank=0):
with paddle.no_grad():
# collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size]
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2")
# down_proj.weight: [moe_intermediate_size, hidden_size]
expert_list = [
get_tensor(
state_dict[
f"ernie.layers.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight"
]
)
for expert_idx in range(moe_num_experts)
]
moe_weight = paddle.concat(expert_list, axis=-1) # [moe_intermediate_size, hidden_size * moe_num_experts]
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.to(Q_ffn2.place)
for expert_idx in range(moe_num_experts):
rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size]
expert_idx_local = ep_rank * moe_num_experts + expert_idx
state_dict[f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = (
rotated_weight.cpu()
)
del moe_weight, new_moe_weight, rotated_weight
paddle.device.cuda.empty_cache()
return Q_ffn2.cpu()
def pack(src, bits=4):
pack_num = 8 // bits
shift_bits = (paddle.arange(0, pack_num) * bits).cast("uint8")
src = paddle.to_tensor(src).cast("uint8")
if len(src.shape) == 2:
row, col = src.shape
src = src.reshape((row, col // pack_num, pack_num))
else:
src = src.reshape((src.shape[0] // pack_num, pack_num))
src[..., 0] = paddle.bitwise_and(src[..., 0], paddle.to_tensor(15, dtype="uint8"))
src = paddle.to_tensor(src.numpy() << shift_bits.numpy())
return src.sum(axis=-1).transpose((1, 0)).cast("int8")
def group_wise_int4_weight_quantize(weight: paddle.Tensor, group_size: int = 128):
"""
Block-wise int4 weight quantization.
Args
weight: paddle.Tensor
group_size: int
Returns
weight_quant: paddle.Tensor, int8 weight after quantization and pack
weight_scale: paddle.Tensor, fp32 weight scale with group_size
"""
if weight.dtype == paddle.bfloat16:
weight = weight.astype(paddle.float32)
assert weight.dim() == 2
weight = weight.transpose((1, 0))
out_features, in_features = weight.shape
q_max, q_min = 7, -8
# [out_features, in_features] -> [out_features, in_features // group_size, group_size]
assert (
in_features % group_size == 0
), f"in_features must be divisible by group_size: {group_size}, but got in_features: {in_features}"
weight = weight.reshape((out_features, in_features // group_size, group_size))
# calculate weight_scale
abs_max = paddle.max(paddle.abs(weight), axis=-1, keepdim=False).astype(paddle.float32)
weight_scale = paddle.clip(abs_max, min=1e-8)
quant_weight = paddle.round(weight / weight_scale.unsqueeze(-1) * q_max)
quant_weight = paddle.clip(quant_weight, min=q_min, max=q_max)
quant_weight = quant_weight.reshape((out_features, in_features)).transpose((1, 0))
return quant_weight.astype(paddle.int8), weight_scale
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
"""
Only used in deep_gemm block wise quant weight.
copy from FastDeploy/custom_ops/gpu_ops/fp8_deep_gemm/tests/test_core.py.
"""
from fastdeploy.model_executor.ops.gpu.deep_gemm import ceil_div
assert x.dim() == 2
m, n = x.shape
x_padded = paddle.zeros(
(
ceil_div(m, block_size[0]) * block_size[0],
ceil_div(n, block_size[1]) * block_size[1],
),
dtype=x.dtype,
)
x_padded[:m, :n] = x
x_view = paddle.view(
x_padded,
(-1, block_size[0], x_padded.shape[1] // block_size[1], block_size[1]),
)
x_abs = paddle.abs(x_view).astype(paddle.float32)
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
x_amax = paddle.clip(x_amax, min=1e-4)
x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2]))
)
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Per token cast to float8_e4m3fn used in wfp8apf8
"""
x_abs = paddle.abs(x).astype(paddle.float32)
x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4)
x_s = x_max / 448.0
x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn)
return x_q, x_s
# for distributed tensor model parallel
def _set_var_distributed(var: Tensor, split_axis: int):
"""
Set whether the variable is distributed. If the variable is None, no operation will be performed.
Args:
var (Tensor): A Variable object, which can be None. The default value is None.
The Variable object should have an attribute 'is_distributed' to indicate whether
the variable has been processed in a distributed manner.
split_axis (int): the sharding dimension of dist tensors.
Returns:
None. No return value.
"""
if var is None:
return
var.is_distributed = True
var.split_axis = split_axis
if not in_dynamic_mode():
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
def get_tensor(input: Union[paddle.Tensor, np.ndarray, str], model_path=None) -> paddle.Tensor:
"""
Return a corresponding PaddlePaddle tensor based on the type and content of the input.
Args:
input (Union[paddle.Tensor, np.ndarray, str]): The input data.
Returns:
paddle.Tensor: Returns a PaddlePaddle tensor.
"""
if "PySafeSlice" in str(type(input)):
input = input.get()
if isinstance(input, paddle.Tensor):
if input.place.is_cpu_place():
if current_platform.is_cuda():
return input.cuda()
else:
return input.to(paddle.device.get_device())
return input
elif isinstance(input, np.ndarray):
return paddle.to_tensor(input)
elif isinstance(input, str):
from fastdeploy.model_executor.load_weight_utils import load_reordered_experts
return load_reordered_experts(model_path, input)
else:
return input
def matmul_hadU(X: Tensor) -> paddle.Tensor:
"""
Perform matrix multiplication using the Hadamard matrix.
Args:
X (Tensor): The tensor to be multiplied.
Returns:
Tensor: The tensor after Hadamard matrix multiplication, with the same shape as the input tensor X.
"""
input = X.clone().reshape((-1, X.shape[-1], 1))
output = input.clone()
while input.shape[1] > 1:
input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
output = output.reshape(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.reshape((input.shape[0], input.shape[1], -1))
(input, output) = (output, input)
del output
return input.reshape(X.shape)
def random_hadamard_matrix(block_size: int, dtype: Union[paddle.dtype, str]) -> paddle.Tensor:
"""
Generate a random Hadamard matrix.
Args:
block_size (int): The size of the block, i.e., the number of rows and columns of the matrix.
dtype (str): The data type, for example 'float32'.
Returns:
paddle.Tensor: The generated random Hadamard matrix.
"""
Q = paddle.diag(paddle.ones((block_size), dtype=dtype))
block = matmul_hadU(Q)
return block
def create_hadamard_matrix(hidden_size: int) -> paddle.Tensor:
"""
Generate a Hadamard matrix.
Args:
hidden_size (int): The size of the hidden layer.
Returns:
paddle.Tensor: The generated Hadamard matrix.
"""
hadamard_block_size = 32
h = random_hadamard_matrix(hadamard_block_size, "float32")
block_num = hidden_size // hadamard_block_size
hadamard_matrix = paddle.to_tensor(block_diag(*[h for i in range(block_num)]))
return hadamard_matrix
def ensure_divisibility(numerator, denominator):
"""
Ensure the numerator is divisible by the denominator.
Args:
numerator (int): The numerator.
denominator (int): The denominator.
Returns:
None
Raises:
AssertionError: If the numerator cannot be evenly divided by the denominator, an assertion error is raised.
"""
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
def divide(numerator: int, denominator: int):
"""
Calculate the division result of two numbers.
Args:
numerator (int): The dividend.
denominator (int): The divisor.
Returns:
int: The result of the division, which is the quotient of the dividend divided by the divisor.
"""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def remove_padding(
max_len: paddle.Tensor,
input_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""
Remove padded sequences from the input.
Args:
max_len (paddle.Tensor): The maximum length of the input sequences.
input_ids (paddle.Tensor): The IDs of the input sequences.
seq_lens_this_time (paddle.Tensor): The actual length of each sequence.
Returns:
tuple: A tuple containing:
- The sequence IDs with padding removed (paddle.Tensor).
- The padding offsets (paddle.Tensor).
- The cumulative offsets (paddle.Tensor).
- The query sequence lengths (paddle.Tensor).
- The key sequence lengths (paddle.Tensor).
"""
if current_platform.is_cuda():
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
cum_offsets,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
return (
ids_remove_padding,
padding_offset,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
)
def speculate_remove_padding(
max_len: paddle.Tensor,
input_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
draft_tokens: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""
Remove padding from sequences.
Args:
max_len (paddle.Tensor): The maximum length of the sequences.
input_ids (paddle.Tensor): The IDs of the input sequences.
seq_lens_this_time (paddle.Tensor): The lengths of the sequences in the current batch.
draft_tokens (paddle.Tensor): The draft tokens.
seq_lens_encoder (paddle.Tensor): The lengths of the encoder sequences.
Returns:
tuple: A tuple containing:
- The input sequence IDs with padding removed (paddle.Tensor).
- Padding offsets (paddle.Tensor).
- Cumulative offsets (paddle.Tensor).
- Query sequence lengths (paddle.Tensor).
- Key sequence lengths (paddle.Tensor).
"""
if current_platform.is_cuda():
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
cum_offsets,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids,
draft_tokens,
cum_offsets_now,
token_num,
seq_lens_this_time,
seq_lens_encoder,
)
return (
ids_remove_padding,
padding_offset,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
)
class CpuGuard:
"""CpuGuard"""
def __init__(self):
"""init"""
pass
def __enter__(self):
"""enter"""
self.ori_device = paddle.device.get_device()
paddle.device.set_device("cpu")
def __exit__(self, exc_type, exc_val, exc_tb):
"""exit"""
paddle.device.set_device(self.ori_device)
def create_and_set_parameter(layer: nn.Layer, name: str, tensor: paddle.Tensor):
"""
Create a parameter for a specified layer and set its value to the given tensor.
Args:
layer (nn.Layer): The layer object to which the parameter will be added.
name (str): The name of the parameter to be created.
tensor (paddle.Tensor): The tensor to set as the value of the parameter.
Returns:
None
"""
setattr(
layer,
name,
layer.create_parameter(
shape=tensor.shape,
dtype=tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, name).set_value(tensor)
@functools.cache
def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) -> paddle.Tensor:
"""
Creates and caches an empty tensor with the specified shape and data type.
Args:
shape (Tuple[int, ...]): A tuple representing the dimensions of the tensor.
dtype (Union[paddle.dtype, str]): The data type for the tensor, such as 'bfloat16', 'float16', etc.
Returns:
paddle.Tensor: An empty tensor with the specified shape and data type.
"""
return paddle.empty(list(shape), dtype=dtype)
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, rank: int, offset: int = 0):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset)