mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 11:02:01 +08:00
@@ -345,9 +345,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||
)
|
||||
)
|
||||
|
||||
@@ -357,22 +355,62 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_in_scale = []
|
||||
down_proj_in_scale = []
|
||||
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||
scale_tensor = get_tensor(
|
||||
get_tensor(
|
||||
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
Reference in New Issue
Block a user