[Feature] Support 45tVL EP FP8 Infer. (#2909)

* support_mm_ep_fp8

* support_mm_ep
This commit is contained in:
xiaoxiaohehe001
2025-07-18 17:57:15 +08:00
committed by GitHub
parent fbe3547c95
commit a42fc3f40b

View File

@@ -94,17 +94,44 @@ class Ernie4_5_VLMoE(nn.Layer):
image_moe_layer_end_index = moe_layer_end_index[1]
assert text_moe_layer_start_index <= text_moe_layer_end_index
moe_quant_type = ""
if hasattr(fd_config, 'quant_config') and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, 'name', lambda: "")()
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8"
and fd_config.model_config.is_quantized):
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
self.text_fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
@@ -128,16 +155,38 @@ class Ernie4_5_VLMoE(nn.Layer):
assert image_moe_layer_start_index <= image_moe_layer_end_index
if layer_id >= image_moe_layer_start_index and layer_id <= image_moe_layer_end_index:
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight_1",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8"
and fd_config.model_config.is_quantized):
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight_1",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight_1",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
self.image_fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
@@ -553,6 +602,18 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
return logits
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
def forward(
self,
ids_remove_padding: paddle.Tensor,
@@ -759,4 +820,4 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
config.vision_config.get("depth")
)
return {**mappings, **vision_mappings}
return {**mappings, **vision_mappings}