mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -311,18 +311,18 @@ def w4a8_weight_convert(state_dict):
|
||||
w4a8_weight_bites_layers_map = {}
|
||||
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["out_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"] = []
|
||||
for name_keys, gemm_bits in w4a8_weight_bites_name_map.items():
|
||||
if "qkv_proj" in name_keys:
|
||||
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits)
|
||||
elif "out_proj" in name_keys:
|
||||
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
|
||||
elif "linear1" in name_keys:
|
||||
w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"].append(
|
||||
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(
|
||||
gemm_bits)
|
||||
elif "linear2" in name_keys:
|
||||
w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"].append(
|
||||
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(
|
||||
gemm_bits)
|
||||
logger.debug(
|
||||
f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
|
||||
|
Reference in New Issue
Block a user