[Models] Add forward_meta to moe models' forward function (#5138)

* [Models] Add forward_meta to moe models' forward function

* fix missing param

* fix

* fix

* fix forward_meta

* fix test and remove chunked MoE releated in config

* fix test

* fix

* fix
This commit is contained in:
Longzhi Wang
2025-12-04 13:26:58 +08:00
committed by GitHub
parent f5bdb36e9b
commit 5cd17fd662
21 changed files with 131 additions and 87 deletions

View File

@@ -94,7 +94,7 @@ class Ernie4_5_MLP(nn.Layer):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
gate_up_out = self.up_gate_proj(hidden_states)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
@@ -213,8 +213,16 @@ class Ernie4_5_MoE(nn.Layer):
def update_state_dict(self, state_dict):
self.experts.load_state_dict(state_dict, True)
def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
def forward(
self,
hidden_states: paddle.Tensor,
forward_meta: ForwardMeta,
):
out = self.experts(
x=hidden_states,
gate=self.gate,
forward_meta=forward_meta,
)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x
@@ -344,7 +352,10 @@ class Ernie4_5_DecoderLayer(nn.Layer):
residual,
)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(
hidden_states=hidden_states,
forward_meta=forward_meta,
)
return hidden_states, residual
@@ -611,7 +622,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
return logits
def empty_input_forward(self):
def empty_input_forward(self, forward_meta):
"""
empty_input_forward
"""
@@ -623,7 +634,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate)
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate, forward_meta)
def forward(
self,