add w4afp8 offline script (#3636)

This commit is contained in:
Yuan Xiaolan
2025-08-29 17:56:05 +08:00
committed by GitHub
parent f677c032c0
commit c71ee0831c
12 changed files with 163 additions and 37 deletions

View File

@@ -18,6 +18,7 @@ from fastdeploy.model_executor.load_weight_utils import (
get_all_safetensors,
safetensors_weights_iterator,
)
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
def parse_arguments():
@@ -47,14 +48,21 @@ def parse_arguments():
help="Whether merge the model into safetensors format.",
)
parser.add_argument(
"--moe_quant_type",
default="w4a8",
choices=["w4a8", "w4afp8"],
help="The moe quant type of the model.",
)
return parser.parse_args()
def reorder():
def fn(weight):
def fn(weight, moe_quant_type):
from paddle.nn.quant import weight_quantize
quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
return quant_weight.cpu()
return fn
@@ -69,22 +77,27 @@ def deal_in_scale():
def deal_weight_scale():
def fn(weight_scale, processed_in_scale):
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
return processed_weight_scale
def fn(weight_scale, processed_in_scale, moe_quant_type):
if moe_quant_type == "w4a8":
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
return processed_weight_scale
elif moe_quant_type == "w4afp8":
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
return processed_weight_scale
return fn
# tmp support w4a8
def deal_quant(state_dict, save_state_dict):
w4a8_mapping = [
def deal_quant(state_dict, save_state_dict, moe_quant_type):
param_mapping = [
# pattern,fn
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
]
for pattern, fn in w4a8_mapping:
for pattern, fn in param_mapping:
for key in list(state_dict.keys()):
# print(f"deal {key}")
match = re.search(pattern, key)
@@ -94,9 +107,11 @@ def deal_quant(state_dict, save_state_dict):
if "weight_scale" in key:
in_scale_key = key.replace("weight_scale", "activation_scale")
in_scale = save_state_dict[in_scale_key]
save_state_dict[key] = fn(weight_or_scale, in_scale)
else:
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
elif "activation_scale" in key:
save_state_dict[key] = fn(weight_or_scale)
else:
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
def save_safetensors(state_dict, args):
@@ -153,7 +168,7 @@ def main():
end = time.perf_counter()
logger.info("Finish Quantize.")
logger.info(f"load and quantize took : {end - start:.6f} seconds")
deal_quant(state_dict, save_state_dict)
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
for key in list(state_dict.keys()):
save_state_dict[key] = state_dict.pop(key)
logger.info("Begin to save model")