mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
add w4afp8 offline script (#3636)
This commit is contained in:
@@ -18,6 +18,7 @@ from fastdeploy.model_executor.load_weight_utils import (
|
||||
get_all_safetensors,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@@ -47,14 +48,21 @@ def parse_arguments():
|
||||
help="Whether merge the model into safetensors format.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moe_quant_type",
|
||||
default="w4a8",
|
||||
choices=["w4a8", "w4afp8"],
|
||||
help="The moe quant type of the model.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def reorder():
|
||||
def fn(weight):
|
||||
def fn(weight, moe_quant_type):
|
||||
from paddle.nn.quant import weight_quantize
|
||||
|
||||
quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
|
||||
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
|
||||
return quant_weight.cpu()
|
||||
|
||||
return fn
|
||||
@@ -69,22 +77,27 @@ def deal_in_scale():
|
||||
|
||||
|
||||
def deal_weight_scale():
|
||||
def fn(weight_scale, processed_in_scale):
|
||||
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
||||
return processed_weight_scale
|
||||
def fn(weight_scale, processed_in_scale, moe_quant_type):
|
||||
if moe_quant_type == "w4a8":
|
||||
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
||||
return processed_weight_scale
|
||||
elif moe_quant_type == "w4afp8":
|
||||
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
|
||||
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
|
||||
return processed_weight_scale
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
# tmp support w4a8
|
||||
def deal_quant(state_dict, save_state_dict):
|
||||
w4a8_mapping = [
|
||||
def deal_quant(state_dict, save_state_dict, moe_quant_type):
|
||||
param_mapping = [
|
||||
# pattern,fn
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
|
||||
]
|
||||
for pattern, fn in w4a8_mapping:
|
||||
for pattern, fn in param_mapping:
|
||||
for key in list(state_dict.keys()):
|
||||
# print(f"deal {key}")
|
||||
match = re.search(pattern, key)
|
||||
@@ -94,9 +107,11 @@ def deal_quant(state_dict, save_state_dict):
|
||||
if "weight_scale" in key:
|
||||
in_scale_key = key.replace("weight_scale", "activation_scale")
|
||||
in_scale = save_state_dict[in_scale_key]
|
||||
save_state_dict[key] = fn(weight_or_scale, in_scale)
|
||||
else:
|
||||
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
|
||||
elif "activation_scale" in key:
|
||||
save_state_dict[key] = fn(weight_or_scale)
|
||||
else:
|
||||
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
|
||||
|
||||
|
||||
def save_safetensors(state_dict, args):
|
||||
@@ -153,7 +168,7 @@ def main():
|
||||
end = time.perf_counter()
|
||||
logger.info("Finish Quantize.")
|
||||
logger.info(f"load and quantize took : {end - start:.6f} seconds")
|
||||
deal_quant(state_dict, save_state_dict)
|
||||
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
|
||||
for key in list(state_dict.keys()):
|
||||
save_state_dict[key] = state_dict.pop(key)
|
||||
logger.info("Begin to save model")
|
||||
|
Reference in New Issue
Block a user