mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
support w4afp8 eplb (#3680)
This commit is contained in:
@@ -661,7 +661,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
self.moe_quant_type = "w4afp8"
|
||||
self.pack_num = 2
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
@@ -677,6 +677,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
is_rearrange,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -686,22 +687,62 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_in_scale = []
|
||||
down_proj_in_scale = []
|
||||
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
@@ -763,7 +804,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
@@ -774,7 +817,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
self.load_w4afp8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||
self.load_w4afp8_scale_weights(
|
||||
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
)
|
||||
|
||||
def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
||||
"""
|
||||
@@ -828,7 +873,14 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
),
|
||||
)
|
||||
|
||||
def load_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
||||
def load_w4afp8_scale_weights(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
weight_key_map: dict,
|
||||
state_dict: dict,
|
||||
logical_expert_ids: paddle.Tensor,
|
||||
ep_rank_to_expert_id_list: list,
|
||||
):
|
||||
"""
|
||||
Get w4afp8 weights from state dict and process them.
|
||||
Args:
|
||||
@@ -837,8 +889,15 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
|
||||
def _extract_scale_tensor(state_dict, key_template, expert_idx):
|
||||
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
|
||||
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
|
||||
return get_tensor(
|
||||
(
|
||||
state_dict.pop(key_template.format(expert_idx))
|
||||
if key_template.format(expert_idx) in state_dict
|
||||
else key_template.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
|
||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||
@@ -881,17 +940,23 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in range(layer.num_experts):
|
||||
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
|
||||
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
|
||||
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
|
||||
paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
)
|
||||
|
||||
for local_expert_idx in range(layer.num_local_experts):
|
||||
expert_idx = local_expert_idx + layer.expert_id_offset
|
||||
for expert_idx in logical_expert_ids:
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
|
Reference in New Issue
Block a user