Files
FastDeploy/fastdeploy/inference_args.py
2025-06-16 00:04:48 +08:00

629 lines
26 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import json
import os
import re
from enum import Enum
import numpy as np
import paddle
from paddlenlp.utils.log import logger
class GenerationPhase(Enum):
"""
The generation phase of the model.
"""
PREFILL = 1
DECODER = 2
class InferenceArgs:
"""
The parameters used for inference, including model parameters and quantization information.
"""
def __init__(
self,
quant_type,
num_layers,
num_attention_heads,
num_key_value_heads,
hidden_size,
ffn_hidden_size,
mp_rank,
mp_size,
model_path="",
use_fake_parameter=False,
fp8_type="e4m3fn",
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
has_zero_point=False,
is_channel_wise=False,
gqa_use_tensorcore=False,
use_dynamic_cachekv_quant=False,
max_position_embeddings=512,
speculate_method=None,
speculate_max_draft_token_num=1,
use_moe=False,
moe_num_experts=None,
moe_intermediate_size=None,
moe_use_gate_correction_bias=False,
moe_every2=False,
moe_topk=8,
moe_num_shared_experts=0,
moe_layer_start_index=0,
moe_use_ffn_shared_weight_and_bias=False,
moe_group=False,
moe_quant_type="default",
use_ep=False,
generation_phase=GenerationPhase.PREFILL,
use_micro_batch=False,
weight_block_size=[-1, -1],
start_layer_index=0,
scale_dir=None,
enable_redundant_experts: bool = False,
redundant_experts_num: int = 0,
use_offline_quant=False,
max_batch_size: int = 128,
):
"""
Initialization function for quantization of the Transformer model
Args:
quant_type (str): Type of quantization. Options include 'abs_max', 'moving_average_abs_max',
'range_abs_max', 'default'.
num_layers (int): Number of layers in the Transformer model.
num_attention_heads (int): Number of attention heads.
num_key_value_heads (int): Number of key-value heads. If less than or equal to 0,
it is equal to num_attention_heads.
hidden_size (int): Size of the hidden layer.
ffn_hidden_size (int): Size of the hidden layer in the feedforward neural network.
mp_rank (int): Rank of the current process in model parallelism.
mp_size (int): Size of model parallelism.
model_path (str, optional): Path to the model. Default is an empty string.
use_fake_parameter (bool, optional): Whether to use fake parameters. Default is False.
fp8_type (str, optional): Type of fp8. Options include 'e4m3fn', 'e5m2'. Default is 'e4m3fn'.
quant_round_type (int, optional): Rounding type for quantization. Default is 0.
quant_max_bound (float, optional): Maximum bound for quantization. Default is 0.
quant_min_bound (float, optional): Minimum bound for quantization. Default is 0.
use_dynamic_cachekv_quant (bool, optional): Whether to use dynamic caching for kv quantization.
Default is False.
max_position_embeddings (int, optional): Maximum position embeddings. Default is 512.
Returns:
None
"""
self.quant_type = quant_type.lower()
self.scale_dir = scale_dir
if self.quant_type == "default":
self.quant_type = ""
self.moe_quant_type = moe_quant_type.lower()
if self.moe_quant_type == "default":
self.moe_quant_type = ""
self.weight_block_size = weight_block_size
# self.weight_block_size = [-1, -1]
self.use_offline_quant = use_offline_quant
self.ffn_hidden_size = ffn_hidden_size
self.mp_rank = mp_rank
if use_ep:
self.mp_size = 1
self.nranks = mp_size
else:
self.mp_size = mp_size
self.nranks = mp_size
self.use_ep = use_ep
self.generation_phase = generation_phase
self.use_micro_batch = use_micro_batch
self.num_layers = num_layers
self.start_layer_index = start_layer_index
self.hidden_size = hidden_size
self.head_dim = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = (num_key_value_heads if num_key_value_heads
>= 0 else self.num_attention_heads)
self.qkv_hidden_size = (self.num_attention_heads +
2 * self.num_key_value_heads) * self.head_dim
self.dim_feedforward = ffn_hidden_size
self.max_position_embeddings = max_position_embeddings
self.model_path = model_path
self.use_fake_parameter = use_fake_parameter
self.fp8_type = fp8_type
self.default_type = paddle.get_default_dtype()
(
self.weight_dtype,
self.act_dtype,
self.cachekv_dtype,
) = self.parser_quant_type(self.quant_type)
logger.info(
f"quant_type: weight[{self.weight_dtype}], act[{self.act_dtype}], cachekv[{self.cachekv_dtype}]"
)
self.enable_redundant_experts = enable_redundant_experts
self.redundant_experts_num = redundant_experts_num
self.max_batch_size = max_batch_size
class MoEConfig:
"""
Initialization moe config
Args:
use_moe (bool): whether your model have moe layer.
num_experts (int): num of experts in moe layer.
top_k (int): top_k in moe layer.
moe_intermediate_size (int): the 2th linear's input dim.
activation (str): the activation in your moe layer.
"""
use_moe: bool = False
num_experts: int = -1
top_k: int = -1
moe_intermediate_size: int = -1
num_experts_per_rank: int = -1
num_experts_start_offset: int = -1
activation = "swiglu"
moe_use_gate_correction_bias = False
moe_every2 = (False, )
moe_topk = (8, )
moe_num_shared_experts = (0, )
moe_layer_start_index = 0
moe_use_ffn_shared_weight_and_bias = (False, )
moe_group = (False, )
moe_quant_type = self.moe_quant_type
num_max_dispatch_tokens_per_rank = 256
has_multimodality: bool = False
im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
)
self.moe_config = MoEConfig()
self.moe_config.use_moe = use_moe
if use_moe:
if isinstance(moe_num_experts, list):
self.moe_config.has_multimodality = True
self.moe_config.num_experts = moe_num_experts[0]
else:
self.moe_config.num_experts = moe_num_experts
self.moe_config.num_experts_per_rank = (
self.moe_config.num_experts + redundant_experts_num
) // self.nranks
self.moe_config.num_experts_start_offset = (
self.moe_config.num_experts_per_rank * self.mp_rank)
if isinstance(moe_intermediate_size, list):
self.moe_config.moe_intermediate_size = moe_intermediate_size[
0]
else:
self.moe_config.moe_intermediate_size = moe_intermediate_size
self.moe_config.moe_every2 = moe_every2
self.moe_config.moe_num_shared_experts = moe_num_shared_experts
self.moe_config.moe_layer_start_index = moe_layer_start_index
self.moe_config.moe_use_ffn_shared_weight_and_bias = (
moe_use_ffn_shared_weight_and_bias)
self.moe_config.moe_group = moe_group
self.moe_config.top_k = moe_topk
self.moe_config.moe_use_gate_correction_bias = moe_use_gate_correction_bias
if isinstance(moe_num_experts, list):
# multimodality
self.moe_config_1 = copy.deepcopy(self.moe_config)
self.moe_config_1.num_experts = moe_num_experts[1]
self.moe_config_1.moe_intermediate_size = moe_intermediate_size[1]
self.use_weight_only = True if self.weight_dtype != self.act_dtype else False
# arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70,
# if you do not assign arch, we will get arch from your device, default: None.
self.weight_only_linear_arch = os.getenv(
"FLAGS_weight_only_linear_arch")
if self.weight_only_linear_arch is not None:
self.weight_only_linear_arch = int(self.weight_only_linear_arch)
self.use_append_attn = os.getenv("FLAGS_use_append_attn")
if self.use_append_attn is not None:
self.use_append_attn = int(self.use_append_attn) == 1
else:
self.use_append_attn = False
self.has_zero_point = has_zero_point
self.is_channel_wise = is_channel_wise
self.gqa_use_tensorcore = gqa_use_tensorcore
if self.gqa_use_tensorcore:
logger.warning("TensorCore Attention is not supported yet.")
if self.cachekv_dtype == "int8":
self.cache_quant_type = "cache_int8"
if self.has_zero_point:
self.cache_quant_type += "_zp"
self.cache_quant_max_bound = 127.0
self.cache_quant_min_bound = -127.0
elif self.cachekv_dtype == "float8_e4m3fn":
self.cache_quant_type = "cache_fp8"
self.cache_quant_max_bound = 448.0
self.cache_quant_min_bound = -448.0
elif self.cachekv_dtype == "int4":
self.cache_quant_type = "cache_int4"
self.cache_quant_max_bound = 7.0
self.cache_quant_min_bound = -7.0
if self.has_zero_point:
self.cache_quant_type += "_zp"
elif self.cachekv_dtype in ["bfloat16", "float16"]:
self.cache_quant_type = "none"
else:
raise ValueError(f"Unsupported cachekv dtype {self.cachekv_dtype}")
self.quant_round_type = quant_round_type
self.quant_max_bound = quant_max_bound
self.quant_min_bound = quant_min_bound
self.use_dynamic_cachekv_quant = use_dynamic_cachekv_quant
self.speculate_method = speculate_method
self.speculate_max_draft_token_num = speculate_max_draft_token_num
# set_scales
if (self.act_dtype == "float8_e4m3fn"
): # 4 exponent bits, 3 mantissa bits, and supports finite numbers
self.quant_max_bound = 448.0
self.quant_min_bound = -448.0
self.quant_round_type = 1
elif self.act_dtype == "int8":
self.quant_max_bound = 127.0
self.quant_min_bound = -127.0
self.quant_round_type = 0
elif self.act_dtype == "int4":
self.quant_max_bound = 7.0
self.quant_min_bound = -7.0
self.quant_round_type = 0
self.weight_scale_dict = {}
self.act_scale_dict = {}
self.cachekv_scale_dict = {}
if not self.use_fake_parameter:
self.set_scales()
# TODO(tangbinhan):Add a unit test for this function.
def parser_quant_type(self, quant_type):
"""
Parse the quantization type string and return the corresponding quantization types for weights,
activations, and custom.
Args:
quant_type (str): The quantization type string. It can be one of the following formats:
- "weight_only_int8" or "wint8": Only weights are quantized to int8.
- "weight_only_int4" or "wint4": Only weights are quantized to int4.
- A custom string in the format of "wxaybzcfp8", where 'x', 'y', 'z' are the quantization bitwidths
for weights, activations, and custom respectively,
and 'a', 'b', 'c' are the prefixes indicating the quantization types
(e.g., 'fp8' for floating-point 8-bit).
If a prefix is missing, the default quantization type will be used.
Returns:
tuple: A tuple of three strings representing the quantization types for weights, activations,
and custom respectively.
If the input is "weight_only_int8" or "wint8", returns ("int8", default_type, default_type).
If the input is "weight_only_int4" or "wint4", returns ("int4", default_type, default_type).
For custom strings, returns the parsed quantization types based on the input format.
Raises:
AssertionError: If the custom quantization type string format is incorrect.
"""
cache_type = self.default_type
if "c8" in quant_type:
cache_type = "int8"
elif "cfp8" in quant_type:
cache_type = "fp8"
elif "c4" in quant_type:
cache_type = "int4"
if "weight_only_int8" in quant_type or "wint8" in quant_type:
return "int8", self.default_type, cache_type
elif "weight_only_int4" in quant_type or "wint4" in quant_type:
return "int4", self.default_type, cache_type
else:
# split quant type, eg. w4afp8c8 -> ['w', '4', 'a', 'fp8', 'c', '8']
pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})"
splited_type = re.split(pattern, quant_type)
splited_type = [tmp_type for tmp_type in splited_type if tmp_type]
assert (len(splited_type) % 2 == 0 and len(splited_type)
<= 6), f"Quant type[{quant_type}] format error."
quant_type_list = []
if "w" in splited_type:
w_idx = splited_type.index("w")
quant_type_list.append(
self.get_quant_dtype(splited_type[w_idx + 1]))
else:
quant_type_list.append(self.default_type)
if "a" in splited_type:
a_idx = splited_type.index("a")
quant_type_list.append(
self.get_quant_dtype(splited_type[a_idx + 1]))
else:
quant_type_list.append(self.default_type)
if "c" in splited_type:
c_idx = splited_type.index("c")
quant_type_list.append(
self.get_quant_dtype(splited_type[c_idx + 1]))
else:
quant_type_list.append(self.default_type)
return quant_type_list[0], quant_type_list[1], quant_type_list[2]
def get_quant_dtype(self, quant_bit):
"""
Get the quantized data type based on the specified bit width.
Args:
quant_bit (str): The bit width for quantization.
Supported values include "8" for int8, "4" for int4, "fp8" for float8
(with additional type specified by self.fp8_type),
"fp16" for float16, "bf16" for bfloat16, and "fp32" for float32.
Returns:
str: The corresponding quantized data type.
Raises:
ValueError: If the specified quant_bit is not supported.
"""
if quant_bit == "8":
return "int8"
elif quant_bit == "4":
return "int4"
elif quant_bit == "16":
return self.default_type
elif quant_bit == "fp8":
return "float8_" + self.fp8_type
elif quant_bit == "fp16":
return "float16"
elif quant_bit == "bf16":
return "bfloat16"
elif quant_bit == "fp32":
return "float32"
else:
raise ValueError(
"only support [int8, int4, float8_e4m3fn, float8_e5m2, fp16/bf16/fp32]"
)
def set_cache_scales(self):
"""
Set scales for weight, activation, and cache key-value.
This method loads scales from JSON files located in the model path. It supports
loading scales for weights, activations, and cache key-value parameters.
Scales for unsupported parameters are ignored.
Raises:
NotImplementedError: If fake parameters are enabled (self.use_fake_parameter is True).
"""
if not self.use_fake_parameter:
# cachekv_scale
if self.cachekv_dtype in ["bfloat16", "float16", "float32"]:
return
from glob import glob
scale_dir = self.scale_dir
scale_paths = glob(os.path.join(scale_dir, "*.json*"))
cachekv_scale_dict_all = []
self.cachekv_scale_dict = {}
for possible_cache_scales_file_name in scale_paths:
fi = open(possible_cache_scales_file_name)
cachekv_scale_dict_all.append(json.load(fi))
for cache_scale_dict in cachekv_scale_dict_all:
for k, v in cache_scale_dict.items():
if k not in self.cachekv_scale_dict.keys():
self.cachekv_scale_dict[k] = []
self.cachekv_scale_dict[k].extend(v)
else:
self.cachekv_scale_dict[k].extend(v)
print("self.cachekv_scale_dict: ", self.cachekv_scale_dict)
num_heads = self.num_attention_heads // self.mp_size
kv_num_heads = self.num_key_value_heads // self.mp_size
col_dim = (kv_num_heads *
self.head_dim if self.is_channel_wise else kv_num_heads)
for k, v in self.cachekv_scale_dict.items():
# cache_kv_scale
if k.endswith(".activation_quanter"):
if self.is_channel_wise:
v_array = (np.array(v).reshape(
-1, self.head_dim).astype(np.float32))
else:
v_array = np.array(v).reshape(-1).astype(np.float32)
if v_array.size > col_dim:
cache_scale = [
v_array[i].tolist()
for i in range(0, num_heads, num_heads //
kv_num_heads)
]
else:
cache_scale = [
v_array[i].tolist()
for i in range(0, kv_num_heads)
]
if (self.has_zero_point
and self.cachekv_dtype == "int4"): # cache_int4_zp
self.cachekv_scale_dict[k] = 1.0 / np.array(
cache_scale).flatten().astype(np.float32)
else:
self.cachekv_scale_dict[k] = (
self.cache_quant_max_bound /
np.array(cache_scale).flatten().astype(np.float32))
# cache_kv_zp
elif k.endswith(".zero_point"):
if self.is_channel_wise:
v_array = (np.array(v).reshape(
-1, self.head_dim).astype(np.float32))
else:
v_array = np.array(v).reshape(-1).astype(np.float32)
if v_array.size > col_dim:
cache_zp = [
v_array[i].tolist()
for i in range(0, num_heads, num_heads //
kv_num_heads)
]
else:
cache_zp = [
v_array[i].tolist()
for i in range(0, kv_num_heads)
]
self.cachekv_scale_dict[k] = (
np.array(cache_zp).flatten().astype(np.float32))
else:
continue
else:
raise NotImplementedError("fake parameter not support now")
def set_scales(self):
"""
Set scales for weight, activation, and cache key-value.
This method loads scales from JSON files located in the model path. It supports
loading scales for weights, activations, and cache key-value parameters.
Scales for unsupported parameters are ignored.
Raises:
NotImplementedError: If fake parameters are enabled (self.use_fake_parameter is True).
"""
if not self.use_fake_parameter:
# weight_scale
if self.use_ep:
weight_scale_json_path = os.path.join(self.model_path,
"weight_scales.json")
else:
weight_scale_json_path = os.path.join(
self.model_path, f"weight_scales_{self.mp_rank}.json")
if os.path.exists(weight_scale_json_path):
with open(weight_scale_json_path) as json_file:
self.weight_scale_dict = json.load(json_file)
for k, v in self.weight_scale_dict.items():
if not k.endswith(".weight_quanter"):
continue
self.weight_scale_dict[k] = np.array(v).astype(np.float32)
# act_scale
if self.use_ep:
act_scale_json_path = os.path.join(self.model_path,
"act_scales.json")
else:
act_scale_json_path = os.path.join(
self.model_path, f"act_scales_{self.mp_rank}.json")
if os.path.exists(act_scale_json_path):
with open(act_scale_json_path) as json_file:
self.act_scale_dict = json.load(json_file)
for k, v in self.act_scale_dict.items():
if not k.endswith(".activation_quanter"):
continue
self.act_scale_dict[k] = 1.0 / np.array(v).astype(np.float32)
# cachekv_scale
if self.cachekv_dtype in ["bfloat16", "float16", "float32"]:
return
if self.use_ep:
from glob import glob
scale_dir = self.scale_dir
scale_paths = glob(os.path.join(scale_dir, "cachekv_scale*"))
cachekv_scale_dict_all = []
self.cachekv_scale_dict = {}
for possible_cache_scales_file_name in scale_paths:
fi = open(possible_cache_scales_file_name)
cachekv_scale_dict_all.append(json.load(fi))
for cache_scale_dict in cachekv_scale_dict_all:
for k, v in cache_scale_dict.items():
if k not in self.cachekv_scale_dict.keys():
self.cachekv_scale_dict[k] = []
self.cachekv_scale_dict[k].extend(v)
else:
self.cachekv_scale_dict[k].extend(v)
else:
for possible_cache_scales_file_name in [
f"cachekv_scales_{self.mp_rank}.json",
f"cachekv_act_scales_{self.mp_rank}.json",
]:
cache_scale_json_path = os.path.join(
self.model_path, possible_cache_scales_file_name)
if os.path.exists(cache_scale_json_path):
with open(cache_scale_json_path) as json_file:
self.cachekv_scale_dict = json.load(json_file)
break
num_heads = self.num_attention_heads // self.mp_size
kv_num_heads = self.num_key_value_heads // self.mp_size
col_dim = (kv_num_heads *
self.head_dim if self.is_channel_wise else kv_num_heads)
for k, v in self.cachekv_scale_dict.items():
# cache_kv_scale
if k.endswith(".activation_quanter"):
if self.is_channel_wise:
v_array = (np.array(v).reshape(
-1, self.head_dim).astype(np.float32))
else:
v_array = np.array(v).reshape(-1).astype(np.float32)
if v_array.size > col_dim:
cache_scale = [
v_array[i].tolist()
for i in range(0, num_heads, num_heads //
kv_num_heads)
]
else:
cache_scale = [
v_array[i].tolist()
for i in range(0, kv_num_heads)
]
if (self.has_zero_point
and self.cachekv_dtype == "int4"): # cache_int4_zp
self.cachekv_scale_dict[k] = 1.0 / np.array(
cache_scale).flatten().astype(np.float32)
else:
self.cachekv_scale_dict[k] = (
self.cache_quant_max_bound /
np.array(cache_scale).flatten().astype(np.float32))
# cache_kv_zp
elif k.endswith(".zero_point"):
if self.is_channel_wise:
v_array = (np.array(v).reshape(
-1, self.head_dim).astype(np.float32))
else:
v_array = np.array(v).reshape(-1).astype(np.float32)
if v_array.size > col_dim:
cache_zp = [
v_array[i].tolist()
for i in range(0, num_heads, num_heads //
kv_num_heads)
]
else:
cache_zp = [
v_array[i].tolist()
for i in range(0, kv_num_heads)
]
self.cachekv_scale_dict[k] = (
np.array(cache_zp).flatten().astype(np.float32))
else:
continue
else:
raise NotImplementedError("fake parameter not support now")