From 77514e3e1e63d3e28073e8b79ea85dc31c4b18e1 Mon Sep 17 00:00:00 2001 From: bukejiyu <52310069+bukejiyu@users.noreply.github.com> Date: Sat, 23 Aug 2025 13:13:41 +0800 Subject: [PATCH] [V1 Loader] support weight_only (#3413) * support wint4/wint8 * delete smoe case * update ci * print log --- .../model_executor/layers/embeddings.py | 2 +- fastdeploy/model_executor/layers/linear.py | 198 +++++++----------- fastdeploy/model_executor/layers/lm_head.py | 2 +- .../layers/moe/fused_moe_backend_base.py | 11 +- .../layers/moe/fused_moe_cutlass_backend.py | 190 +++++++++++++---- fastdeploy/model_executor/layers/moe/moe.py | 125 +++++------ .../layers/quantization/weight_only.py | 91 ++++++-- fastdeploy/model_executor/layers/utils.py | 10 +- .../model_loader/default_loader_v1.py | 8 +- .../model_executor/models/deepseek_v3.py | 58 +++-- .../model_executor/models/ernie4_5_moe.py | 54 +++-- .../models/ernie4_5_vl/dfnrope/modeling.py | 2 +- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 85 ++++---- .../models/ernie4_5_vl/modeling_resampler.py | 2 +- fastdeploy/model_executor/models/qwen3.py | 15 +- fastdeploy/model_executor/models/qwen3moe.py | 15 +- fastdeploy/model_executor/models/utils.py | 65 +----- fastdeploy/model_executor/utils.py | 179 ++++++++++++++++ fastdeploy/rl/rollout_model.py | 6 +- scripts/coverage_run.sh | 2 +- tests/ci_use/EB_VL_Lite/baseline.txt | 164 +++++---------- tests/conftest.py | 120 +++++++++++ tests/model_loader/__init__.py | 0 tests/model_loader/test_common_model.py | 175 ++++++++++++++++ 24 files changed, 1055 insertions(+), 524 deletions(-) create mode 100644 fastdeploy/model_executor/utils.py create mode 100644 tests/conftest.py create mode 100644 tests/model_loader/__init__.py create mode 100644 tests/model_loader/test_common_model.py diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index ba68c9ed0..5c26437de 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -22,7 +22,7 @@ from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from .utils import get_tensor diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 3cb62e973..47ce9365f 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -23,7 +23,7 @@ from paddle import nn from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase -from fastdeploy.model_executor.models.utils import ( +from fastdeploy.model_executor.utils import ( default_weight_loader, set_weight_attrs, slice_fn, @@ -39,6 +39,7 @@ class UnquantizedLinearMethod(QuantMethodBase): def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ extra_weight_attrs is a dictionary that may include parameters like: + - split_axis: axis along which to split the tensor in a distributed environment - output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns) - weight_loader: a callable or method responsible for loading the weight data """ @@ -48,12 +49,16 @@ class UnquantizedLinearMethod(QuantMethodBase): is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) + split_axis = extra_weight_attrs.get("split_axis") + if hasattr(layer, "nranks") and layer.nranks > 0: + _set_var_distributed(layer.weight, split_axis=split_axis) set_weight_attrs( layer.weight, - {"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))}, + { + **extra_weight_attrs, + "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), + }, ) - if hasattr(layer, "nranks") and layer.nranks > 1: - set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")}) def process_loaded_weights(self, layer, weights) -> None: # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation @@ -340,7 +345,6 @@ class ColumnParallelLinear(LinearBase): ), ) if self.nranks > 0: - _set_var_distributed(self.weight, split_axis=1) if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=1) @@ -399,28 +403,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ) def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + output_dim = getattr(param, "output_dim", None) + shard_dim = -1 if output_dim else 0 + output_size = param.shape[shard_dim] if loaded_shard_id is None: # Loaded weight is already fused on disk. - if self.nranks != 1: - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("gate", 0, self.output_size * self.nranks // 2), - ("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2), - ] - for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size] - self.weight_loader(param, loaded_weight_shard, shard_id) - else: - loaded_weight = get_tensor(loaded_weight) - param.copy_(loaded_weight, False) + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("gate", 0, output_size * self.nranks // 2), + ("up", output_size * self.nranks // 2, output_size * self.nranks // 2), + ] + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = slice_fn( + loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) else: - # 1.fused gate_up in disk - # 2.split gate up + # split gate up assert loaded_shard_id in ["gate", "up"] - output_dim = getattr(param, "output_dim", None) # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: - dim = -1 + if self.nranks != 1: + dim = -1 if output_dim else 0 if isinstance(loaded_weight, np.ndarray): size = loaded_weight.shape[dim] else: @@ -428,15 +431,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear): block_size = size // self.nranks shard_offset = self.local_rank * block_size shard_size = (self.local_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] + loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) loaded_weight = get_tensor(loaded_weight) - + if not param._is_initialized(): + param.initialize() + param_shard_size = output_size // 2 if loaded_shard_id == "gate": - param = param[:, : self.output_size // 2] - elif loaded_shard_id == "up": - param = param[:, self.output_size // 2 :] - + param_shard_offset = 0 + else: + # loaded_shard_id == "up" + param_shard_offset = param_shard_size + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) + param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" ) @@ -513,30 +521,25 @@ class QKVParallelLinear(ColumnParallelLinear): def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) + head_dim = param.shape[output_dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) if loaded_shard_id is None: # Loaded weight is already fused on disk - if self.nranks != 1: - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("q", 0, self.num_heads * self.head_dim), - ("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim), - ("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim), - ] - for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = loaded_weight_shard = slice_fn( - loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - else: - loaded_weight = get_tensor(loaded_weight) - split_loaded_weight = loaded_weight - param.copy_(split_loaded_weight, False) + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.num_heads * head_dim), + ("k", self.num_heads * head_dim, self.kv_num_heads * head_dim), + ("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim), + ] + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = slice_fn( + loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) else: - # 1.fused qkv in disk - # 2.split q k v + # split q k v assert loaded_shard_id in ["q", "k", "v"] # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: + if self.nranks != 1: dim = -1 if output_dim else 0 if isinstance(loaded_weight, np.ndarray): size = loaded_weight.shape[dim] @@ -545,20 +548,25 @@ class QKVParallelLinear(ColumnParallelLinear): block_size = size // self.nranks shard_offset = self.local_rank * block_size shard_size = (self.local_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] + loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) loaded_weight = get_tensor(loaded_weight) + if not param._is_initialized(): + param.initialize() if loaded_shard_id == "q": + param_shard_offset = 0 - param_shard_size = self.num_heads_per_rank * self.head_dim + param_shard_size = self.num_heads_per_rank * head_dim elif loaded_shard_id == "k": - param_shard_offset = self.num_heads_per_rank * self.head_dim - param_shard_size = self.kv_num_heads_per_rank * self.head_dim + param_shard_offset = self.num_heads_per_rank * head_dim + param_shard_size = self.kv_num_heads_per_rank * head_dim else: # loaded_shard_id == "v" - param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim - param_shard_size = self.kv_num_heads_per_rank * self.head_dim + param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim + param_shard_size = self.kv_num_heads_per_rank * head_dim + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" @@ -706,7 +714,6 @@ class RowParallelLinear(LinearBase): ), ) if self.nranks > 0: - _set_var_distributed(self.weight, split_axis=0) if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=0) @@ -732,7 +739,7 @@ class RowParallelLinear(LinearBase): return out -class KVBatchLinear(LinearBase): +class KVBatchLinear(nn.Layer): """ KVBatchLinear Layer for handling combined KV projections with bmm. """ @@ -740,13 +747,12 @@ class KVBatchLinear(LinearBase): def __init__( self, fd_config: FDConfig, + kv_b_proj: nn.Layer, prefix: str = "", kv_lora_rank: int = None, num_attention_heads: int = None, qk_nope_head_dim: int = None, v_head_dim: int = None, - with_bias: bool = False, - skip_quant: bool = False, ): """ Initializes a KV batch linear layer that internally splits into K and V projections. @@ -761,6 +767,7 @@ class KVBatchLinear(LinearBase): with_bias (bool): Whether to include bias or not. Defaults to False. skip_quant (bool): Whether to skip quantization. Defaults to False. """ + super().__init__() self.nranks = fd_config.parallel_config.tensor_parallel_size self.kv_lora_rank = kv_lora_rank self.num_attention_heads = num_attention_heads @@ -770,69 +777,27 @@ class KVBatchLinear(LinearBase): self.num_heads_per_partition = divide(num_attention_heads, self.nranks) self.local_rank = fd_config.parallel_config.tensor_parallel_rank - # Initialize parent with combined dimensions - super().__init__( - fd_config=fd_config, - prefix=prefix, - input_size=None, # Will be determined from weight shape - output_size=None, # Will be determined from weight shape - with_bias=with_bias, - add_bias=False, - skip_quant=skip_quant, - ) - self.weight_dtype = self._dtype + self.kv_b_proj = kv_b_proj + + self.weight_dtype = self._helper.get_default_dtype() # Override weight keys to use the combined kv_b_proj self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight" - self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight" - self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight" - self.k_b_proj_weight = self.create_parameter( - shape=[self.num_heads_per_partition, self.qk_nope_head_dim, self.kv_lora_rank], - dtype=self.weight_dtype, - is_bias=False, - default_initializer=paddle.nn.initializer.Constant(0), - ) + def process_weights_after_loading(self): - self.v_b_proj_weight = self.create_parameter( - shape=[self.num_heads_per_partition, self.kv_lora_rank, self.v_head_dim], - dtype=self.weight_dtype, - is_bias=False, - default_initializer=paddle.nn.initializer.Constant(0), - ) + w = self.kv_b_proj.weight.reshape( + [ + self.kv_lora_rank, + self.num_heads_per_partition, + -1, + ] + ).transpose(perm=[1, 2, 0]) + self.kv_b_proj = None - set_weight_attrs( - self.k_b_proj_weight, - {"weight_loader": self.weight_loader}, - ) + if w.dtype != self.weight_dtype: + w = w.cast(self.weight_dtype) - if self.nranks > 0: - _set_var_distributed(self.k_b_proj_weight, split_axis=1) - set_weight_attrs(self.k_b_proj_weight, {"output_dim": True}) - - def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - output_dim = getattr(param, "output_dim", None) - # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: - dim = -1 - size = loaded_weight.get_shape()[dim] - block_size = size // self.nranks - shard_offset = self.local_rank * block_size - shard_size = (self.local_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] - w = ( - get_tensor(loaded_weight) - .reshape( - [ - self.kv_lora_rank, - self.num_heads_per_partition, - -1, - ] - ) - .transpose(perm=[1, 2, 0]) - ) - if param.dtype != w.dtype: - w = w.cast(param.dtype) # Split into K and V weights # wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank] wk_b = w[:, : self.qk_nope_head_dim, :] @@ -840,9 +805,8 @@ class KVBatchLinear(LinearBase): raise ValueError("self.v_head_dim should not be None") # wv_b: [num_heads, kv_lora_rank, v_head_dim] wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1]) - - self.k_b_proj_weight.set_value(wk_b) - self.v_b_proj_weight.set_value(wv_b) + self.k_b_proj_weight = wk_b + self.v_b_proj_weight = wv_b def load_state_dict(self, state_dict: dict): """ @@ -916,7 +880,7 @@ class KVBatchLinear(LinearBase): out = paddle.bmm(x, self.v_b_proj_weight) return out - def forward_cuda(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor: + def forward(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor: """ Forward function that can handle both K and V projections diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index dce0ccbc4..f71f828eb 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -22,7 +22,7 @@ from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from .utils import get_tensor diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index d72ce1232..5b3b1c6a4 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -19,7 +19,7 @@ from abc import abstractmethod import paddle from paddle import nn -from fastdeploy.model_executor.layers.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.platforms import current_platform from ..quantization.quant_base import QuantMethodBase @@ -185,9 +185,11 @@ class UnquantizedFusedMoEMethod(MoEMethodBase): if current_platform.is_cuda(): self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2] self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size] + extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}} else: self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size] self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size] + extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} layer.up_gate_proj_weight = layer.create_parameter( shape=self.up_gate_proj_weight_shape, @@ -203,10 +205,3 @@ class UnquantizedFusedMoEMethod(MoEMethodBase): set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs) set_weight_attrs(layer.down_proj_weight, extra_weight_attrs) - - if layer.moe_use_gate_correction_bias: - gate_correction_bias_shape = [1, layer.num_experts] - layer.gate_correction_bias = layer.create_parameter( - shape=gate_correction_bias_shape, - dtype="float32", - ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 2be90f8f9..902babcdf 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -38,6 +38,8 @@ elif current_platform.is_iluvatar(): moe_expert_reduce, ) +from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs + # used for deepseek_v3 def get_moe_scores( @@ -93,8 +95,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn( permute_input, token_nums_per_expert, - layer.up_gate_proj_weight, - layer.down_proj_weight, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), @@ -106,8 +108,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): return fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, token_nums_per_expert, - layer.up_gate_proj_weight, - layer.down_proj_weight, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), @@ -392,12 +394,12 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): Paddle cutlass create weight process. """ self.weight_dtype = "int8" - self.ffn1_weight_shape = [ + self.up_gate_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size // 2, layer.moe_intermediate_size * 2, ] - self.ffn2_weight_shape = [ + self.down_proj_weight_shape = [ layer.num_local_experts, layer.moe_intermediate_size // 2, layer.hidden_size, @@ -406,7 +408,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): layer, self.added_weight_attrs[0], layer.create_parameter( - shape=self.ffn1_weight_shape, + shape=self.up_gate_proj_weight_shape, dtype=self.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), @@ -415,7 +417,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): layer, self.added_weight_attrs[1], layer.create_parameter( - shape=self.ffn2_weight_shape, + shape=self.down_proj_weight_shape, dtype=self.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), @@ -625,71 +627,177 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): Paddle cutlass create weight process. """ self.default_dtype = layer._helper.get_default_dtype() - self.weight_dtype = "int8" - - up_gate_proj_weight_name = self.added_weight_attrs[0] - down_proj_weight_name = self.added_weight_attrs[1] if self.moe_quant_type == "weight_only_int4": - self.ffn1_weight_shape = [ + self.up_gate_proj_weight_shape = [ layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size, ] else: - self.ffn1_weight_shape = [ + self.up_gate_proj_weight_shape = [ layer.num_local_experts, layer.moe_intermediate_size * 2, layer.hidden_size, ] if self.moe_quant_type == "weight_only_int4": - self.ffn2_weight_shape = [ + self.down_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size // 2, layer.moe_intermediate_size, ] else: - self.ffn2_weight_shape = [ + self.down_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size, ] + self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2] + self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size] + + if layer.fd_config.load_config.load_choices == "default_v1": + layer.up_gate_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + self.weight_dtype = "int8" + + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} + set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs) + scale_extra_weight_attrs = { + **extra_weight_attrs, + "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None}, + } + set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs) + + def process_weights_after_loading(self, layer): + """ """ + if not layer.fd_config.load_config.load_choices == "default_v1": + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + else: + weight_type = "down" + + # 1.init shape and type + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + unquantized_weight_name = weight_name.replace("quant_weight", "weight") + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_dtype = "int8" + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = self.default_dtype + + # 2.crate tmp tensor + + weight = paddle.empty(weight_shape, dtype=weight_dtype) + scale = paddle.empty(scale_shape, dtype=scale_dtype) + + # 3.quantize weight + + for expert_id in range(layer.num_experts): + weight[expert_id], scale[expert_id] = weight_quantize( + getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type + ) + + free_tensor(getattr(layer, unquantized_weight_name)) + + # create weight setattr( layer, - up_gate_proj_weight_name, + weight_name, layer.create_parameter( - shape=self.ffn1_weight_shape, - dtype=self.weight_dtype, + shape=weight_shape, + dtype=weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + # create scale setattr( layer, - down_proj_weight_name, + scale_name, layer.create_parameter( - shape=self.ffn2_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scale - setattr( - layer, - self.added_scale_attrs[0], - layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[1], - layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size], - dtype=self.default_dtype, + shape=scale_shape, + dtype=scale_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + getattr(layer, weight_name).copy_(weight, False) + getattr(layer, scale_name).copy_(scale, False) def process_loaded_weights(self, layer: nn.Layer, state_dict): """ diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 28b9afdbe..475b3015c 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -23,6 +23,7 @@ from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.utils import slice_fn from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger @@ -78,6 +79,7 @@ class FusedMoE(nn.Layer): routed_scaling_factor: float = 1.0, layer_idx: int = -1, moe_tag: str = "", + gate_correction_bias=None, weight_key_map: dict = {}, ): """ @@ -155,9 +157,10 @@ class FusedMoE(nn.Layer): # It's for RL to build model self.init_moe_weights() else: - self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None) - if self.gate_correction_bias_key is not None: - self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32") + if gate_correction_bias is not None: + self.gate_correction_bias = gate_correction_bias + else: + self.gate_correction_bias = None if moe_quant_config: if ( moe_quant_config @@ -179,54 +182,72 @@ class FusedMoE(nn.Layer): def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None): from fastdeploy.platforms import current_platform + if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): + SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM + elif current_platform.is_cuda(): + SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} + else: + SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} + + if not param._is_initialized(): + param.initialize() + if shard_id is None: # 1.gate up fused in disk - if self.tp_size > 1: - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("gate", 0, self.moe_intermediate_size * self.tp_size), - ("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size), - ] - for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size] - self.weight_loader(param, loaded_weight_shard, expert_id, shard_id) - else: - expert_param = param[expert_id - self.expert_id_offset] - loaded_weight = get_tensor(loaded_weight) - expert_param.copy_(loaded_weight, False) + output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("gate", 0, output_size // 2 * self.tp_size), + ("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size), + ] + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = slice_fn( + loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size + ) + self.weight_loader(param, loaded_weight_shard, expert_id, shard_id) else: # 2.gate up splited in disk assert shard_id in ["gate", "down", "up"] - if current_platform.is_cuda(): - SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} - else: - SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} self._load_expert_weight( param=param, expert_id=expert_id, - shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], loaded_weight=loaded_weight, shard_id=shard_id, + shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], ) - def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id): - tensor_size = expert_param.shape[shard_dim] // 2 - if shard_id == "gate": - expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...] - elif shard_id == "up": - expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...] - + def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): + dim = -1 if shard_dim else 0 if self.tp_size > 1: if isinstance(loaded_weight, np.ndarray): - size = loaded_weight.shape[-1] + size = loaded_weight.shape[dim] else: - size = loaded_weight.get_shape()[-1] + size = loaded_weight.get_shape()[dim] block_size = size // self.tp_size shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] + loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) loaded_weight = get_tensor(loaded_weight) + + expert_param = param[expert_id - self.expert_id_offset] + param_shard_size = expert_param.shape[dim] // 2 + if shard_id == "gate": + param_shard_offset = 0 + else: + # shard_id == "up": + param_shard_offset = param_shard_size + expert_param = slice_fn( + expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size + ) + if hasattr(param, "tensor_track"): + # for dyn quant + param.tensor_track.mark( + start=param_shard_offset, + end=param_shard_offset + param_shard_size, + batch_id=expert_id - self.expert_id_offset, + ) + # To ensure compatibility across backends, apply an extra transpose for GCU and XPU if expert_param.shape != loaded_weight.shape: loaded_weight = loaded_weight.transpose([1, 0]) @@ -235,17 +256,22 @@ class FusedMoE(nn.Layer): ) expert_param.copy_(loaded_weight, False) - def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id): - if self.tp_size > 1: + def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): + if self.tp_size > 1 and shard_dim is not None: + dim = -1 if shard_dim else 0 if isinstance(loaded_weight, np.ndarray): - size = loaded_weight.shape[shard_dim] + size = loaded_weight.shape[dim] else: - size = loaded_weight.get_shape()[shard_dim] + size = loaded_weight.get_shape()[dim] block_size = size // self.tp_size shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size - loaded_weight = loaded_weight[shard_offset:shard_size, ...] + loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) loaded_weight = get_tensor(loaded_weight) + expert_param = param[expert_id - self.expert_id_offset] + if hasattr(param, "tensor_track"): + # for dyn quant + param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset) # To ensure compatibility across backends, apply an extra transpose for GCU and XPU if expert_param.shape != loaded_weight.shape: loaded_weight = loaded_weight.transpose([1, 0]) @@ -258,15 +284,14 @@ class FusedMoE(nn.Layer): self, param, expert_id, - shard_dim, loaded_weight, shard_id, + shard_dim=None, ): - expert_param = param[expert_id - self.expert_id_offset] if shard_id == "down": - self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id) + self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim) elif shard_id in ["gate", "up"]: - self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id) + self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim) @classmethod def make_expert_params_mapping( @@ -314,13 +339,6 @@ class FusedMoE(nn.Layer): Combines weight shape initialization and parameter creation into a single function. """ # Initialize weight shapes - gate_correction_bias_shape = [1, self.num_experts] - - if self.fd_config.model_config.moe_use_aux_free: - self.gate_correction_bias = self.create_parameter( - shape=gate_correction_bias_shape, - dtype="float32", - ) up_gate_proj_output_dim = self.moe_intermediate_size * 2 if self.moe_quant_type in ["block_wise_fp8", "wint8"]: up_gate_proj_weight_shape = [ @@ -535,19 +553,6 @@ class FusedMoE(nn.Layer): """ load_state_dict function. """ - if not is_rearrange: - if self.moe_use_gate_correction_bias: - gate_correction_bias_tensor = self.extract_gate_correction_bias( - self.gate_correction_bias_key, state_dict - ) - if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape: - gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape) - self.gate_correction_bias.set_value(gate_correction_bias_tensor) - else: - self.gate_correction_bias = None - else: - self.gate_correction_bias = None - if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method): if self.fd_config.model_config.is_quantized: if getattr(self.fd_config.quant_config, "is_permuted", True): diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 4825faaf7..6e4c6f34b 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -21,6 +21,11 @@ from typing import Optional import paddle from paddle.nn.quant import weight_only_linear, weight_quantize +from fastdeploy.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, +) +from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs from fastdeploy.platforms import current_platform from ..moe import FusedMoE @@ -135,9 +140,7 @@ class WINT8Config(WeightOnlyConfig): weight only int8 config """ - def __init__( - self, - ) -> None: + def __init__(self) -> None: super().__init__("weight_only_int8") @classmethod @@ -179,27 +182,89 @@ class WeightOnlyLinearMethod(QuantMethodBase): self.quant_config = quant_config def create_weights(self, layer, **extra_weight_attrs): + if layer.fd_config.load_config.load_choices == "default_v1": + layer.weight = layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + quant_attrs = extra_weight_attrs + if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear): + quant_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker( + shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim") + ), + } + set_weight_attrs( + layer.weight, + quant_attrs, + ) + else: + # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. + weight_scale_shape = [layer.weight_shape[1]] + layer.weight_shape.reverse() + if self.quant_config.name() == "wint4": + layer.weight_shape[0] //= 2 + layer.weight_dtype = "int8" + layer.weight = layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) - # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. - weight_scale_shape = [layer.weight_shape[1]] + output_dim = extra_weight_attrs.get("output_dim") + output_dim = not output_dim + weight_loader = extra_weight_attrs.get("weight_loader") + set_weight_attrs( + layer.weight, + { + "weight_loader": weight_loader, + "output_dim": output_dim, + }, + ) - layer.weight_shape.reverse() - if self.quant_config.name() == "wint4": - layer.weight_shape[0] //= 2 - layer.weight_dtype = "int8" + layer.weight_scale = layer.create_parameter( + shape=weight_scale_shape, + dtype=layer._dtype, + is_bias=False, + ) + + set_weight_attrs( + layer.weight_scale, + { + "weight_loader": weight_loader, + "output_dim": output_dim, + }, + ) + + def process_weights_after_loading(self, layer) -> None: + if not layer.fd_config.load_config.load_choices == "default_v1": + return + quanted_weight_tensor, weight_scale_tensor = weight_quantize( + layer.weight, + algo=self.quant_config.algo, + arch=self.quant_config.weight_only_linear_arch, + ) + + free_tensor(layer.weight) layer.weight = layer.create_parameter( - shape=layer.weight_shape, - dtype=layer.weight_dtype, + shape=quanted_weight_tensor.shape, + dtype="int8", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) - layer.weight_scale = layer.create_parameter( - shape=weight_scale_shape, + shape=weight_scale_tensor.shape, dtype=layer._dtype, is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), ) + layer.weight.copy_(quanted_weight_tensor, False) + layer.weight_scale.copy_(weight_scale_tensor, False) @abstractmethod def process_loaded_weights(self, layer, weights) -> None: diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index b5e1c2ad0..e7a6c0137 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -15,7 +15,7 @@ """ import functools -from typing import Any, Optional, Tuple, Union +from typing import Tuple, Union import numpy as np import paddle @@ -45,14 +45,6 @@ if cache_params != "none": c8_state_dict = paddle.load(cache_params, return_numpy=True) -# TODO(lulinjun): delete it, import from fastdeploy.model_executor.models.utils after supporting all backends -def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): - if param_attr_map is None: - return - for key, value in param_attr_map.items(): - setattr(param, key, value) - - def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: """ Only used in deep_gemm block wise quant weight. diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 4d79772e5..51e80e7b0 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -14,8 +14,6 @@ # limitations under the License. """ -import contextlib - import paddle from paddle import nn from paddleformers.utils.log import logger @@ -56,15 +54,12 @@ class DefaultModelLoaderV1(BaseModelLoader): def load_model(self, fd_config: FDConfig) -> nn.Layer: architectures = fd_config.model_config.architectures[0] logger.info(f"Starting to load model {architectures}") + context = paddle.LazyGuard() if fd_config.load_config.dynamic_load_weight: # register rl model import fastdeploy.rl # noqa architectures = architectures + "RL" - context = paddle.LazyGuard() - - else: - context = contextlib.nullcontext() with context: model_cls = ModelRegistry.get_class(architectures) @@ -75,6 +70,5 @@ class DefaultModelLoaderV1(BaseModelLoader): # RL model not need set_state_dict if fd_config.load_config.dynamic_load_weight: return model - self.load_weights(model, fd_config) return model diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index b65925be2..f240e760f 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -17,6 +17,7 @@ from __future__ import annotations import math +import re from functools import partial import paddle @@ -122,6 +123,25 @@ class DeepSeekV3MoE(nn.Layer): "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } + self.gate = ReplicatedLinear( + fd_config=fd_config, + prefix=f"{prefix}.gate", + input_size=fd_config.model_config.hidden_size, + output_size=fd_config.model_config.n_routed_experts, + with_bias=False, + skip_quant=True, + weight_dtype="float32", + ) + + if fd_config.model_config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = self.create_parameter( + shape=[1, fd_config.model_config.n_routed_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + else: + self.gate.e_score_correction_bias = None + self.experts = FusedMoE( fd_config=fd_config, reduce_results=False, @@ -133,19 +153,10 @@ class DeepSeekV3MoE(nn.Layer): n_group=fd_config.model_config.n_group, routed_scaling_factor=fd_config.model_config.routed_scaling_factor, layer_idx=layer_id, + gate_correction_bias=self.gate.e_score_correction_bias, weight_key_map=weight_key_map, ) - self.gate = ReplicatedLinear( - fd_config=fd_config, - prefix=f"{prefix}.gate", - input_size=fd_config.model_config.hidden_size, - output_size=fd_config.model_config.n_routed_experts, - with_bias=False, - skip_quant=True, - weight_dtype="float32", - ) - self.num_shared_experts = fd_config.model_config.n_shared_experts shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size @@ -258,6 +269,7 @@ class DeepseekV3MLAAttention(nn.Layer): self.kv_b_proj_bmm = KVBatchLinear( fd_config=fd_config, + kv_b_proj=self.kv_b_proj, prefix=f"{prefix}.kv_b_proj", kv_lora_rank=self.kv_lora_rank, num_attention_heads=self.num_attention_heads, @@ -617,7 +629,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): Args: weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -637,7 +652,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): param_down_proj_name="experts.down_proj_", ) params_dict = dict(self.named_parameters()) - + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model") @@ -668,19 +683,18 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) break else: - if loaded_weight_name not in params_dict: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: continue - param = params_dict[loaded_weight_name] + param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) - if "kv_b_proj.weight" in loaded_weight_name: - # handle kv_b_proj_bmm - model_param_name = loaded_weight_name.replace( - "kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight" - ) - param = params_dict[model_param_name] - weight_loader = getattr(param, "weight_loader", None) - weight_loader(param, loaded_weight, shard_id) + + model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) + if "kv_b_proj" in model_sublayer_name: + kv_model_sublayer_name = model_sublayer_name.replace("kv_b_proj", "kv_b_proj_bmm") + process_weights_after_loading_fn(kv_model_sublayer_name) + process_weights_after_loading_fn(model_sublayer_name, param) def compute_logits(self, hidden_states: paddle.Tensor): """ """ diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index c4f8b0872..ed5226dce 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -17,6 +17,7 @@ from __future__ import annotations import inspect +import re from functools import partial from typing import Dict, Union @@ -149,15 +150,6 @@ class Ernie4_5_MoE(nn.Layer): "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } - self.experts = FusedMoE( - fd_config=fd_config, - moe_intermediate_size=fd_config.model_config.moe_intermediate_size, - num_experts=fd_config.model_config.moe_num_experts, - top_k=fd_config.model_config.moe_k, - layer_idx=layer_id, - weight_key_map=weight_key_map, - ) - self.gate = ReplicatedLinear( fd_config=fd_config, prefix=f"{prefix}.gate", @@ -168,6 +160,25 @@ class Ernie4_5_MoE(nn.Layer): weight_dtype="float32", ) + self.experts = FusedMoE( + fd_config=fd_config, + moe_intermediate_size=fd_config.model_config.moe_intermediate_size, + num_experts=fd_config.model_config.moe_num_experts, + top_k=fd_config.model_config.moe_k, + layer_idx=layer_id, + gate_correction_bias=None, + weight_key_map=weight_key_map, + ) + + if fd_config.model_config.moe_use_aux_free: + self.experts.gate_correction_bias = self.create_parameter( + shape=[1, fd_config.model_config.moe_num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + else: + self.experts.gate_correction_bias = None + self.num_shared_experts = fd_config.model_config.moe_num_shared_experts if self.num_shared_experts > 0: shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size @@ -180,6 +191,13 @@ class Ernie4_5_MoE(nn.Layer): def load_state_dict(self, state_dict): self.gate.load_state_dict(state_dict) self.experts.load_state_dict(state_dict) + if self.experts.gate_correction_bias is not None: + gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key) + if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape: + gate_correction_bias_tensor = gate_correction_bias_tensor.reshape( + self.experts.gate_correction_bias.shape + ) + self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor) if self.num_shared_experts > 0: self.shared_experts.load_state_dict(state_dict) @@ -441,12 +459,16 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) general_params_mapping = [ # (param_name, weight_name, expert_id, shard_id) ("embed_tokens.embeddings", "embed_tokens", None, None), ("lm_head.linear", "lm_head", None, None), + ("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, None), ] expert_params_mapping = [] @@ -458,13 +480,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): param_gate_up_proj_name="experts.up_gate_proj_", param_down_proj_name="experts.down_proj_", ) - expert_params_mapping.append( - ("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, "gate_bias") - ) - logger.info(f"expert params mapping:{expert_params_mapping}") all_param_mapping = general_params_mapping + expert_params_mapping params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) expert_id = None shard_id = None @@ -478,9 +497,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): shard_id = shard_id break else: - if loaded_weight_name not in params_dict.keys(): + model_param_name = loaded_weight_name + if model_param_name not in params_dict.keys(): continue - param = params_dict[loaded_weight_name] + param = params_dict[model_param_name] # Get weight loader from parameter and set weight weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) @@ -490,6 +510,8 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): else: weight_loader(param, loaded_weight) + model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py index fcfd80ec3..e0628e59d 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py @@ -34,7 +34,7 @@ from paddle.nn.functional.flash_attention import ( from paddleformers.transformers.model_utils import PretrainedModel from fastdeploy.model_executor.layers.utils import divide, get_tensor -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from .activation import ACT2FN from .configuration import DFNRopeVisionTransformerConfig diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 92146b19a..600811ff3 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -17,6 +17,7 @@ from __future__ import annotations import inspect +import re from dataclasses import dataclass from functools import partial from typing import Dict, Optional, Union @@ -38,7 +39,6 @@ from fastdeploy.model_executor.layers.linear import ReplicatedLinear from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.models.ernie4_5_moe import ( Ernie4_5_Attention, Ernie4_5_MLP, @@ -75,7 +75,15 @@ class VLMoEMeta: class Ernie4_5_VLMoeBlock(nn.Layer): - def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str, moe_tag: str, expert_id_offset: int) -> None: + def __init__( + self, + fd_config: FDConfig, + layer_id: int, + prefix: str, + moe_tag: str, + expert_id_offset: int, + gate_correction_bias=None, + ) -> None: super().__init__() moe_quant_type = "" if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None: @@ -120,6 +128,7 @@ class Ernie4_5_VLMoeBlock(nn.Layer): layer_idx=layer_id, moe_tag=moe_tag, weight_key_map=weight_key_map, + gate_correction_bias=gate_correction_bias, ) self.gate = ReplicatedLinear( @@ -133,29 +142,10 @@ class Ernie4_5_VLMoeBlock(nn.Layer): weight_key="weight" if moe_tag == "Text" else "weight_1", ) - if moe_tag == "Text": - self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_text - elif moe_tag == "Image": - self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_image - def forward(self, hidden_states: paddle.Tensor): out = self.experts(hidden_states, self.gate) return out - def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict): - """ - extract_gate_correction_bias function. - """ - gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32") - return gate_correction_bias_tensor[0].unsqueeze(0) - - def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict): - """ - extract_gate_correction_bias function. - """ - gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32") - return gate_correction_bias_tensor[1].unsqueeze(0) - def load_state_dict(self, state_dict): self.experts.load_state_dict(state_dict) self.gate.load_state_dict(state_dict) @@ -186,10 +176,25 @@ class Ernie4_5_VLMoE(nn.Layer): image_moe_layer_end_index = moe_layer_end_index[1] assert text_moe_layer_start_index <= text_moe_layer_end_index + if fd_config.model_config.moe_use_aux_free: + self.gate_correction_bias = self.create_parameter( + shape=[2, fd_config.model_config.moe_num_experts[0]], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + if not self.gate_correction_bias._is_initialized(): + self.gate_correction_bias.initialize() + else: + self.gate_correction_bias = None if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index: self.text_fused_moe = Ernie4_5_VLMoeBlock( - fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}", moe_tag="Text", expert_id_offset=0 + fd_config=fd_config, + layer_id=layer_id, + prefix=f"{prefix}", + moe_tag="Text", + expert_id_offset=0, + gate_correction_bias=self.gate_correction_bias[0] if fd_config.model_config.moe_use_aux_free else None, ) else: self.text_fused_moe = Ernie4_5_VLMLP( @@ -207,6 +212,7 @@ class Ernie4_5_VLMoE(nn.Layer): prefix=f"{prefix}", moe_tag="Image", expert_id_offset=fd_config.model_config.moe_num_experts[0], + gate_correction_bias=self.gate_correction_bias[1] if fd_config.model_config.moe_use_aux_free else None, ) else: self.image_fused_moe = Ernie4_5_VLMLP( @@ -226,10 +232,13 @@ class Ernie4_5_VLMoE(nn.Layer): ) def load_state_dict(self, state_dict): + if self.gate_correction_bias is not None: + gate_correction_bias_tensor = state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key) + if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape: + gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape) + self.gate_correction_bias.set_value(gate_correction_bias_tensor) self.text_fused_moe.load_state_dict(state_dict) self.image_fused_moe.load_state_dict(state_dict) - if self.text_fused_moe.experts.moe_use_gate_correction_bias: - state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key) if self.num_shared_experts > 0: self.shared_experts.load_state_dict(state_dict) @@ -563,19 +572,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): def name(self): return "Ernie4_5_VLMoeForConditionalGeneration" - def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight): - text_param_name = loaded_weight_name.replace( - "moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias" - ) - image_param_name = loaded_weight_name.replace( - "moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias" - ) - text_param = params_dict[text_param_name] - image_param = params_dict[image_param_name] - loaded_weight = get_tensor(loaded_weight) - text_param.copy_(loaded_weight[0].unsqueeze(0), False) - image_param.copy_(loaded_weight[1].unsqueeze(0), False) - @paddle.no_grad() def load_weights(self, weights_iterator) -> None: """ @@ -585,7 +581,10 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) general_params_mapping = [ # (param_name, weight_name, expert_id, shard_id) @@ -594,6 +593,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): ("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"), ("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"), ("resampler_model", "ernie.resampler_model", None, None), + ("vision_model", "ernie.vision_model", None, None), + ("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None), ] text_expert_params_mapping = [] @@ -617,6 +618,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) expert_id = None shard_id = None for loaded_weight_name, loaded_weight in weights_iterator: @@ -629,10 +631,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): shard_id = shard_id break else: - # text and image gate_correction_bias is fused in ckpt and need load independently - if "moe_statics.e_score_correction_bias" in loaded_weight_name: - self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight) - continue if loaded_weight_name not in params_dict.keys(): continue model_param_name = loaded_weight_name @@ -646,7 +644,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) else: weight_loader(param, loaded_weight) - + model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py index 80e664e49..149b4efe3 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py @@ -30,7 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import ( reduce_scatter_group, scatter_axis, ) -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs class ScatterOp(PyLayer): diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 04988740d..6d4553dc1 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -16,6 +16,7 @@ from __future__ import annotations +import re from functools import partial import paddle @@ -254,7 +255,10 @@ class Qwen3ForCausalLM(ModelForCasualLM): weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -266,8 +270,8 @@ class Qwen3ForCausalLM(ModelForCasualLM): ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ] - params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: @@ -280,11 +284,14 @@ class Qwen3ForCausalLM(ModelForCasualLM): weight_loader(param, loaded_weight, shard_id) break else: - if loaded_weight_name not in params_dict: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: continue - param = params_dict[loaded_weight_name] + param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) + model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 5857ad144..3dce5c976 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -16,6 +16,7 @@ from __future__ import annotations +import re from functools import partial import paddle @@ -334,7 +335,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -348,6 +352,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): ] expert_params_mapping = self.get_expert_mapping() params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: @@ -374,12 +379,16 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) break else: - if loaded_weight_name not in params_dict: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: continue - param = params_dict[loaded_weight_name] + param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) + model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) + @paddle.no_grad() def set_state_dict(self, state_dict): """ diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 78fd8a40e..063344d19 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -24,7 +24,7 @@ import random import re import struct from functools import partial -from typing import Any, NamedTuple, Optional, Union +from typing import NamedTuple, Optional import numpy as np import paddle @@ -40,73 +40,10 @@ from paddleformers.utils.env import ( from paddleformers.utils.log import logger from tqdm import tqdm -from fastdeploy.config import FDConfig -from fastdeploy.model_executor.layers.utils import get_tensor - MAX_BSZ = 512 MAX_DRAFT_TOKENS = 6 -def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): - if param_attr_map is None: - return - for key, value in param_attr_map.items(): - setattr(param, key, value) - - -def slice_fn(weight_or_paramter, output_dim, start, end, step=1): - if hasattr(weight_or_paramter, "get_shape"): - shape = weight_or_paramter.get_shape() - else: - shape = weight_or_paramter.shape - if len(shape) == 1: - weight_or_paramter = weight_or_paramter[start:end] - elif output_dim: - weight_or_paramter = weight_or_paramter[..., start:end] - else: - weight_or_paramter = weight_or_paramter[start:end, ...] - return weight_or_paramter - - -def default_weight_loader(fd_config: FDConfig) -> None: - """Default weight loader""" - - def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): - """fn""" - try: - output_dim = getattr(param, "output_dim", None) - # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: - dim = -1 if output_dim else 0 - size = loaded_weight.get_shape()[dim] - block_size = size // fd_config.parallel_config.tensor_parallel_size - shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size - shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size - if output_dim: - loaded_weight = loaded_weight[..., shard_offset:shard_size] - else: - loaded_weight = loaded_weight[shard_offset:shard_size, ...] - - loaded_weight = get_tensor(loaded_weight) - # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation - if param.dtype != loaded_weight.dtype: - loaded_weight = loaded_weight.cast(param.dtype) - - if param.shape != loaded_weight.shape: - try: - param = param.reshape(loaded_weight.shape) - except ValueError as e: - raise ValueError( - f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}" - ) - - param.copy_(loaded_weight, False) - except Exception: - raise - - return fn - - class LayerIdPlaceholder(str, enum.Enum): """LayerIdPlaceholder""" diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py new file mode 100644 index 000000000..31cd67172 --- /dev/null +++ b/fastdeploy/model_executor/utils.py @@ -0,0 +1,179 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Any, Optional, Union + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.utils import get_tensor + + +class BitMaskTracker: + def __init__(self, length: int): + """ + Track filling status along a single dimension using a bitmask. + + Args: + length (int): Number of positions to track (e.g., columns or rows) + """ + self.length = length + self.mask = 0 + + def mark(self, start: int, end: int): + """ + Mark the range [start, end) as filled. + + Args: + start (int): Start index (inclusive) + end (int): End index (exclusive) + """ + if start < 0 or end > self.length or start >= end: + raise ValueError("Invalid mark range") + block = ((1 << (end - start)) - 1) << start + self.mask |= block + + def is_full(self) -> bool: + """Return True if all positions are filled.""" + return self.mask == (1 << self.length) - 1 + + +class TensorTracker: + def __init__(self, shape: tuple, output_dim: int): + """ + Unified tracker for 2D or 3D tensors. + + Args: + shape (tuple): Tensor shape + output_dim (bool): + - 2D: True = track columns (dim=1), False = track rows (dim=0) + - 3D: True = track columns (dim=2), False = track rows (dim=1) + """ + self.shape = shape + self.output_dim = output_dim + + if len(shape) == 2: + self.track_dim = 1 if output_dim else 0 + self.trackers = [BitMaskTracker(shape[self.track_dim])] + elif len(shape) == 3: + batch = shape[0] + self.track_dim = 2 if output_dim else 1 + self.trackers = [BitMaskTracker(shape[self.track_dim]) for _ in range(batch)] + else: + raise ValueError("Only 2D or 3D tensors supported") + + def mark(self, start: int = 0, end: int = None, batch_id: int = None): + """ + Mark a slice of the tensor as filled. + + Args: + batch_id (int, optional): Batch index for 3D tensors + start (int): Start index along tracked dimension + end (int): End index along tracked dimension + """ + if end is None: + end = self.shape[self.track_dim] + + if len(self.shape) == 2: + self.trackers[0].mark(start, end) + else: + if batch_id is None: + raise ValueError("batch_id must be provided for 3D tensor") + self.trackers[batch_id].mark(start, end) + + def is_fully_copied(self) -> bool: + """Return True if the tensor is fully filled along tracked dimension(s).""" + return all(tr.is_full() for tr in self.trackers) + + +def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): + if param_attr_map is None: + return + for key, value in param_attr_map.items(): + setattr(param, key, value) + + +def slice_fn(weight_or_paramter, output_dim, start, end, step=1): + if hasattr(weight_or_paramter, "get_shape"): + shape = weight_or_paramter.get_shape() + else: + shape = weight_or_paramter.shape + if len(shape) == 1: + weight_or_paramter = weight_or_paramter[start:end] + elif output_dim: + weight_or_paramter = weight_or_paramter[..., start:end] + else: + weight_or_paramter = weight_or_paramter[start:end, ...] + return weight_or_paramter + + +def process_weights_after_loading(sublayers_dict: dict): + """ + process_weights_after_loading: e.g., handle extracted weights (quantization, reshaping, etc.) + """ + + def fn(model_sublayer_name: str, param=None): + from fastdeploy.model_executor.layers.linear import KVBatchLinear + + if model_sublayer_name not in sublayers_dict: + return + model_sublayer = sublayers_dict[model_sublayer_name] + if isinstance(model_sublayer, KVBatchLinear): + model_sublayer.process_weights_after_loading() + if hasattr(model_sublayer, "quant_method"): + quant_method = getattr(model_sublayer, "quant_method", None) + if not hasattr(quant_method, "process_weights_after_loading"): + return + if param is not None and hasattr(param, "tensor_track") and not param.tensor_track.is_fully_copied(): + return + quant_method.process_weights_after_loading(model_sublayer) + + return fn + + +def free_tensor(tensor): + if hasattr(tensor, "tensor_track"): + tensor.tensor_track = None + tensor.value().get_tensor()._clear() + del tensor + + +def default_weight_loader(fd_config: FDConfig) -> None: + """Default weight loader""" + + def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): + """fn""" + output_dim = getattr(param, "output_dim", None) + # Tensor parallelism splits the weight along the output_dim + if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: + dim = -1 if output_dim else 0 + size = loaded_weight.get_shape()[dim] + block_size = size // fd_config.parallel_config.tensor_parallel_size + shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size + shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size + loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) + + loaded_weight = get_tensor(loaded_weight) + # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation + if param.dtype != loaded_weight.dtype: + loaded_weight = loaded_weight.cast(param.dtype) + if param.shape != loaded_weight.shape: + # for e_score_correction_bias + loaded_weight = loaded_weight.reshape(param.shape) + assert param.shape == loaded_weight.shape, ( + f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + param.copy_(loaded_weight, False) + + return fn diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index fd4165174..33508603d 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -247,9 +247,9 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener ) if self.fd_config.model_config.moe_use_aux_free: - self.infer_to_train_mapping[ - f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.gate_correction_bias" - ] = f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" + self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate_correction_bias"] = ( + f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" + ) # Initialize defaultdict for expert weights from collections import defaultdict diff --git a/scripts/coverage_run.sh b/scripts/coverage_run.sh index 3952f43b0..eab0073d6 100644 --- a/scripts/coverage_run.sh +++ b/scripts/coverage_run.sh @@ -54,7 +54,7 @@ success_pytest=0 for file in $TEST_FILES; do echo "Running pytest file: $file" - python -m coverage run --parallel-mode -m pytest "$file" + python -m coverage run --parallel-mode -m pytest "$file" -vv -s status=$? if [ "$status" -ne 0 ]; then echo "$file" >> "$failed_tests_file" diff --git a/tests/ci_use/EB_VL_Lite/baseline.txt b/tests/ci_use/EB_VL_Lite/baseline.txt index 6cd3d9655..43d284bfb 100644 --- a/tests/ci_use/EB_VL_Lite/baseline.txt +++ b/tests/ci_use/EB_VL_Lite/baseline.txt @@ -415,13 +415,12 @@ ernie.layers.1.self_attn.qkv_proj.weight ernie.layers.1.self_attn.qkv_proj.weight_scale ernie.layers.1.self_attn.o_proj.weight ernie.layers.1.self_attn.o_proj.weight_scale -ernie.layers.1.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.1.mlp.gate_correction_bias ernie.layers.1.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.1.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.1.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.1.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.1.mlp.text_fused_moe.gate.weight -ernie.layers.1.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.1.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.1.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.1.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -437,13 +436,12 @@ ernie.layers.2.self_attn.qkv_proj.weight ernie.layers.2.self_attn.qkv_proj.weight_scale ernie.layers.2.self_attn.o_proj.weight ernie.layers.2.self_attn.o_proj.weight_scale -ernie.layers.2.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.2.mlp.gate_correction_bias ernie.layers.2.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.2.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.2.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.2.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.2.mlp.text_fused_moe.gate.weight -ernie.layers.2.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.2.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.2.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.2.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -459,13 +457,12 @@ ernie.layers.3.self_attn.qkv_proj.weight ernie.layers.3.self_attn.qkv_proj.weight_scale ernie.layers.3.self_attn.o_proj.weight ernie.layers.3.self_attn.o_proj.weight_scale -ernie.layers.3.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.3.mlp.gate_correction_bias ernie.layers.3.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.3.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.3.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.3.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.3.mlp.text_fused_moe.gate.weight -ernie.layers.3.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.3.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.3.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.3.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -481,13 +478,12 @@ ernie.layers.4.self_attn.qkv_proj.weight ernie.layers.4.self_attn.qkv_proj.weight_scale ernie.layers.4.self_attn.o_proj.weight ernie.layers.4.self_attn.o_proj.weight_scale -ernie.layers.4.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.4.mlp.gate_correction_bias ernie.layers.4.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.4.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.4.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.4.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.4.mlp.text_fused_moe.gate.weight -ernie.layers.4.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.4.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.4.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.4.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -503,13 +499,12 @@ ernie.layers.5.self_attn.qkv_proj.weight ernie.layers.5.self_attn.qkv_proj.weight_scale ernie.layers.5.self_attn.o_proj.weight ernie.layers.5.self_attn.o_proj.weight_scale -ernie.layers.5.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.5.mlp.gate_correction_bias ernie.layers.5.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.5.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.5.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.5.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.5.mlp.text_fused_moe.gate.weight -ernie.layers.5.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.5.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.5.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.5.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -525,13 +520,12 @@ ernie.layers.6.self_attn.qkv_proj.weight ernie.layers.6.self_attn.qkv_proj.weight_scale ernie.layers.6.self_attn.o_proj.weight ernie.layers.6.self_attn.o_proj.weight_scale -ernie.layers.6.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.6.mlp.gate_correction_bias ernie.layers.6.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.6.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.6.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.6.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.6.mlp.text_fused_moe.gate.weight -ernie.layers.6.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.6.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.6.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.6.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -547,13 +541,12 @@ ernie.layers.7.self_attn.qkv_proj.weight ernie.layers.7.self_attn.qkv_proj.weight_scale ernie.layers.7.self_attn.o_proj.weight ernie.layers.7.self_attn.o_proj.weight_scale -ernie.layers.7.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.7.mlp.gate_correction_bias ernie.layers.7.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.7.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.7.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.7.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.7.mlp.text_fused_moe.gate.weight -ernie.layers.7.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.7.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.7.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.7.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -569,13 +562,12 @@ ernie.layers.8.self_attn.qkv_proj.weight ernie.layers.8.self_attn.qkv_proj.weight_scale ernie.layers.8.self_attn.o_proj.weight ernie.layers.8.self_attn.o_proj.weight_scale -ernie.layers.8.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.8.mlp.gate_correction_bias ernie.layers.8.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.8.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.8.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.8.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.8.mlp.text_fused_moe.gate.weight -ernie.layers.8.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.8.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.8.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.8.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -591,13 +583,12 @@ ernie.layers.9.self_attn.qkv_proj.weight ernie.layers.9.self_attn.qkv_proj.weight_scale ernie.layers.9.self_attn.o_proj.weight ernie.layers.9.self_attn.o_proj.weight_scale -ernie.layers.9.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.9.mlp.gate_correction_bias ernie.layers.9.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.9.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.9.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.9.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.9.mlp.text_fused_moe.gate.weight -ernie.layers.9.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.9.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.9.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.9.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -613,13 +604,12 @@ ernie.layers.10.self_attn.qkv_proj.weight ernie.layers.10.self_attn.qkv_proj.weight_scale ernie.layers.10.self_attn.o_proj.weight ernie.layers.10.self_attn.o_proj.weight_scale -ernie.layers.10.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.10.mlp.gate_correction_bias ernie.layers.10.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.10.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.10.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.10.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.10.mlp.text_fused_moe.gate.weight -ernie.layers.10.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.10.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.10.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.10.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -635,13 +625,12 @@ ernie.layers.11.self_attn.qkv_proj.weight ernie.layers.11.self_attn.qkv_proj.weight_scale ernie.layers.11.self_attn.o_proj.weight ernie.layers.11.self_attn.o_proj.weight_scale -ernie.layers.11.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.11.mlp.gate_correction_bias ernie.layers.11.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.11.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.11.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.11.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.11.mlp.text_fused_moe.gate.weight -ernie.layers.11.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.11.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.11.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.11.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -657,13 +646,12 @@ ernie.layers.12.self_attn.qkv_proj.weight ernie.layers.12.self_attn.qkv_proj.weight_scale ernie.layers.12.self_attn.o_proj.weight ernie.layers.12.self_attn.o_proj.weight_scale -ernie.layers.12.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.12.mlp.gate_correction_bias ernie.layers.12.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.12.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.12.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.12.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.12.mlp.text_fused_moe.gate.weight -ernie.layers.12.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.12.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.12.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.12.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -679,13 +667,12 @@ ernie.layers.13.self_attn.qkv_proj.weight ernie.layers.13.self_attn.qkv_proj.weight_scale ernie.layers.13.self_attn.o_proj.weight ernie.layers.13.self_attn.o_proj.weight_scale -ernie.layers.13.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.13.mlp.gate_correction_bias ernie.layers.13.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.13.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.13.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.13.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.13.mlp.text_fused_moe.gate.weight -ernie.layers.13.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.13.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.13.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.13.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -701,13 +688,12 @@ ernie.layers.14.self_attn.qkv_proj.weight ernie.layers.14.self_attn.qkv_proj.weight_scale ernie.layers.14.self_attn.o_proj.weight ernie.layers.14.self_attn.o_proj.weight_scale -ernie.layers.14.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.14.mlp.gate_correction_bias ernie.layers.14.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.14.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.14.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.14.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.14.mlp.text_fused_moe.gate.weight -ernie.layers.14.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.14.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.14.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.14.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -723,13 +709,12 @@ ernie.layers.15.self_attn.qkv_proj.weight ernie.layers.15.self_attn.qkv_proj.weight_scale ernie.layers.15.self_attn.o_proj.weight ernie.layers.15.self_attn.o_proj.weight_scale -ernie.layers.15.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.15.mlp.gate_correction_bias ernie.layers.15.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.15.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.15.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.15.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.15.mlp.text_fused_moe.gate.weight -ernie.layers.15.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.15.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.15.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.15.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -745,13 +730,12 @@ ernie.layers.16.self_attn.qkv_proj.weight ernie.layers.16.self_attn.qkv_proj.weight_scale ernie.layers.16.self_attn.o_proj.weight ernie.layers.16.self_attn.o_proj.weight_scale -ernie.layers.16.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.16.mlp.gate_correction_bias ernie.layers.16.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.16.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.16.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.16.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.16.mlp.text_fused_moe.gate.weight -ernie.layers.16.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.16.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.16.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.16.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -767,13 +751,12 @@ ernie.layers.17.self_attn.qkv_proj.weight ernie.layers.17.self_attn.qkv_proj.weight_scale ernie.layers.17.self_attn.o_proj.weight ernie.layers.17.self_attn.o_proj.weight_scale -ernie.layers.17.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.17.mlp.gate_correction_bias ernie.layers.17.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.17.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.17.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.17.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.17.mlp.text_fused_moe.gate.weight -ernie.layers.17.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.17.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.17.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.17.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -789,13 +772,12 @@ ernie.layers.18.self_attn.qkv_proj.weight ernie.layers.18.self_attn.qkv_proj.weight_scale ernie.layers.18.self_attn.o_proj.weight ernie.layers.18.self_attn.o_proj.weight_scale -ernie.layers.18.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.18.mlp.gate_correction_bias ernie.layers.18.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.18.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.18.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.18.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.18.mlp.text_fused_moe.gate.weight -ernie.layers.18.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.18.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.18.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.18.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -811,13 +793,12 @@ ernie.layers.19.self_attn.qkv_proj.weight ernie.layers.19.self_attn.qkv_proj.weight_scale ernie.layers.19.self_attn.o_proj.weight ernie.layers.19.self_attn.o_proj.weight_scale -ernie.layers.19.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.19.mlp.gate_correction_bias ernie.layers.19.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.19.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.19.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.19.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.19.mlp.text_fused_moe.gate.weight -ernie.layers.19.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.19.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.19.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.19.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -833,13 +814,12 @@ ernie.layers.20.self_attn.qkv_proj.weight ernie.layers.20.self_attn.qkv_proj.weight_scale ernie.layers.20.self_attn.o_proj.weight ernie.layers.20.self_attn.o_proj.weight_scale -ernie.layers.20.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.20.mlp.gate_correction_bias ernie.layers.20.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.20.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.20.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.20.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.20.mlp.text_fused_moe.gate.weight -ernie.layers.20.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.20.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.20.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.20.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -855,13 +835,12 @@ ernie.layers.21.self_attn.qkv_proj.weight ernie.layers.21.self_attn.qkv_proj.weight_scale ernie.layers.21.self_attn.o_proj.weight ernie.layers.21.self_attn.o_proj.weight_scale -ernie.layers.21.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.21.mlp.gate_correction_bias ernie.layers.21.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.21.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.21.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.21.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.21.mlp.text_fused_moe.gate.weight -ernie.layers.21.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.21.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.21.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.21.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -877,13 +856,12 @@ ernie.layers.22.self_attn.qkv_proj.weight ernie.layers.22.self_attn.qkv_proj.weight_scale ernie.layers.22.self_attn.o_proj.weight ernie.layers.22.self_attn.o_proj.weight_scale -ernie.layers.22.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.22.mlp.gate_correction_bias ernie.layers.22.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.22.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.22.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.22.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.22.mlp.text_fused_moe.gate.weight -ernie.layers.22.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.22.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.22.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.22.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -899,13 +877,12 @@ ernie.layers.23.self_attn.qkv_proj.weight ernie.layers.23.self_attn.qkv_proj.weight_scale ernie.layers.23.self_attn.o_proj.weight ernie.layers.23.self_attn.o_proj.weight_scale -ernie.layers.23.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.23.mlp.gate_correction_bias ernie.layers.23.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.23.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.23.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.23.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.23.mlp.text_fused_moe.gate.weight -ernie.layers.23.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.23.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.23.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.23.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -921,13 +898,12 @@ ernie.layers.24.self_attn.qkv_proj.weight ernie.layers.24.self_attn.qkv_proj.weight_scale ernie.layers.24.self_attn.o_proj.weight ernie.layers.24.self_attn.o_proj.weight_scale -ernie.layers.24.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.24.mlp.gate_correction_bias ernie.layers.24.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.24.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.24.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.24.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.24.mlp.text_fused_moe.gate.weight -ernie.layers.24.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.24.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.24.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.24.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -943,13 +919,12 @@ ernie.layers.25.self_attn.qkv_proj.weight ernie.layers.25.self_attn.qkv_proj.weight_scale ernie.layers.25.self_attn.o_proj.weight ernie.layers.25.self_attn.o_proj.weight_scale -ernie.layers.25.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.25.mlp.gate_correction_bias ernie.layers.25.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.25.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.25.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.25.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.25.mlp.text_fused_moe.gate.weight -ernie.layers.25.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.25.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.25.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.25.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -965,13 +940,12 @@ ernie.layers.26.self_attn.qkv_proj.weight ernie.layers.26.self_attn.qkv_proj.weight_scale ernie.layers.26.self_attn.o_proj.weight ernie.layers.26.self_attn.o_proj.weight_scale -ernie.layers.26.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.26.mlp.gate_correction_bias ernie.layers.26.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.26.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.26.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.26.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.26.mlp.text_fused_moe.gate.weight -ernie.layers.26.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.26.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.26.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.26.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -987,13 +961,12 @@ ernie.layers.27.self_attn.qkv_proj.weight ernie.layers.27.self_attn.qkv_proj.weight_scale ernie.layers.27.self_attn.o_proj.weight ernie.layers.27.self_attn.o_proj.weight_scale -ernie.layers.27.mlp.text_fused_moe.experts.gate_correction_bias +ernie.layers.27.mlp.gate_correction_bias ernie.layers.27.mlp.text_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.27.mlp.text_fused_moe.experts.down_proj_weight_scale ernie.layers.27.mlp.text_fused_moe.experts.up_gate_proj_weight ernie.layers.27.mlp.text_fused_moe.experts.down_proj_weight ernie.layers.27.mlp.text_fused_moe.gate.weight -ernie.layers.27.mlp.image_fused_moe.experts.gate_correction_bias ernie.layers.27.mlp.image_fused_moe.experts.up_gate_proj_weight_scale ernie.layers.27.mlp.image_fused_moe.experts.down_proj_weight_scale ernie.layers.27.mlp.image_fused_moe.experts.up_gate_proj_weight @@ -1010,223 +983,196 @@ lm_head.linear.weight ernie.embed_tokens.embeddings.weight:ernie.embed_tokens.weight lm_head.linear.weight:lm_head.weight ernie.layers.1.mlp.text_fused_moe.gate.weight:ernie.layers.1.mlp.gate.weight -ernie.layers.1.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias +ernie.layers.1.mlp.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias ernie.layers.1.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.1.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.95.up_gate_proj.weight'] ernie.layers.1.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.1.mlp.experts.0.down_proj.weight', 'ernie.layers.1.mlp.experts.1.down_proj.weight', 'ernie.layers.1.mlp.experts.2.down_proj.weight', 'ernie.layers.1.mlp.experts.3.down_proj.weight', 'ernie.layers.1.mlp.experts.4.down_proj.weight', 'ernie.layers.1.mlp.experts.5.down_proj.weight', 'ernie.layers.1.mlp.experts.6.down_proj.weight', 'ernie.layers.1.mlp.experts.7.down_proj.weight', 'ernie.layers.1.mlp.experts.8.down_proj.weight', 'ernie.layers.1.mlp.experts.9.down_proj.weight', 'ernie.layers.1.mlp.experts.10.down_proj.weight', 'ernie.layers.1.mlp.experts.11.down_proj.weight', 'ernie.layers.1.mlp.experts.12.down_proj.weight', 'ernie.layers.1.mlp.experts.13.down_proj.weight', 'ernie.layers.1.mlp.experts.14.down_proj.weight', 'ernie.layers.1.mlp.experts.15.down_proj.weight', 'ernie.layers.1.mlp.experts.16.down_proj.weight', 'ernie.layers.1.mlp.experts.17.down_proj.weight', 'ernie.layers.1.mlp.experts.18.down_proj.weight', 'ernie.layers.1.mlp.experts.19.down_proj.weight', 'ernie.layers.1.mlp.experts.20.down_proj.weight', 'ernie.layers.1.mlp.experts.21.down_proj.weight', 'ernie.layers.1.mlp.experts.22.down_proj.weight', 'ernie.layers.1.mlp.experts.23.down_proj.weight', 'ernie.layers.1.mlp.experts.24.down_proj.weight', 'ernie.layers.1.mlp.experts.25.down_proj.weight', 'ernie.layers.1.mlp.experts.26.down_proj.weight', 'ernie.layers.1.mlp.experts.27.down_proj.weight', 'ernie.layers.1.mlp.experts.28.down_proj.weight', 'ernie.layers.1.mlp.experts.29.down_proj.weight', 'ernie.layers.1.mlp.experts.30.down_proj.weight', 'ernie.layers.1.mlp.experts.31.down_proj.weight', 'ernie.layers.1.mlp.experts.64.down_proj.weight', 'ernie.layers.1.mlp.experts.65.down_proj.weight', 'ernie.layers.1.mlp.experts.66.down_proj.weight', 'ernie.layers.1.mlp.experts.67.down_proj.weight', 'ernie.layers.1.mlp.experts.68.down_proj.weight', 'ernie.layers.1.mlp.experts.69.down_proj.weight', 'ernie.layers.1.mlp.experts.70.down_proj.weight', 'ernie.layers.1.mlp.experts.71.down_proj.weight', 'ernie.layers.1.mlp.experts.72.down_proj.weight', 'ernie.layers.1.mlp.experts.73.down_proj.weight', 'ernie.layers.1.mlp.experts.74.down_proj.weight', 'ernie.layers.1.mlp.experts.75.down_proj.weight', 'ernie.layers.1.mlp.experts.76.down_proj.weight', 'ernie.layers.1.mlp.experts.77.down_proj.weight', 'ernie.layers.1.mlp.experts.78.down_proj.weight', 'ernie.layers.1.mlp.experts.79.down_proj.weight', 'ernie.layers.1.mlp.experts.80.down_proj.weight', 'ernie.layers.1.mlp.experts.81.down_proj.weight', 'ernie.layers.1.mlp.experts.82.down_proj.weight', 'ernie.layers.1.mlp.experts.83.down_proj.weight', 'ernie.layers.1.mlp.experts.84.down_proj.weight', 'ernie.layers.1.mlp.experts.85.down_proj.weight', 'ernie.layers.1.mlp.experts.86.down_proj.weight', 'ernie.layers.1.mlp.experts.87.down_proj.weight', 'ernie.layers.1.mlp.experts.88.down_proj.weight', 'ernie.layers.1.mlp.experts.89.down_proj.weight', 'ernie.layers.1.mlp.experts.90.down_proj.weight', 'ernie.layers.1.mlp.experts.91.down_proj.weight', 'ernie.layers.1.mlp.experts.92.down_proj.weight', 'ernie.layers.1.mlp.experts.93.down_proj.weight', 'ernie.layers.1.mlp.experts.94.down_proj.weight', 'ernie.layers.1.mlp.experts.95.down_proj.weight'] ernie.layers.2.mlp.text_fused_moe.gate.weight:ernie.layers.2.mlp.gate.weight -ernie.layers.2.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.2.mlp.moe_statics.e_score_correction_bias +ernie.layers.2.mlp.gate_correction_bias:ernie.layers.2.mlp.moe_statics.e_score_correction_bias ernie.layers.2.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.2.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.95.up_gate_proj.weight'] ernie.layers.2.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.2.mlp.experts.0.down_proj.weight', 'ernie.layers.2.mlp.experts.1.down_proj.weight', 'ernie.layers.2.mlp.experts.2.down_proj.weight', 'ernie.layers.2.mlp.experts.3.down_proj.weight', 'ernie.layers.2.mlp.experts.4.down_proj.weight', 'ernie.layers.2.mlp.experts.5.down_proj.weight', 'ernie.layers.2.mlp.experts.6.down_proj.weight', 'ernie.layers.2.mlp.experts.7.down_proj.weight', 'ernie.layers.2.mlp.experts.8.down_proj.weight', 'ernie.layers.2.mlp.experts.9.down_proj.weight', 'ernie.layers.2.mlp.experts.10.down_proj.weight', 'ernie.layers.2.mlp.experts.11.down_proj.weight', 'ernie.layers.2.mlp.experts.12.down_proj.weight', 'ernie.layers.2.mlp.experts.13.down_proj.weight', 'ernie.layers.2.mlp.experts.14.down_proj.weight', 'ernie.layers.2.mlp.experts.15.down_proj.weight', 'ernie.layers.2.mlp.experts.16.down_proj.weight', 'ernie.layers.2.mlp.experts.17.down_proj.weight', 'ernie.layers.2.mlp.experts.18.down_proj.weight', 'ernie.layers.2.mlp.experts.19.down_proj.weight', 'ernie.layers.2.mlp.experts.20.down_proj.weight', 'ernie.layers.2.mlp.experts.21.down_proj.weight', 'ernie.layers.2.mlp.experts.22.down_proj.weight', 'ernie.layers.2.mlp.experts.23.down_proj.weight', 'ernie.layers.2.mlp.experts.24.down_proj.weight', 'ernie.layers.2.mlp.experts.25.down_proj.weight', 'ernie.layers.2.mlp.experts.26.down_proj.weight', 'ernie.layers.2.mlp.experts.27.down_proj.weight', 'ernie.layers.2.mlp.experts.28.down_proj.weight', 'ernie.layers.2.mlp.experts.29.down_proj.weight', 'ernie.layers.2.mlp.experts.30.down_proj.weight', 'ernie.layers.2.mlp.experts.31.down_proj.weight', 'ernie.layers.2.mlp.experts.64.down_proj.weight', 'ernie.layers.2.mlp.experts.65.down_proj.weight', 'ernie.layers.2.mlp.experts.66.down_proj.weight', 'ernie.layers.2.mlp.experts.67.down_proj.weight', 'ernie.layers.2.mlp.experts.68.down_proj.weight', 'ernie.layers.2.mlp.experts.69.down_proj.weight', 'ernie.layers.2.mlp.experts.70.down_proj.weight', 'ernie.layers.2.mlp.experts.71.down_proj.weight', 'ernie.layers.2.mlp.experts.72.down_proj.weight', 'ernie.layers.2.mlp.experts.73.down_proj.weight', 'ernie.layers.2.mlp.experts.74.down_proj.weight', 'ernie.layers.2.mlp.experts.75.down_proj.weight', 'ernie.layers.2.mlp.experts.76.down_proj.weight', 'ernie.layers.2.mlp.experts.77.down_proj.weight', 'ernie.layers.2.mlp.experts.78.down_proj.weight', 'ernie.layers.2.mlp.experts.79.down_proj.weight', 'ernie.layers.2.mlp.experts.80.down_proj.weight', 'ernie.layers.2.mlp.experts.81.down_proj.weight', 'ernie.layers.2.mlp.experts.82.down_proj.weight', 'ernie.layers.2.mlp.experts.83.down_proj.weight', 'ernie.layers.2.mlp.experts.84.down_proj.weight', 'ernie.layers.2.mlp.experts.85.down_proj.weight', 'ernie.layers.2.mlp.experts.86.down_proj.weight', 'ernie.layers.2.mlp.experts.87.down_proj.weight', 'ernie.layers.2.mlp.experts.88.down_proj.weight', 'ernie.layers.2.mlp.experts.89.down_proj.weight', 'ernie.layers.2.mlp.experts.90.down_proj.weight', 'ernie.layers.2.mlp.experts.91.down_proj.weight', 'ernie.layers.2.mlp.experts.92.down_proj.weight', 'ernie.layers.2.mlp.experts.93.down_proj.weight', 'ernie.layers.2.mlp.experts.94.down_proj.weight', 'ernie.layers.2.mlp.experts.95.down_proj.weight'] ernie.layers.3.mlp.text_fused_moe.gate.weight:ernie.layers.3.mlp.gate.weight -ernie.layers.3.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.3.mlp.moe_statics.e_score_correction_bias +ernie.layers.3.mlp.gate_correction_bias:ernie.layers.3.mlp.moe_statics.e_score_correction_bias ernie.layers.3.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.3.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.95.up_gate_proj.weight'] ernie.layers.3.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.3.mlp.experts.0.down_proj.weight', 'ernie.layers.3.mlp.experts.1.down_proj.weight', 'ernie.layers.3.mlp.experts.2.down_proj.weight', 'ernie.layers.3.mlp.experts.3.down_proj.weight', 'ernie.layers.3.mlp.experts.4.down_proj.weight', 'ernie.layers.3.mlp.experts.5.down_proj.weight', 'ernie.layers.3.mlp.experts.6.down_proj.weight', 'ernie.layers.3.mlp.experts.7.down_proj.weight', 'ernie.layers.3.mlp.experts.8.down_proj.weight', 'ernie.layers.3.mlp.experts.9.down_proj.weight', 'ernie.layers.3.mlp.experts.10.down_proj.weight', 'ernie.layers.3.mlp.experts.11.down_proj.weight', 'ernie.layers.3.mlp.experts.12.down_proj.weight', 'ernie.layers.3.mlp.experts.13.down_proj.weight', 'ernie.layers.3.mlp.experts.14.down_proj.weight', 'ernie.layers.3.mlp.experts.15.down_proj.weight', 'ernie.layers.3.mlp.experts.16.down_proj.weight', 'ernie.layers.3.mlp.experts.17.down_proj.weight', 'ernie.layers.3.mlp.experts.18.down_proj.weight', 'ernie.layers.3.mlp.experts.19.down_proj.weight', 'ernie.layers.3.mlp.experts.20.down_proj.weight', 'ernie.layers.3.mlp.experts.21.down_proj.weight', 'ernie.layers.3.mlp.experts.22.down_proj.weight', 'ernie.layers.3.mlp.experts.23.down_proj.weight', 'ernie.layers.3.mlp.experts.24.down_proj.weight', 'ernie.layers.3.mlp.experts.25.down_proj.weight', 'ernie.layers.3.mlp.experts.26.down_proj.weight', 'ernie.layers.3.mlp.experts.27.down_proj.weight', 'ernie.layers.3.mlp.experts.28.down_proj.weight', 'ernie.layers.3.mlp.experts.29.down_proj.weight', 'ernie.layers.3.mlp.experts.30.down_proj.weight', 'ernie.layers.3.mlp.experts.31.down_proj.weight', 'ernie.layers.3.mlp.experts.64.down_proj.weight', 'ernie.layers.3.mlp.experts.65.down_proj.weight', 'ernie.layers.3.mlp.experts.66.down_proj.weight', 'ernie.layers.3.mlp.experts.67.down_proj.weight', 'ernie.layers.3.mlp.experts.68.down_proj.weight', 'ernie.layers.3.mlp.experts.69.down_proj.weight', 'ernie.layers.3.mlp.experts.70.down_proj.weight', 'ernie.layers.3.mlp.experts.71.down_proj.weight', 'ernie.layers.3.mlp.experts.72.down_proj.weight', 'ernie.layers.3.mlp.experts.73.down_proj.weight', 'ernie.layers.3.mlp.experts.74.down_proj.weight', 'ernie.layers.3.mlp.experts.75.down_proj.weight', 'ernie.layers.3.mlp.experts.76.down_proj.weight', 'ernie.layers.3.mlp.experts.77.down_proj.weight', 'ernie.layers.3.mlp.experts.78.down_proj.weight', 'ernie.layers.3.mlp.experts.79.down_proj.weight', 'ernie.layers.3.mlp.experts.80.down_proj.weight', 'ernie.layers.3.mlp.experts.81.down_proj.weight', 'ernie.layers.3.mlp.experts.82.down_proj.weight', 'ernie.layers.3.mlp.experts.83.down_proj.weight', 'ernie.layers.3.mlp.experts.84.down_proj.weight', 'ernie.layers.3.mlp.experts.85.down_proj.weight', 'ernie.layers.3.mlp.experts.86.down_proj.weight', 'ernie.layers.3.mlp.experts.87.down_proj.weight', 'ernie.layers.3.mlp.experts.88.down_proj.weight', 'ernie.layers.3.mlp.experts.89.down_proj.weight', 'ernie.layers.3.mlp.experts.90.down_proj.weight', 'ernie.layers.3.mlp.experts.91.down_proj.weight', 'ernie.layers.3.mlp.experts.92.down_proj.weight', 'ernie.layers.3.mlp.experts.93.down_proj.weight', 'ernie.layers.3.mlp.experts.94.down_proj.weight', 'ernie.layers.3.mlp.experts.95.down_proj.weight'] ernie.layers.4.mlp.text_fused_moe.gate.weight:ernie.layers.4.mlp.gate.weight -ernie.layers.4.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.4.mlp.moe_statics.e_score_correction_bias +ernie.layers.4.mlp.gate_correction_bias:ernie.layers.4.mlp.moe_statics.e_score_correction_bias ernie.layers.4.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.4.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.95.up_gate_proj.weight'] ernie.layers.4.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.4.mlp.experts.0.down_proj.weight', 'ernie.layers.4.mlp.experts.1.down_proj.weight', 'ernie.layers.4.mlp.experts.2.down_proj.weight', 'ernie.layers.4.mlp.experts.3.down_proj.weight', 'ernie.layers.4.mlp.experts.4.down_proj.weight', 'ernie.layers.4.mlp.experts.5.down_proj.weight', 'ernie.layers.4.mlp.experts.6.down_proj.weight', 'ernie.layers.4.mlp.experts.7.down_proj.weight', 'ernie.layers.4.mlp.experts.8.down_proj.weight', 'ernie.layers.4.mlp.experts.9.down_proj.weight', 'ernie.layers.4.mlp.experts.10.down_proj.weight', 'ernie.layers.4.mlp.experts.11.down_proj.weight', 'ernie.layers.4.mlp.experts.12.down_proj.weight', 'ernie.layers.4.mlp.experts.13.down_proj.weight', 'ernie.layers.4.mlp.experts.14.down_proj.weight', 'ernie.layers.4.mlp.experts.15.down_proj.weight', 'ernie.layers.4.mlp.experts.16.down_proj.weight', 'ernie.layers.4.mlp.experts.17.down_proj.weight', 'ernie.layers.4.mlp.experts.18.down_proj.weight', 'ernie.layers.4.mlp.experts.19.down_proj.weight', 'ernie.layers.4.mlp.experts.20.down_proj.weight', 'ernie.layers.4.mlp.experts.21.down_proj.weight', 'ernie.layers.4.mlp.experts.22.down_proj.weight', 'ernie.layers.4.mlp.experts.23.down_proj.weight', 'ernie.layers.4.mlp.experts.24.down_proj.weight', 'ernie.layers.4.mlp.experts.25.down_proj.weight', 'ernie.layers.4.mlp.experts.26.down_proj.weight', 'ernie.layers.4.mlp.experts.27.down_proj.weight', 'ernie.layers.4.mlp.experts.28.down_proj.weight', 'ernie.layers.4.mlp.experts.29.down_proj.weight', 'ernie.layers.4.mlp.experts.30.down_proj.weight', 'ernie.layers.4.mlp.experts.31.down_proj.weight', 'ernie.layers.4.mlp.experts.64.down_proj.weight', 'ernie.layers.4.mlp.experts.65.down_proj.weight', 'ernie.layers.4.mlp.experts.66.down_proj.weight', 'ernie.layers.4.mlp.experts.67.down_proj.weight', 'ernie.layers.4.mlp.experts.68.down_proj.weight', 'ernie.layers.4.mlp.experts.69.down_proj.weight', 'ernie.layers.4.mlp.experts.70.down_proj.weight', 'ernie.layers.4.mlp.experts.71.down_proj.weight', 'ernie.layers.4.mlp.experts.72.down_proj.weight', 'ernie.layers.4.mlp.experts.73.down_proj.weight', 'ernie.layers.4.mlp.experts.74.down_proj.weight', 'ernie.layers.4.mlp.experts.75.down_proj.weight', 'ernie.layers.4.mlp.experts.76.down_proj.weight', 'ernie.layers.4.mlp.experts.77.down_proj.weight', 'ernie.layers.4.mlp.experts.78.down_proj.weight', 'ernie.layers.4.mlp.experts.79.down_proj.weight', 'ernie.layers.4.mlp.experts.80.down_proj.weight', 'ernie.layers.4.mlp.experts.81.down_proj.weight', 'ernie.layers.4.mlp.experts.82.down_proj.weight', 'ernie.layers.4.mlp.experts.83.down_proj.weight', 'ernie.layers.4.mlp.experts.84.down_proj.weight', 'ernie.layers.4.mlp.experts.85.down_proj.weight', 'ernie.layers.4.mlp.experts.86.down_proj.weight', 'ernie.layers.4.mlp.experts.87.down_proj.weight', 'ernie.layers.4.mlp.experts.88.down_proj.weight', 'ernie.layers.4.mlp.experts.89.down_proj.weight', 'ernie.layers.4.mlp.experts.90.down_proj.weight', 'ernie.layers.4.mlp.experts.91.down_proj.weight', 'ernie.layers.4.mlp.experts.92.down_proj.weight', 'ernie.layers.4.mlp.experts.93.down_proj.weight', 'ernie.layers.4.mlp.experts.94.down_proj.weight', 'ernie.layers.4.mlp.experts.95.down_proj.weight'] ernie.layers.5.mlp.text_fused_moe.gate.weight:ernie.layers.5.mlp.gate.weight -ernie.layers.5.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.5.mlp.moe_statics.e_score_correction_bias +ernie.layers.5.mlp.gate_correction_bias:ernie.layers.5.mlp.moe_statics.e_score_correction_bias ernie.layers.5.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.5.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.95.up_gate_proj.weight'] ernie.layers.5.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.5.mlp.experts.0.down_proj.weight', 'ernie.layers.5.mlp.experts.1.down_proj.weight', 'ernie.layers.5.mlp.experts.2.down_proj.weight', 'ernie.layers.5.mlp.experts.3.down_proj.weight', 'ernie.layers.5.mlp.experts.4.down_proj.weight', 'ernie.layers.5.mlp.experts.5.down_proj.weight', 'ernie.layers.5.mlp.experts.6.down_proj.weight', 'ernie.layers.5.mlp.experts.7.down_proj.weight', 'ernie.layers.5.mlp.experts.8.down_proj.weight', 'ernie.layers.5.mlp.experts.9.down_proj.weight', 'ernie.layers.5.mlp.experts.10.down_proj.weight', 'ernie.layers.5.mlp.experts.11.down_proj.weight', 'ernie.layers.5.mlp.experts.12.down_proj.weight', 'ernie.layers.5.mlp.experts.13.down_proj.weight', 'ernie.layers.5.mlp.experts.14.down_proj.weight', 'ernie.layers.5.mlp.experts.15.down_proj.weight', 'ernie.layers.5.mlp.experts.16.down_proj.weight', 'ernie.layers.5.mlp.experts.17.down_proj.weight', 'ernie.layers.5.mlp.experts.18.down_proj.weight', 'ernie.layers.5.mlp.experts.19.down_proj.weight', 'ernie.layers.5.mlp.experts.20.down_proj.weight', 'ernie.layers.5.mlp.experts.21.down_proj.weight', 'ernie.layers.5.mlp.experts.22.down_proj.weight', 'ernie.layers.5.mlp.experts.23.down_proj.weight', 'ernie.layers.5.mlp.experts.24.down_proj.weight', 'ernie.layers.5.mlp.experts.25.down_proj.weight', 'ernie.layers.5.mlp.experts.26.down_proj.weight', 'ernie.layers.5.mlp.experts.27.down_proj.weight', 'ernie.layers.5.mlp.experts.28.down_proj.weight', 'ernie.layers.5.mlp.experts.29.down_proj.weight', 'ernie.layers.5.mlp.experts.30.down_proj.weight', 'ernie.layers.5.mlp.experts.31.down_proj.weight', 'ernie.layers.5.mlp.experts.64.down_proj.weight', 'ernie.layers.5.mlp.experts.65.down_proj.weight', 'ernie.layers.5.mlp.experts.66.down_proj.weight', 'ernie.layers.5.mlp.experts.67.down_proj.weight', 'ernie.layers.5.mlp.experts.68.down_proj.weight', 'ernie.layers.5.mlp.experts.69.down_proj.weight', 'ernie.layers.5.mlp.experts.70.down_proj.weight', 'ernie.layers.5.mlp.experts.71.down_proj.weight', 'ernie.layers.5.mlp.experts.72.down_proj.weight', 'ernie.layers.5.mlp.experts.73.down_proj.weight', 'ernie.layers.5.mlp.experts.74.down_proj.weight', 'ernie.layers.5.mlp.experts.75.down_proj.weight', 'ernie.layers.5.mlp.experts.76.down_proj.weight', 'ernie.layers.5.mlp.experts.77.down_proj.weight', 'ernie.layers.5.mlp.experts.78.down_proj.weight', 'ernie.layers.5.mlp.experts.79.down_proj.weight', 'ernie.layers.5.mlp.experts.80.down_proj.weight', 'ernie.layers.5.mlp.experts.81.down_proj.weight', 'ernie.layers.5.mlp.experts.82.down_proj.weight', 'ernie.layers.5.mlp.experts.83.down_proj.weight', 'ernie.layers.5.mlp.experts.84.down_proj.weight', 'ernie.layers.5.mlp.experts.85.down_proj.weight', 'ernie.layers.5.mlp.experts.86.down_proj.weight', 'ernie.layers.5.mlp.experts.87.down_proj.weight', 'ernie.layers.5.mlp.experts.88.down_proj.weight', 'ernie.layers.5.mlp.experts.89.down_proj.weight', 'ernie.layers.5.mlp.experts.90.down_proj.weight', 'ernie.layers.5.mlp.experts.91.down_proj.weight', 'ernie.layers.5.mlp.experts.92.down_proj.weight', 'ernie.layers.5.mlp.experts.93.down_proj.weight', 'ernie.layers.5.mlp.experts.94.down_proj.weight', 'ernie.layers.5.mlp.experts.95.down_proj.weight'] ernie.layers.6.mlp.text_fused_moe.gate.weight:ernie.layers.6.mlp.gate.weight -ernie.layers.6.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.6.mlp.moe_statics.e_score_correction_bias +ernie.layers.6.mlp.gate_correction_bias:ernie.layers.6.mlp.moe_statics.e_score_correction_bias ernie.layers.6.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.6.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.95.up_gate_proj.weight'] ernie.layers.6.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.6.mlp.experts.0.down_proj.weight', 'ernie.layers.6.mlp.experts.1.down_proj.weight', 'ernie.layers.6.mlp.experts.2.down_proj.weight', 'ernie.layers.6.mlp.experts.3.down_proj.weight', 'ernie.layers.6.mlp.experts.4.down_proj.weight', 'ernie.layers.6.mlp.experts.5.down_proj.weight', 'ernie.layers.6.mlp.experts.6.down_proj.weight', 'ernie.layers.6.mlp.experts.7.down_proj.weight', 'ernie.layers.6.mlp.experts.8.down_proj.weight', 'ernie.layers.6.mlp.experts.9.down_proj.weight', 'ernie.layers.6.mlp.experts.10.down_proj.weight', 'ernie.layers.6.mlp.experts.11.down_proj.weight', 'ernie.layers.6.mlp.experts.12.down_proj.weight', 'ernie.layers.6.mlp.experts.13.down_proj.weight', 'ernie.layers.6.mlp.experts.14.down_proj.weight', 'ernie.layers.6.mlp.experts.15.down_proj.weight', 'ernie.layers.6.mlp.experts.16.down_proj.weight', 'ernie.layers.6.mlp.experts.17.down_proj.weight', 'ernie.layers.6.mlp.experts.18.down_proj.weight', 'ernie.layers.6.mlp.experts.19.down_proj.weight', 'ernie.layers.6.mlp.experts.20.down_proj.weight', 'ernie.layers.6.mlp.experts.21.down_proj.weight', 'ernie.layers.6.mlp.experts.22.down_proj.weight', 'ernie.layers.6.mlp.experts.23.down_proj.weight', 'ernie.layers.6.mlp.experts.24.down_proj.weight', 'ernie.layers.6.mlp.experts.25.down_proj.weight', 'ernie.layers.6.mlp.experts.26.down_proj.weight', 'ernie.layers.6.mlp.experts.27.down_proj.weight', 'ernie.layers.6.mlp.experts.28.down_proj.weight', 'ernie.layers.6.mlp.experts.29.down_proj.weight', 'ernie.layers.6.mlp.experts.30.down_proj.weight', 'ernie.layers.6.mlp.experts.31.down_proj.weight', 'ernie.layers.6.mlp.experts.64.down_proj.weight', 'ernie.layers.6.mlp.experts.65.down_proj.weight', 'ernie.layers.6.mlp.experts.66.down_proj.weight', 'ernie.layers.6.mlp.experts.67.down_proj.weight', 'ernie.layers.6.mlp.experts.68.down_proj.weight', 'ernie.layers.6.mlp.experts.69.down_proj.weight', 'ernie.layers.6.mlp.experts.70.down_proj.weight', 'ernie.layers.6.mlp.experts.71.down_proj.weight', 'ernie.layers.6.mlp.experts.72.down_proj.weight', 'ernie.layers.6.mlp.experts.73.down_proj.weight', 'ernie.layers.6.mlp.experts.74.down_proj.weight', 'ernie.layers.6.mlp.experts.75.down_proj.weight', 'ernie.layers.6.mlp.experts.76.down_proj.weight', 'ernie.layers.6.mlp.experts.77.down_proj.weight', 'ernie.layers.6.mlp.experts.78.down_proj.weight', 'ernie.layers.6.mlp.experts.79.down_proj.weight', 'ernie.layers.6.mlp.experts.80.down_proj.weight', 'ernie.layers.6.mlp.experts.81.down_proj.weight', 'ernie.layers.6.mlp.experts.82.down_proj.weight', 'ernie.layers.6.mlp.experts.83.down_proj.weight', 'ernie.layers.6.mlp.experts.84.down_proj.weight', 'ernie.layers.6.mlp.experts.85.down_proj.weight', 'ernie.layers.6.mlp.experts.86.down_proj.weight', 'ernie.layers.6.mlp.experts.87.down_proj.weight', 'ernie.layers.6.mlp.experts.88.down_proj.weight', 'ernie.layers.6.mlp.experts.89.down_proj.weight', 'ernie.layers.6.mlp.experts.90.down_proj.weight', 'ernie.layers.6.mlp.experts.91.down_proj.weight', 'ernie.layers.6.mlp.experts.92.down_proj.weight', 'ernie.layers.6.mlp.experts.93.down_proj.weight', 'ernie.layers.6.mlp.experts.94.down_proj.weight', 'ernie.layers.6.mlp.experts.95.down_proj.weight'] ernie.layers.7.mlp.text_fused_moe.gate.weight:ernie.layers.7.mlp.gate.weight -ernie.layers.7.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.7.mlp.moe_statics.e_score_correction_bias +ernie.layers.7.mlp.gate_correction_bias:ernie.layers.7.mlp.moe_statics.e_score_correction_bias ernie.layers.7.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.7.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.95.up_gate_proj.weight'] ernie.layers.7.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.7.mlp.experts.0.down_proj.weight', 'ernie.layers.7.mlp.experts.1.down_proj.weight', 'ernie.layers.7.mlp.experts.2.down_proj.weight', 'ernie.layers.7.mlp.experts.3.down_proj.weight', 'ernie.layers.7.mlp.experts.4.down_proj.weight', 'ernie.layers.7.mlp.experts.5.down_proj.weight', 'ernie.layers.7.mlp.experts.6.down_proj.weight', 'ernie.layers.7.mlp.experts.7.down_proj.weight', 'ernie.layers.7.mlp.experts.8.down_proj.weight', 'ernie.layers.7.mlp.experts.9.down_proj.weight', 'ernie.layers.7.mlp.experts.10.down_proj.weight', 'ernie.layers.7.mlp.experts.11.down_proj.weight', 'ernie.layers.7.mlp.experts.12.down_proj.weight', 'ernie.layers.7.mlp.experts.13.down_proj.weight', 'ernie.layers.7.mlp.experts.14.down_proj.weight', 'ernie.layers.7.mlp.experts.15.down_proj.weight', 'ernie.layers.7.mlp.experts.16.down_proj.weight', 'ernie.layers.7.mlp.experts.17.down_proj.weight', 'ernie.layers.7.mlp.experts.18.down_proj.weight', 'ernie.layers.7.mlp.experts.19.down_proj.weight', 'ernie.layers.7.mlp.experts.20.down_proj.weight', 'ernie.layers.7.mlp.experts.21.down_proj.weight', 'ernie.layers.7.mlp.experts.22.down_proj.weight', 'ernie.layers.7.mlp.experts.23.down_proj.weight', 'ernie.layers.7.mlp.experts.24.down_proj.weight', 'ernie.layers.7.mlp.experts.25.down_proj.weight', 'ernie.layers.7.mlp.experts.26.down_proj.weight', 'ernie.layers.7.mlp.experts.27.down_proj.weight', 'ernie.layers.7.mlp.experts.28.down_proj.weight', 'ernie.layers.7.mlp.experts.29.down_proj.weight', 'ernie.layers.7.mlp.experts.30.down_proj.weight', 'ernie.layers.7.mlp.experts.31.down_proj.weight', 'ernie.layers.7.mlp.experts.64.down_proj.weight', 'ernie.layers.7.mlp.experts.65.down_proj.weight', 'ernie.layers.7.mlp.experts.66.down_proj.weight', 'ernie.layers.7.mlp.experts.67.down_proj.weight', 'ernie.layers.7.mlp.experts.68.down_proj.weight', 'ernie.layers.7.mlp.experts.69.down_proj.weight', 'ernie.layers.7.mlp.experts.70.down_proj.weight', 'ernie.layers.7.mlp.experts.71.down_proj.weight', 'ernie.layers.7.mlp.experts.72.down_proj.weight', 'ernie.layers.7.mlp.experts.73.down_proj.weight', 'ernie.layers.7.mlp.experts.74.down_proj.weight', 'ernie.layers.7.mlp.experts.75.down_proj.weight', 'ernie.layers.7.mlp.experts.76.down_proj.weight', 'ernie.layers.7.mlp.experts.77.down_proj.weight', 'ernie.layers.7.mlp.experts.78.down_proj.weight', 'ernie.layers.7.mlp.experts.79.down_proj.weight', 'ernie.layers.7.mlp.experts.80.down_proj.weight', 'ernie.layers.7.mlp.experts.81.down_proj.weight', 'ernie.layers.7.mlp.experts.82.down_proj.weight', 'ernie.layers.7.mlp.experts.83.down_proj.weight', 'ernie.layers.7.mlp.experts.84.down_proj.weight', 'ernie.layers.7.mlp.experts.85.down_proj.weight', 'ernie.layers.7.mlp.experts.86.down_proj.weight', 'ernie.layers.7.mlp.experts.87.down_proj.weight', 'ernie.layers.7.mlp.experts.88.down_proj.weight', 'ernie.layers.7.mlp.experts.89.down_proj.weight', 'ernie.layers.7.mlp.experts.90.down_proj.weight', 'ernie.layers.7.mlp.experts.91.down_proj.weight', 'ernie.layers.7.mlp.experts.92.down_proj.weight', 'ernie.layers.7.mlp.experts.93.down_proj.weight', 'ernie.layers.7.mlp.experts.94.down_proj.weight', 'ernie.layers.7.mlp.experts.95.down_proj.weight'] ernie.layers.8.mlp.text_fused_moe.gate.weight:ernie.layers.8.mlp.gate.weight -ernie.layers.8.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.8.mlp.moe_statics.e_score_correction_bias +ernie.layers.8.mlp.gate_correction_bias:ernie.layers.8.mlp.moe_statics.e_score_correction_bias ernie.layers.8.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.8.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.95.up_gate_proj.weight'] ernie.layers.8.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.8.mlp.experts.0.down_proj.weight', 'ernie.layers.8.mlp.experts.1.down_proj.weight', 'ernie.layers.8.mlp.experts.2.down_proj.weight', 'ernie.layers.8.mlp.experts.3.down_proj.weight', 'ernie.layers.8.mlp.experts.4.down_proj.weight', 'ernie.layers.8.mlp.experts.5.down_proj.weight', 'ernie.layers.8.mlp.experts.6.down_proj.weight', 'ernie.layers.8.mlp.experts.7.down_proj.weight', 'ernie.layers.8.mlp.experts.8.down_proj.weight', 'ernie.layers.8.mlp.experts.9.down_proj.weight', 'ernie.layers.8.mlp.experts.10.down_proj.weight', 'ernie.layers.8.mlp.experts.11.down_proj.weight', 'ernie.layers.8.mlp.experts.12.down_proj.weight', 'ernie.layers.8.mlp.experts.13.down_proj.weight', 'ernie.layers.8.mlp.experts.14.down_proj.weight', 'ernie.layers.8.mlp.experts.15.down_proj.weight', 'ernie.layers.8.mlp.experts.16.down_proj.weight', 'ernie.layers.8.mlp.experts.17.down_proj.weight', 'ernie.layers.8.mlp.experts.18.down_proj.weight', 'ernie.layers.8.mlp.experts.19.down_proj.weight', 'ernie.layers.8.mlp.experts.20.down_proj.weight', 'ernie.layers.8.mlp.experts.21.down_proj.weight', 'ernie.layers.8.mlp.experts.22.down_proj.weight', 'ernie.layers.8.mlp.experts.23.down_proj.weight', 'ernie.layers.8.mlp.experts.24.down_proj.weight', 'ernie.layers.8.mlp.experts.25.down_proj.weight', 'ernie.layers.8.mlp.experts.26.down_proj.weight', 'ernie.layers.8.mlp.experts.27.down_proj.weight', 'ernie.layers.8.mlp.experts.28.down_proj.weight', 'ernie.layers.8.mlp.experts.29.down_proj.weight', 'ernie.layers.8.mlp.experts.30.down_proj.weight', 'ernie.layers.8.mlp.experts.31.down_proj.weight', 'ernie.layers.8.mlp.experts.64.down_proj.weight', 'ernie.layers.8.mlp.experts.65.down_proj.weight', 'ernie.layers.8.mlp.experts.66.down_proj.weight', 'ernie.layers.8.mlp.experts.67.down_proj.weight', 'ernie.layers.8.mlp.experts.68.down_proj.weight', 'ernie.layers.8.mlp.experts.69.down_proj.weight', 'ernie.layers.8.mlp.experts.70.down_proj.weight', 'ernie.layers.8.mlp.experts.71.down_proj.weight', 'ernie.layers.8.mlp.experts.72.down_proj.weight', 'ernie.layers.8.mlp.experts.73.down_proj.weight', 'ernie.layers.8.mlp.experts.74.down_proj.weight', 'ernie.layers.8.mlp.experts.75.down_proj.weight', 'ernie.layers.8.mlp.experts.76.down_proj.weight', 'ernie.layers.8.mlp.experts.77.down_proj.weight', 'ernie.layers.8.mlp.experts.78.down_proj.weight', 'ernie.layers.8.mlp.experts.79.down_proj.weight', 'ernie.layers.8.mlp.experts.80.down_proj.weight', 'ernie.layers.8.mlp.experts.81.down_proj.weight', 'ernie.layers.8.mlp.experts.82.down_proj.weight', 'ernie.layers.8.mlp.experts.83.down_proj.weight', 'ernie.layers.8.mlp.experts.84.down_proj.weight', 'ernie.layers.8.mlp.experts.85.down_proj.weight', 'ernie.layers.8.mlp.experts.86.down_proj.weight', 'ernie.layers.8.mlp.experts.87.down_proj.weight', 'ernie.layers.8.mlp.experts.88.down_proj.weight', 'ernie.layers.8.mlp.experts.89.down_proj.weight', 'ernie.layers.8.mlp.experts.90.down_proj.weight', 'ernie.layers.8.mlp.experts.91.down_proj.weight', 'ernie.layers.8.mlp.experts.92.down_proj.weight', 'ernie.layers.8.mlp.experts.93.down_proj.weight', 'ernie.layers.8.mlp.experts.94.down_proj.weight', 'ernie.layers.8.mlp.experts.95.down_proj.weight'] ernie.layers.9.mlp.text_fused_moe.gate.weight:ernie.layers.9.mlp.gate.weight -ernie.layers.9.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.9.mlp.moe_statics.e_score_correction_bias +ernie.layers.9.mlp.gate_correction_bias:ernie.layers.9.mlp.moe_statics.e_score_correction_bias ernie.layers.9.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.9.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.95.up_gate_proj.weight'] ernie.layers.9.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.9.mlp.experts.0.down_proj.weight', 'ernie.layers.9.mlp.experts.1.down_proj.weight', 'ernie.layers.9.mlp.experts.2.down_proj.weight', 'ernie.layers.9.mlp.experts.3.down_proj.weight', 'ernie.layers.9.mlp.experts.4.down_proj.weight', 'ernie.layers.9.mlp.experts.5.down_proj.weight', 'ernie.layers.9.mlp.experts.6.down_proj.weight', 'ernie.layers.9.mlp.experts.7.down_proj.weight', 'ernie.layers.9.mlp.experts.8.down_proj.weight', 'ernie.layers.9.mlp.experts.9.down_proj.weight', 'ernie.layers.9.mlp.experts.10.down_proj.weight', 'ernie.layers.9.mlp.experts.11.down_proj.weight', 'ernie.layers.9.mlp.experts.12.down_proj.weight', 'ernie.layers.9.mlp.experts.13.down_proj.weight', 'ernie.layers.9.mlp.experts.14.down_proj.weight', 'ernie.layers.9.mlp.experts.15.down_proj.weight', 'ernie.layers.9.mlp.experts.16.down_proj.weight', 'ernie.layers.9.mlp.experts.17.down_proj.weight', 'ernie.layers.9.mlp.experts.18.down_proj.weight', 'ernie.layers.9.mlp.experts.19.down_proj.weight', 'ernie.layers.9.mlp.experts.20.down_proj.weight', 'ernie.layers.9.mlp.experts.21.down_proj.weight', 'ernie.layers.9.mlp.experts.22.down_proj.weight', 'ernie.layers.9.mlp.experts.23.down_proj.weight', 'ernie.layers.9.mlp.experts.24.down_proj.weight', 'ernie.layers.9.mlp.experts.25.down_proj.weight', 'ernie.layers.9.mlp.experts.26.down_proj.weight', 'ernie.layers.9.mlp.experts.27.down_proj.weight', 'ernie.layers.9.mlp.experts.28.down_proj.weight', 'ernie.layers.9.mlp.experts.29.down_proj.weight', 'ernie.layers.9.mlp.experts.30.down_proj.weight', 'ernie.layers.9.mlp.experts.31.down_proj.weight', 'ernie.layers.9.mlp.experts.64.down_proj.weight', 'ernie.layers.9.mlp.experts.65.down_proj.weight', 'ernie.layers.9.mlp.experts.66.down_proj.weight', 'ernie.layers.9.mlp.experts.67.down_proj.weight', 'ernie.layers.9.mlp.experts.68.down_proj.weight', 'ernie.layers.9.mlp.experts.69.down_proj.weight', 'ernie.layers.9.mlp.experts.70.down_proj.weight', 'ernie.layers.9.mlp.experts.71.down_proj.weight', 'ernie.layers.9.mlp.experts.72.down_proj.weight', 'ernie.layers.9.mlp.experts.73.down_proj.weight', 'ernie.layers.9.mlp.experts.74.down_proj.weight', 'ernie.layers.9.mlp.experts.75.down_proj.weight', 'ernie.layers.9.mlp.experts.76.down_proj.weight', 'ernie.layers.9.mlp.experts.77.down_proj.weight', 'ernie.layers.9.mlp.experts.78.down_proj.weight', 'ernie.layers.9.mlp.experts.79.down_proj.weight', 'ernie.layers.9.mlp.experts.80.down_proj.weight', 'ernie.layers.9.mlp.experts.81.down_proj.weight', 'ernie.layers.9.mlp.experts.82.down_proj.weight', 'ernie.layers.9.mlp.experts.83.down_proj.weight', 'ernie.layers.9.mlp.experts.84.down_proj.weight', 'ernie.layers.9.mlp.experts.85.down_proj.weight', 'ernie.layers.9.mlp.experts.86.down_proj.weight', 'ernie.layers.9.mlp.experts.87.down_proj.weight', 'ernie.layers.9.mlp.experts.88.down_proj.weight', 'ernie.layers.9.mlp.experts.89.down_proj.weight', 'ernie.layers.9.mlp.experts.90.down_proj.weight', 'ernie.layers.9.mlp.experts.91.down_proj.weight', 'ernie.layers.9.mlp.experts.92.down_proj.weight', 'ernie.layers.9.mlp.experts.93.down_proj.weight', 'ernie.layers.9.mlp.experts.94.down_proj.weight', 'ernie.layers.9.mlp.experts.95.down_proj.weight'] ernie.layers.10.mlp.text_fused_moe.gate.weight:ernie.layers.10.mlp.gate.weight -ernie.layers.10.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.10.mlp.moe_statics.e_score_correction_bias +ernie.layers.10.mlp.gate_correction_bias:ernie.layers.10.mlp.moe_statics.e_score_correction_bias ernie.layers.10.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.10.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.95.up_gate_proj.weight'] ernie.layers.10.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.10.mlp.experts.0.down_proj.weight', 'ernie.layers.10.mlp.experts.1.down_proj.weight', 'ernie.layers.10.mlp.experts.2.down_proj.weight', 'ernie.layers.10.mlp.experts.3.down_proj.weight', 'ernie.layers.10.mlp.experts.4.down_proj.weight', 'ernie.layers.10.mlp.experts.5.down_proj.weight', 'ernie.layers.10.mlp.experts.6.down_proj.weight', 'ernie.layers.10.mlp.experts.7.down_proj.weight', 'ernie.layers.10.mlp.experts.8.down_proj.weight', 'ernie.layers.10.mlp.experts.9.down_proj.weight', 'ernie.layers.10.mlp.experts.10.down_proj.weight', 'ernie.layers.10.mlp.experts.11.down_proj.weight', 'ernie.layers.10.mlp.experts.12.down_proj.weight', 'ernie.layers.10.mlp.experts.13.down_proj.weight', 'ernie.layers.10.mlp.experts.14.down_proj.weight', 'ernie.layers.10.mlp.experts.15.down_proj.weight', 'ernie.layers.10.mlp.experts.16.down_proj.weight', 'ernie.layers.10.mlp.experts.17.down_proj.weight', 'ernie.layers.10.mlp.experts.18.down_proj.weight', 'ernie.layers.10.mlp.experts.19.down_proj.weight', 'ernie.layers.10.mlp.experts.20.down_proj.weight', 'ernie.layers.10.mlp.experts.21.down_proj.weight', 'ernie.layers.10.mlp.experts.22.down_proj.weight', 'ernie.layers.10.mlp.experts.23.down_proj.weight', 'ernie.layers.10.mlp.experts.24.down_proj.weight', 'ernie.layers.10.mlp.experts.25.down_proj.weight', 'ernie.layers.10.mlp.experts.26.down_proj.weight', 'ernie.layers.10.mlp.experts.27.down_proj.weight', 'ernie.layers.10.mlp.experts.28.down_proj.weight', 'ernie.layers.10.mlp.experts.29.down_proj.weight', 'ernie.layers.10.mlp.experts.30.down_proj.weight', 'ernie.layers.10.mlp.experts.31.down_proj.weight', 'ernie.layers.10.mlp.experts.64.down_proj.weight', 'ernie.layers.10.mlp.experts.65.down_proj.weight', 'ernie.layers.10.mlp.experts.66.down_proj.weight', 'ernie.layers.10.mlp.experts.67.down_proj.weight', 'ernie.layers.10.mlp.experts.68.down_proj.weight', 'ernie.layers.10.mlp.experts.69.down_proj.weight', 'ernie.layers.10.mlp.experts.70.down_proj.weight', 'ernie.layers.10.mlp.experts.71.down_proj.weight', 'ernie.layers.10.mlp.experts.72.down_proj.weight', 'ernie.layers.10.mlp.experts.73.down_proj.weight', 'ernie.layers.10.mlp.experts.74.down_proj.weight', 'ernie.layers.10.mlp.experts.75.down_proj.weight', 'ernie.layers.10.mlp.experts.76.down_proj.weight', 'ernie.layers.10.mlp.experts.77.down_proj.weight', 'ernie.layers.10.mlp.experts.78.down_proj.weight', 'ernie.layers.10.mlp.experts.79.down_proj.weight', 'ernie.layers.10.mlp.experts.80.down_proj.weight', 'ernie.layers.10.mlp.experts.81.down_proj.weight', 'ernie.layers.10.mlp.experts.82.down_proj.weight', 'ernie.layers.10.mlp.experts.83.down_proj.weight', 'ernie.layers.10.mlp.experts.84.down_proj.weight', 'ernie.layers.10.mlp.experts.85.down_proj.weight', 'ernie.layers.10.mlp.experts.86.down_proj.weight', 'ernie.layers.10.mlp.experts.87.down_proj.weight', 'ernie.layers.10.mlp.experts.88.down_proj.weight', 'ernie.layers.10.mlp.experts.89.down_proj.weight', 'ernie.layers.10.mlp.experts.90.down_proj.weight', 'ernie.layers.10.mlp.experts.91.down_proj.weight', 'ernie.layers.10.mlp.experts.92.down_proj.weight', 'ernie.layers.10.mlp.experts.93.down_proj.weight', 'ernie.layers.10.mlp.experts.94.down_proj.weight', 'ernie.layers.10.mlp.experts.95.down_proj.weight'] ernie.layers.11.mlp.text_fused_moe.gate.weight:ernie.layers.11.mlp.gate.weight -ernie.layers.11.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.11.mlp.moe_statics.e_score_correction_bias +ernie.layers.11.mlp.gate_correction_bias:ernie.layers.11.mlp.moe_statics.e_score_correction_bias ernie.layers.11.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.11.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.95.up_gate_proj.weight'] ernie.layers.11.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.11.mlp.experts.0.down_proj.weight', 'ernie.layers.11.mlp.experts.1.down_proj.weight', 'ernie.layers.11.mlp.experts.2.down_proj.weight', 'ernie.layers.11.mlp.experts.3.down_proj.weight', 'ernie.layers.11.mlp.experts.4.down_proj.weight', 'ernie.layers.11.mlp.experts.5.down_proj.weight', 'ernie.layers.11.mlp.experts.6.down_proj.weight', 'ernie.layers.11.mlp.experts.7.down_proj.weight', 'ernie.layers.11.mlp.experts.8.down_proj.weight', 'ernie.layers.11.mlp.experts.9.down_proj.weight', 'ernie.layers.11.mlp.experts.10.down_proj.weight', 'ernie.layers.11.mlp.experts.11.down_proj.weight', 'ernie.layers.11.mlp.experts.12.down_proj.weight', 'ernie.layers.11.mlp.experts.13.down_proj.weight', 'ernie.layers.11.mlp.experts.14.down_proj.weight', 'ernie.layers.11.mlp.experts.15.down_proj.weight', 'ernie.layers.11.mlp.experts.16.down_proj.weight', 'ernie.layers.11.mlp.experts.17.down_proj.weight', 'ernie.layers.11.mlp.experts.18.down_proj.weight', 'ernie.layers.11.mlp.experts.19.down_proj.weight', 'ernie.layers.11.mlp.experts.20.down_proj.weight', 'ernie.layers.11.mlp.experts.21.down_proj.weight', 'ernie.layers.11.mlp.experts.22.down_proj.weight', 'ernie.layers.11.mlp.experts.23.down_proj.weight', 'ernie.layers.11.mlp.experts.24.down_proj.weight', 'ernie.layers.11.mlp.experts.25.down_proj.weight', 'ernie.layers.11.mlp.experts.26.down_proj.weight', 'ernie.layers.11.mlp.experts.27.down_proj.weight', 'ernie.layers.11.mlp.experts.28.down_proj.weight', 'ernie.layers.11.mlp.experts.29.down_proj.weight', 'ernie.layers.11.mlp.experts.30.down_proj.weight', 'ernie.layers.11.mlp.experts.31.down_proj.weight', 'ernie.layers.11.mlp.experts.64.down_proj.weight', 'ernie.layers.11.mlp.experts.65.down_proj.weight', 'ernie.layers.11.mlp.experts.66.down_proj.weight', 'ernie.layers.11.mlp.experts.67.down_proj.weight', 'ernie.layers.11.mlp.experts.68.down_proj.weight', 'ernie.layers.11.mlp.experts.69.down_proj.weight', 'ernie.layers.11.mlp.experts.70.down_proj.weight', 'ernie.layers.11.mlp.experts.71.down_proj.weight', 'ernie.layers.11.mlp.experts.72.down_proj.weight', 'ernie.layers.11.mlp.experts.73.down_proj.weight', 'ernie.layers.11.mlp.experts.74.down_proj.weight', 'ernie.layers.11.mlp.experts.75.down_proj.weight', 'ernie.layers.11.mlp.experts.76.down_proj.weight', 'ernie.layers.11.mlp.experts.77.down_proj.weight', 'ernie.layers.11.mlp.experts.78.down_proj.weight', 'ernie.layers.11.mlp.experts.79.down_proj.weight', 'ernie.layers.11.mlp.experts.80.down_proj.weight', 'ernie.layers.11.mlp.experts.81.down_proj.weight', 'ernie.layers.11.mlp.experts.82.down_proj.weight', 'ernie.layers.11.mlp.experts.83.down_proj.weight', 'ernie.layers.11.mlp.experts.84.down_proj.weight', 'ernie.layers.11.mlp.experts.85.down_proj.weight', 'ernie.layers.11.mlp.experts.86.down_proj.weight', 'ernie.layers.11.mlp.experts.87.down_proj.weight', 'ernie.layers.11.mlp.experts.88.down_proj.weight', 'ernie.layers.11.mlp.experts.89.down_proj.weight', 'ernie.layers.11.mlp.experts.90.down_proj.weight', 'ernie.layers.11.mlp.experts.91.down_proj.weight', 'ernie.layers.11.mlp.experts.92.down_proj.weight', 'ernie.layers.11.mlp.experts.93.down_proj.weight', 'ernie.layers.11.mlp.experts.94.down_proj.weight', 'ernie.layers.11.mlp.experts.95.down_proj.weight'] ernie.layers.12.mlp.text_fused_moe.gate.weight:ernie.layers.12.mlp.gate.weight -ernie.layers.12.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.12.mlp.moe_statics.e_score_correction_bias +ernie.layers.12.mlp.gate_correction_bias:ernie.layers.12.mlp.moe_statics.e_score_correction_bias ernie.layers.12.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.12.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.95.up_gate_proj.weight'] ernie.layers.12.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.12.mlp.experts.0.down_proj.weight', 'ernie.layers.12.mlp.experts.1.down_proj.weight', 'ernie.layers.12.mlp.experts.2.down_proj.weight', 'ernie.layers.12.mlp.experts.3.down_proj.weight', 'ernie.layers.12.mlp.experts.4.down_proj.weight', 'ernie.layers.12.mlp.experts.5.down_proj.weight', 'ernie.layers.12.mlp.experts.6.down_proj.weight', 'ernie.layers.12.mlp.experts.7.down_proj.weight', 'ernie.layers.12.mlp.experts.8.down_proj.weight', 'ernie.layers.12.mlp.experts.9.down_proj.weight', 'ernie.layers.12.mlp.experts.10.down_proj.weight', 'ernie.layers.12.mlp.experts.11.down_proj.weight', 'ernie.layers.12.mlp.experts.12.down_proj.weight', 'ernie.layers.12.mlp.experts.13.down_proj.weight', 'ernie.layers.12.mlp.experts.14.down_proj.weight', 'ernie.layers.12.mlp.experts.15.down_proj.weight', 'ernie.layers.12.mlp.experts.16.down_proj.weight', 'ernie.layers.12.mlp.experts.17.down_proj.weight', 'ernie.layers.12.mlp.experts.18.down_proj.weight', 'ernie.layers.12.mlp.experts.19.down_proj.weight', 'ernie.layers.12.mlp.experts.20.down_proj.weight', 'ernie.layers.12.mlp.experts.21.down_proj.weight', 'ernie.layers.12.mlp.experts.22.down_proj.weight', 'ernie.layers.12.mlp.experts.23.down_proj.weight', 'ernie.layers.12.mlp.experts.24.down_proj.weight', 'ernie.layers.12.mlp.experts.25.down_proj.weight', 'ernie.layers.12.mlp.experts.26.down_proj.weight', 'ernie.layers.12.mlp.experts.27.down_proj.weight', 'ernie.layers.12.mlp.experts.28.down_proj.weight', 'ernie.layers.12.mlp.experts.29.down_proj.weight', 'ernie.layers.12.mlp.experts.30.down_proj.weight', 'ernie.layers.12.mlp.experts.31.down_proj.weight', 'ernie.layers.12.mlp.experts.64.down_proj.weight', 'ernie.layers.12.mlp.experts.65.down_proj.weight', 'ernie.layers.12.mlp.experts.66.down_proj.weight', 'ernie.layers.12.mlp.experts.67.down_proj.weight', 'ernie.layers.12.mlp.experts.68.down_proj.weight', 'ernie.layers.12.mlp.experts.69.down_proj.weight', 'ernie.layers.12.mlp.experts.70.down_proj.weight', 'ernie.layers.12.mlp.experts.71.down_proj.weight', 'ernie.layers.12.mlp.experts.72.down_proj.weight', 'ernie.layers.12.mlp.experts.73.down_proj.weight', 'ernie.layers.12.mlp.experts.74.down_proj.weight', 'ernie.layers.12.mlp.experts.75.down_proj.weight', 'ernie.layers.12.mlp.experts.76.down_proj.weight', 'ernie.layers.12.mlp.experts.77.down_proj.weight', 'ernie.layers.12.mlp.experts.78.down_proj.weight', 'ernie.layers.12.mlp.experts.79.down_proj.weight', 'ernie.layers.12.mlp.experts.80.down_proj.weight', 'ernie.layers.12.mlp.experts.81.down_proj.weight', 'ernie.layers.12.mlp.experts.82.down_proj.weight', 'ernie.layers.12.mlp.experts.83.down_proj.weight', 'ernie.layers.12.mlp.experts.84.down_proj.weight', 'ernie.layers.12.mlp.experts.85.down_proj.weight', 'ernie.layers.12.mlp.experts.86.down_proj.weight', 'ernie.layers.12.mlp.experts.87.down_proj.weight', 'ernie.layers.12.mlp.experts.88.down_proj.weight', 'ernie.layers.12.mlp.experts.89.down_proj.weight', 'ernie.layers.12.mlp.experts.90.down_proj.weight', 'ernie.layers.12.mlp.experts.91.down_proj.weight', 'ernie.layers.12.mlp.experts.92.down_proj.weight', 'ernie.layers.12.mlp.experts.93.down_proj.weight', 'ernie.layers.12.mlp.experts.94.down_proj.weight', 'ernie.layers.12.mlp.experts.95.down_proj.weight'] ernie.layers.13.mlp.text_fused_moe.gate.weight:ernie.layers.13.mlp.gate.weight -ernie.layers.13.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.13.mlp.moe_statics.e_score_correction_bias +ernie.layers.13.mlp.gate_correction_bias:ernie.layers.13.mlp.moe_statics.e_score_correction_bias ernie.layers.13.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.13.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.95.up_gate_proj.weight'] ernie.layers.13.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.13.mlp.experts.0.down_proj.weight', 'ernie.layers.13.mlp.experts.1.down_proj.weight', 'ernie.layers.13.mlp.experts.2.down_proj.weight', 'ernie.layers.13.mlp.experts.3.down_proj.weight', 'ernie.layers.13.mlp.experts.4.down_proj.weight', 'ernie.layers.13.mlp.experts.5.down_proj.weight', 'ernie.layers.13.mlp.experts.6.down_proj.weight', 'ernie.layers.13.mlp.experts.7.down_proj.weight', 'ernie.layers.13.mlp.experts.8.down_proj.weight', 'ernie.layers.13.mlp.experts.9.down_proj.weight', 'ernie.layers.13.mlp.experts.10.down_proj.weight', 'ernie.layers.13.mlp.experts.11.down_proj.weight', 'ernie.layers.13.mlp.experts.12.down_proj.weight', 'ernie.layers.13.mlp.experts.13.down_proj.weight', 'ernie.layers.13.mlp.experts.14.down_proj.weight', 'ernie.layers.13.mlp.experts.15.down_proj.weight', 'ernie.layers.13.mlp.experts.16.down_proj.weight', 'ernie.layers.13.mlp.experts.17.down_proj.weight', 'ernie.layers.13.mlp.experts.18.down_proj.weight', 'ernie.layers.13.mlp.experts.19.down_proj.weight', 'ernie.layers.13.mlp.experts.20.down_proj.weight', 'ernie.layers.13.mlp.experts.21.down_proj.weight', 'ernie.layers.13.mlp.experts.22.down_proj.weight', 'ernie.layers.13.mlp.experts.23.down_proj.weight', 'ernie.layers.13.mlp.experts.24.down_proj.weight', 'ernie.layers.13.mlp.experts.25.down_proj.weight', 'ernie.layers.13.mlp.experts.26.down_proj.weight', 'ernie.layers.13.mlp.experts.27.down_proj.weight', 'ernie.layers.13.mlp.experts.28.down_proj.weight', 'ernie.layers.13.mlp.experts.29.down_proj.weight', 'ernie.layers.13.mlp.experts.30.down_proj.weight', 'ernie.layers.13.mlp.experts.31.down_proj.weight', 'ernie.layers.13.mlp.experts.64.down_proj.weight', 'ernie.layers.13.mlp.experts.65.down_proj.weight', 'ernie.layers.13.mlp.experts.66.down_proj.weight', 'ernie.layers.13.mlp.experts.67.down_proj.weight', 'ernie.layers.13.mlp.experts.68.down_proj.weight', 'ernie.layers.13.mlp.experts.69.down_proj.weight', 'ernie.layers.13.mlp.experts.70.down_proj.weight', 'ernie.layers.13.mlp.experts.71.down_proj.weight', 'ernie.layers.13.mlp.experts.72.down_proj.weight', 'ernie.layers.13.mlp.experts.73.down_proj.weight', 'ernie.layers.13.mlp.experts.74.down_proj.weight', 'ernie.layers.13.mlp.experts.75.down_proj.weight', 'ernie.layers.13.mlp.experts.76.down_proj.weight', 'ernie.layers.13.mlp.experts.77.down_proj.weight', 'ernie.layers.13.mlp.experts.78.down_proj.weight', 'ernie.layers.13.mlp.experts.79.down_proj.weight', 'ernie.layers.13.mlp.experts.80.down_proj.weight', 'ernie.layers.13.mlp.experts.81.down_proj.weight', 'ernie.layers.13.mlp.experts.82.down_proj.weight', 'ernie.layers.13.mlp.experts.83.down_proj.weight', 'ernie.layers.13.mlp.experts.84.down_proj.weight', 'ernie.layers.13.mlp.experts.85.down_proj.weight', 'ernie.layers.13.mlp.experts.86.down_proj.weight', 'ernie.layers.13.mlp.experts.87.down_proj.weight', 'ernie.layers.13.mlp.experts.88.down_proj.weight', 'ernie.layers.13.mlp.experts.89.down_proj.weight', 'ernie.layers.13.mlp.experts.90.down_proj.weight', 'ernie.layers.13.mlp.experts.91.down_proj.weight', 'ernie.layers.13.mlp.experts.92.down_proj.weight', 'ernie.layers.13.mlp.experts.93.down_proj.weight', 'ernie.layers.13.mlp.experts.94.down_proj.weight', 'ernie.layers.13.mlp.experts.95.down_proj.weight'] ernie.layers.14.mlp.text_fused_moe.gate.weight:ernie.layers.14.mlp.gate.weight -ernie.layers.14.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.14.mlp.moe_statics.e_score_correction_bias +ernie.layers.14.mlp.gate_correction_bias:ernie.layers.14.mlp.moe_statics.e_score_correction_bias ernie.layers.14.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.14.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.95.up_gate_proj.weight'] ernie.layers.14.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.14.mlp.experts.0.down_proj.weight', 'ernie.layers.14.mlp.experts.1.down_proj.weight', 'ernie.layers.14.mlp.experts.2.down_proj.weight', 'ernie.layers.14.mlp.experts.3.down_proj.weight', 'ernie.layers.14.mlp.experts.4.down_proj.weight', 'ernie.layers.14.mlp.experts.5.down_proj.weight', 'ernie.layers.14.mlp.experts.6.down_proj.weight', 'ernie.layers.14.mlp.experts.7.down_proj.weight', 'ernie.layers.14.mlp.experts.8.down_proj.weight', 'ernie.layers.14.mlp.experts.9.down_proj.weight', 'ernie.layers.14.mlp.experts.10.down_proj.weight', 'ernie.layers.14.mlp.experts.11.down_proj.weight', 'ernie.layers.14.mlp.experts.12.down_proj.weight', 'ernie.layers.14.mlp.experts.13.down_proj.weight', 'ernie.layers.14.mlp.experts.14.down_proj.weight', 'ernie.layers.14.mlp.experts.15.down_proj.weight', 'ernie.layers.14.mlp.experts.16.down_proj.weight', 'ernie.layers.14.mlp.experts.17.down_proj.weight', 'ernie.layers.14.mlp.experts.18.down_proj.weight', 'ernie.layers.14.mlp.experts.19.down_proj.weight', 'ernie.layers.14.mlp.experts.20.down_proj.weight', 'ernie.layers.14.mlp.experts.21.down_proj.weight', 'ernie.layers.14.mlp.experts.22.down_proj.weight', 'ernie.layers.14.mlp.experts.23.down_proj.weight', 'ernie.layers.14.mlp.experts.24.down_proj.weight', 'ernie.layers.14.mlp.experts.25.down_proj.weight', 'ernie.layers.14.mlp.experts.26.down_proj.weight', 'ernie.layers.14.mlp.experts.27.down_proj.weight', 'ernie.layers.14.mlp.experts.28.down_proj.weight', 'ernie.layers.14.mlp.experts.29.down_proj.weight', 'ernie.layers.14.mlp.experts.30.down_proj.weight', 'ernie.layers.14.mlp.experts.31.down_proj.weight', 'ernie.layers.14.mlp.experts.64.down_proj.weight', 'ernie.layers.14.mlp.experts.65.down_proj.weight', 'ernie.layers.14.mlp.experts.66.down_proj.weight', 'ernie.layers.14.mlp.experts.67.down_proj.weight', 'ernie.layers.14.mlp.experts.68.down_proj.weight', 'ernie.layers.14.mlp.experts.69.down_proj.weight', 'ernie.layers.14.mlp.experts.70.down_proj.weight', 'ernie.layers.14.mlp.experts.71.down_proj.weight', 'ernie.layers.14.mlp.experts.72.down_proj.weight', 'ernie.layers.14.mlp.experts.73.down_proj.weight', 'ernie.layers.14.mlp.experts.74.down_proj.weight', 'ernie.layers.14.mlp.experts.75.down_proj.weight', 'ernie.layers.14.mlp.experts.76.down_proj.weight', 'ernie.layers.14.mlp.experts.77.down_proj.weight', 'ernie.layers.14.mlp.experts.78.down_proj.weight', 'ernie.layers.14.mlp.experts.79.down_proj.weight', 'ernie.layers.14.mlp.experts.80.down_proj.weight', 'ernie.layers.14.mlp.experts.81.down_proj.weight', 'ernie.layers.14.mlp.experts.82.down_proj.weight', 'ernie.layers.14.mlp.experts.83.down_proj.weight', 'ernie.layers.14.mlp.experts.84.down_proj.weight', 'ernie.layers.14.mlp.experts.85.down_proj.weight', 'ernie.layers.14.mlp.experts.86.down_proj.weight', 'ernie.layers.14.mlp.experts.87.down_proj.weight', 'ernie.layers.14.mlp.experts.88.down_proj.weight', 'ernie.layers.14.mlp.experts.89.down_proj.weight', 'ernie.layers.14.mlp.experts.90.down_proj.weight', 'ernie.layers.14.mlp.experts.91.down_proj.weight', 'ernie.layers.14.mlp.experts.92.down_proj.weight', 'ernie.layers.14.mlp.experts.93.down_proj.weight', 'ernie.layers.14.mlp.experts.94.down_proj.weight', 'ernie.layers.14.mlp.experts.95.down_proj.weight'] ernie.layers.15.mlp.text_fused_moe.gate.weight:ernie.layers.15.mlp.gate.weight -ernie.layers.15.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.15.mlp.moe_statics.e_score_correction_bias +ernie.layers.15.mlp.gate_correction_bias:ernie.layers.15.mlp.moe_statics.e_score_correction_bias ernie.layers.15.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.15.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.95.up_gate_proj.weight'] ernie.layers.15.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.15.mlp.experts.0.down_proj.weight', 'ernie.layers.15.mlp.experts.1.down_proj.weight', 'ernie.layers.15.mlp.experts.2.down_proj.weight', 'ernie.layers.15.mlp.experts.3.down_proj.weight', 'ernie.layers.15.mlp.experts.4.down_proj.weight', 'ernie.layers.15.mlp.experts.5.down_proj.weight', 'ernie.layers.15.mlp.experts.6.down_proj.weight', 'ernie.layers.15.mlp.experts.7.down_proj.weight', 'ernie.layers.15.mlp.experts.8.down_proj.weight', 'ernie.layers.15.mlp.experts.9.down_proj.weight', 'ernie.layers.15.mlp.experts.10.down_proj.weight', 'ernie.layers.15.mlp.experts.11.down_proj.weight', 'ernie.layers.15.mlp.experts.12.down_proj.weight', 'ernie.layers.15.mlp.experts.13.down_proj.weight', 'ernie.layers.15.mlp.experts.14.down_proj.weight', 'ernie.layers.15.mlp.experts.15.down_proj.weight', 'ernie.layers.15.mlp.experts.16.down_proj.weight', 'ernie.layers.15.mlp.experts.17.down_proj.weight', 'ernie.layers.15.mlp.experts.18.down_proj.weight', 'ernie.layers.15.mlp.experts.19.down_proj.weight', 'ernie.layers.15.mlp.experts.20.down_proj.weight', 'ernie.layers.15.mlp.experts.21.down_proj.weight', 'ernie.layers.15.mlp.experts.22.down_proj.weight', 'ernie.layers.15.mlp.experts.23.down_proj.weight', 'ernie.layers.15.mlp.experts.24.down_proj.weight', 'ernie.layers.15.mlp.experts.25.down_proj.weight', 'ernie.layers.15.mlp.experts.26.down_proj.weight', 'ernie.layers.15.mlp.experts.27.down_proj.weight', 'ernie.layers.15.mlp.experts.28.down_proj.weight', 'ernie.layers.15.mlp.experts.29.down_proj.weight', 'ernie.layers.15.mlp.experts.30.down_proj.weight', 'ernie.layers.15.mlp.experts.31.down_proj.weight', 'ernie.layers.15.mlp.experts.64.down_proj.weight', 'ernie.layers.15.mlp.experts.65.down_proj.weight', 'ernie.layers.15.mlp.experts.66.down_proj.weight', 'ernie.layers.15.mlp.experts.67.down_proj.weight', 'ernie.layers.15.mlp.experts.68.down_proj.weight', 'ernie.layers.15.mlp.experts.69.down_proj.weight', 'ernie.layers.15.mlp.experts.70.down_proj.weight', 'ernie.layers.15.mlp.experts.71.down_proj.weight', 'ernie.layers.15.mlp.experts.72.down_proj.weight', 'ernie.layers.15.mlp.experts.73.down_proj.weight', 'ernie.layers.15.mlp.experts.74.down_proj.weight', 'ernie.layers.15.mlp.experts.75.down_proj.weight', 'ernie.layers.15.mlp.experts.76.down_proj.weight', 'ernie.layers.15.mlp.experts.77.down_proj.weight', 'ernie.layers.15.mlp.experts.78.down_proj.weight', 'ernie.layers.15.mlp.experts.79.down_proj.weight', 'ernie.layers.15.mlp.experts.80.down_proj.weight', 'ernie.layers.15.mlp.experts.81.down_proj.weight', 'ernie.layers.15.mlp.experts.82.down_proj.weight', 'ernie.layers.15.mlp.experts.83.down_proj.weight', 'ernie.layers.15.mlp.experts.84.down_proj.weight', 'ernie.layers.15.mlp.experts.85.down_proj.weight', 'ernie.layers.15.mlp.experts.86.down_proj.weight', 'ernie.layers.15.mlp.experts.87.down_proj.weight', 'ernie.layers.15.mlp.experts.88.down_proj.weight', 'ernie.layers.15.mlp.experts.89.down_proj.weight', 'ernie.layers.15.mlp.experts.90.down_proj.weight', 'ernie.layers.15.mlp.experts.91.down_proj.weight', 'ernie.layers.15.mlp.experts.92.down_proj.weight', 'ernie.layers.15.mlp.experts.93.down_proj.weight', 'ernie.layers.15.mlp.experts.94.down_proj.weight', 'ernie.layers.15.mlp.experts.95.down_proj.weight'] ernie.layers.16.mlp.text_fused_moe.gate.weight:ernie.layers.16.mlp.gate.weight -ernie.layers.16.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.16.mlp.moe_statics.e_score_correction_bias +ernie.layers.16.mlp.gate_correction_bias:ernie.layers.16.mlp.moe_statics.e_score_correction_bias ernie.layers.16.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.16.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.95.up_gate_proj.weight'] ernie.layers.16.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.16.mlp.experts.0.down_proj.weight', 'ernie.layers.16.mlp.experts.1.down_proj.weight', 'ernie.layers.16.mlp.experts.2.down_proj.weight', 'ernie.layers.16.mlp.experts.3.down_proj.weight', 'ernie.layers.16.mlp.experts.4.down_proj.weight', 'ernie.layers.16.mlp.experts.5.down_proj.weight', 'ernie.layers.16.mlp.experts.6.down_proj.weight', 'ernie.layers.16.mlp.experts.7.down_proj.weight', 'ernie.layers.16.mlp.experts.8.down_proj.weight', 'ernie.layers.16.mlp.experts.9.down_proj.weight', 'ernie.layers.16.mlp.experts.10.down_proj.weight', 'ernie.layers.16.mlp.experts.11.down_proj.weight', 'ernie.layers.16.mlp.experts.12.down_proj.weight', 'ernie.layers.16.mlp.experts.13.down_proj.weight', 'ernie.layers.16.mlp.experts.14.down_proj.weight', 'ernie.layers.16.mlp.experts.15.down_proj.weight', 'ernie.layers.16.mlp.experts.16.down_proj.weight', 'ernie.layers.16.mlp.experts.17.down_proj.weight', 'ernie.layers.16.mlp.experts.18.down_proj.weight', 'ernie.layers.16.mlp.experts.19.down_proj.weight', 'ernie.layers.16.mlp.experts.20.down_proj.weight', 'ernie.layers.16.mlp.experts.21.down_proj.weight', 'ernie.layers.16.mlp.experts.22.down_proj.weight', 'ernie.layers.16.mlp.experts.23.down_proj.weight', 'ernie.layers.16.mlp.experts.24.down_proj.weight', 'ernie.layers.16.mlp.experts.25.down_proj.weight', 'ernie.layers.16.mlp.experts.26.down_proj.weight', 'ernie.layers.16.mlp.experts.27.down_proj.weight', 'ernie.layers.16.mlp.experts.28.down_proj.weight', 'ernie.layers.16.mlp.experts.29.down_proj.weight', 'ernie.layers.16.mlp.experts.30.down_proj.weight', 'ernie.layers.16.mlp.experts.31.down_proj.weight', 'ernie.layers.16.mlp.experts.64.down_proj.weight', 'ernie.layers.16.mlp.experts.65.down_proj.weight', 'ernie.layers.16.mlp.experts.66.down_proj.weight', 'ernie.layers.16.mlp.experts.67.down_proj.weight', 'ernie.layers.16.mlp.experts.68.down_proj.weight', 'ernie.layers.16.mlp.experts.69.down_proj.weight', 'ernie.layers.16.mlp.experts.70.down_proj.weight', 'ernie.layers.16.mlp.experts.71.down_proj.weight', 'ernie.layers.16.mlp.experts.72.down_proj.weight', 'ernie.layers.16.mlp.experts.73.down_proj.weight', 'ernie.layers.16.mlp.experts.74.down_proj.weight', 'ernie.layers.16.mlp.experts.75.down_proj.weight', 'ernie.layers.16.mlp.experts.76.down_proj.weight', 'ernie.layers.16.mlp.experts.77.down_proj.weight', 'ernie.layers.16.mlp.experts.78.down_proj.weight', 'ernie.layers.16.mlp.experts.79.down_proj.weight', 'ernie.layers.16.mlp.experts.80.down_proj.weight', 'ernie.layers.16.mlp.experts.81.down_proj.weight', 'ernie.layers.16.mlp.experts.82.down_proj.weight', 'ernie.layers.16.mlp.experts.83.down_proj.weight', 'ernie.layers.16.mlp.experts.84.down_proj.weight', 'ernie.layers.16.mlp.experts.85.down_proj.weight', 'ernie.layers.16.mlp.experts.86.down_proj.weight', 'ernie.layers.16.mlp.experts.87.down_proj.weight', 'ernie.layers.16.mlp.experts.88.down_proj.weight', 'ernie.layers.16.mlp.experts.89.down_proj.weight', 'ernie.layers.16.mlp.experts.90.down_proj.weight', 'ernie.layers.16.mlp.experts.91.down_proj.weight', 'ernie.layers.16.mlp.experts.92.down_proj.weight', 'ernie.layers.16.mlp.experts.93.down_proj.weight', 'ernie.layers.16.mlp.experts.94.down_proj.weight', 'ernie.layers.16.mlp.experts.95.down_proj.weight'] ernie.layers.17.mlp.text_fused_moe.gate.weight:ernie.layers.17.mlp.gate.weight -ernie.layers.17.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.17.mlp.moe_statics.e_score_correction_bias +ernie.layers.17.mlp.gate_correction_bias:ernie.layers.17.mlp.moe_statics.e_score_correction_bias ernie.layers.17.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.17.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.95.up_gate_proj.weight'] ernie.layers.17.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.17.mlp.experts.0.down_proj.weight', 'ernie.layers.17.mlp.experts.1.down_proj.weight', 'ernie.layers.17.mlp.experts.2.down_proj.weight', 'ernie.layers.17.mlp.experts.3.down_proj.weight', 'ernie.layers.17.mlp.experts.4.down_proj.weight', 'ernie.layers.17.mlp.experts.5.down_proj.weight', 'ernie.layers.17.mlp.experts.6.down_proj.weight', 'ernie.layers.17.mlp.experts.7.down_proj.weight', 'ernie.layers.17.mlp.experts.8.down_proj.weight', 'ernie.layers.17.mlp.experts.9.down_proj.weight', 'ernie.layers.17.mlp.experts.10.down_proj.weight', 'ernie.layers.17.mlp.experts.11.down_proj.weight', 'ernie.layers.17.mlp.experts.12.down_proj.weight', 'ernie.layers.17.mlp.experts.13.down_proj.weight', 'ernie.layers.17.mlp.experts.14.down_proj.weight', 'ernie.layers.17.mlp.experts.15.down_proj.weight', 'ernie.layers.17.mlp.experts.16.down_proj.weight', 'ernie.layers.17.mlp.experts.17.down_proj.weight', 'ernie.layers.17.mlp.experts.18.down_proj.weight', 'ernie.layers.17.mlp.experts.19.down_proj.weight', 'ernie.layers.17.mlp.experts.20.down_proj.weight', 'ernie.layers.17.mlp.experts.21.down_proj.weight', 'ernie.layers.17.mlp.experts.22.down_proj.weight', 'ernie.layers.17.mlp.experts.23.down_proj.weight', 'ernie.layers.17.mlp.experts.24.down_proj.weight', 'ernie.layers.17.mlp.experts.25.down_proj.weight', 'ernie.layers.17.mlp.experts.26.down_proj.weight', 'ernie.layers.17.mlp.experts.27.down_proj.weight', 'ernie.layers.17.mlp.experts.28.down_proj.weight', 'ernie.layers.17.mlp.experts.29.down_proj.weight', 'ernie.layers.17.mlp.experts.30.down_proj.weight', 'ernie.layers.17.mlp.experts.31.down_proj.weight', 'ernie.layers.17.mlp.experts.64.down_proj.weight', 'ernie.layers.17.mlp.experts.65.down_proj.weight', 'ernie.layers.17.mlp.experts.66.down_proj.weight', 'ernie.layers.17.mlp.experts.67.down_proj.weight', 'ernie.layers.17.mlp.experts.68.down_proj.weight', 'ernie.layers.17.mlp.experts.69.down_proj.weight', 'ernie.layers.17.mlp.experts.70.down_proj.weight', 'ernie.layers.17.mlp.experts.71.down_proj.weight', 'ernie.layers.17.mlp.experts.72.down_proj.weight', 'ernie.layers.17.mlp.experts.73.down_proj.weight', 'ernie.layers.17.mlp.experts.74.down_proj.weight', 'ernie.layers.17.mlp.experts.75.down_proj.weight', 'ernie.layers.17.mlp.experts.76.down_proj.weight', 'ernie.layers.17.mlp.experts.77.down_proj.weight', 'ernie.layers.17.mlp.experts.78.down_proj.weight', 'ernie.layers.17.mlp.experts.79.down_proj.weight', 'ernie.layers.17.mlp.experts.80.down_proj.weight', 'ernie.layers.17.mlp.experts.81.down_proj.weight', 'ernie.layers.17.mlp.experts.82.down_proj.weight', 'ernie.layers.17.mlp.experts.83.down_proj.weight', 'ernie.layers.17.mlp.experts.84.down_proj.weight', 'ernie.layers.17.mlp.experts.85.down_proj.weight', 'ernie.layers.17.mlp.experts.86.down_proj.weight', 'ernie.layers.17.mlp.experts.87.down_proj.weight', 'ernie.layers.17.mlp.experts.88.down_proj.weight', 'ernie.layers.17.mlp.experts.89.down_proj.weight', 'ernie.layers.17.mlp.experts.90.down_proj.weight', 'ernie.layers.17.mlp.experts.91.down_proj.weight', 'ernie.layers.17.mlp.experts.92.down_proj.weight', 'ernie.layers.17.mlp.experts.93.down_proj.weight', 'ernie.layers.17.mlp.experts.94.down_proj.weight', 'ernie.layers.17.mlp.experts.95.down_proj.weight'] ernie.layers.18.mlp.text_fused_moe.gate.weight:ernie.layers.18.mlp.gate.weight -ernie.layers.18.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.18.mlp.moe_statics.e_score_correction_bias +ernie.layers.18.mlp.gate_correction_bias:ernie.layers.18.mlp.moe_statics.e_score_correction_bias ernie.layers.18.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.18.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.95.up_gate_proj.weight'] ernie.layers.18.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.18.mlp.experts.0.down_proj.weight', 'ernie.layers.18.mlp.experts.1.down_proj.weight', 'ernie.layers.18.mlp.experts.2.down_proj.weight', 'ernie.layers.18.mlp.experts.3.down_proj.weight', 'ernie.layers.18.mlp.experts.4.down_proj.weight', 'ernie.layers.18.mlp.experts.5.down_proj.weight', 'ernie.layers.18.mlp.experts.6.down_proj.weight', 'ernie.layers.18.mlp.experts.7.down_proj.weight', 'ernie.layers.18.mlp.experts.8.down_proj.weight', 'ernie.layers.18.mlp.experts.9.down_proj.weight', 'ernie.layers.18.mlp.experts.10.down_proj.weight', 'ernie.layers.18.mlp.experts.11.down_proj.weight', 'ernie.layers.18.mlp.experts.12.down_proj.weight', 'ernie.layers.18.mlp.experts.13.down_proj.weight', 'ernie.layers.18.mlp.experts.14.down_proj.weight', 'ernie.layers.18.mlp.experts.15.down_proj.weight', 'ernie.layers.18.mlp.experts.16.down_proj.weight', 'ernie.layers.18.mlp.experts.17.down_proj.weight', 'ernie.layers.18.mlp.experts.18.down_proj.weight', 'ernie.layers.18.mlp.experts.19.down_proj.weight', 'ernie.layers.18.mlp.experts.20.down_proj.weight', 'ernie.layers.18.mlp.experts.21.down_proj.weight', 'ernie.layers.18.mlp.experts.22.down_proj.weight', 'ernie.layers.18.mlp.experts.23.down_proj.weight', 'ernie.layers.18.mlp.experts.24.down_proj.weight', 'ernie.layers.18.mlp.experts.25.down_proj.weight', 'ernie.layers.18.mlp.experts.26.down_proj.weight', 'ernie.layers.18.mlp.experts.27.down_proj.weight', 'ernie.layers.18.mlp.experts.28.down_proj.weight', 'ernie.layers.18.mlp.experts.29.down_proj.weight', 'ernie.layers.18.mlp.experts.30.down_proj.weight', 'ernie.layers.18.mlp.experts.31.down_proj.weight', 'ernie.layers.18.mlp.experts.64.down_proj.weight', 'ernie.layers.18.mlp.experts.65.down_proj.weight', 'ernie.layers.18.mlp.experts.66.down_proj.weight', 'ernie.layers.18.mlp.experts.67.down_proj.weight', 'ernie.layers.18.mlp.experts.68.down_proj.weight', 'ernie.layers.18.mlp.experts.69.down_proj.weight', 'ernie.layers.18.mlp.experts.70.down_proj.weight', 'ernie.layers.18.mlp.experts.71.down_proj.weight', 'ernie.layers.18.mlp.experts.72.down_proj.weight', 'ernie.layers.18.mlp.experts.73.down_proj.weight', 'ernie.layers.18.mlp.experts.74.down_proj.weight', 'ernie.layers.18.mlp.experts.75.down_proj.weight', 'ernie.layers.18.mlp.experts.76.down_proj.weight', 'ernie.layers.18.mlp.experts.77.down_proj.weight', 'ernie.layers.18.mlp.experts.78.down_proj.weight', 'ernie.layers.18.mlp.experts.79.down_proj.weight', 'ernie.layers.18.mlp.experts.80.down_proj.weight', 'ernie.layers.18.mlp.experts.81.down_proj.weight', 'ernie.layers.18.mlp.experts.82.down_proj.weight', 'ernie.layers.18.mlp.experts.83.down_proj.weight', 'ernie.layers.18.mlp.experts.84.down_proj.weight', 'ernie.layers.18.mlp.experts.85.down_proj.weight', 'ernie.layers.18.mlp.experts.86.down_proj.weight', 'ernie.layers.18.mlp.experts.87.down_proj.weight', 'ernie.layers.18.mlp.experts.88.down_proj.weight', 'ernie.layers.18.mlp.experts.89.down_proj.weight', 'ernie.layers.18.mlp.experts.90.down_proj.weight', 'ernie.layers.18.mlp.experts.91.down_proj.weight', 'ernie.layers.18.mlp.experts.92.down_proj.weight', 'ernie.layers.18.mlp.experts.93.down_proj.weight', 'ernie.layers.18.mlp.experts.94.down_proj.weight', 'ernie.layers.18.mlp.experts.95.down_proj.weight'] ernie.layers.19.mlp.text_fused_moe.gate.weight:ernie.layers.19.mlp.gate.weight -ernie.layers.19.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.19.mlp.moe_statics.e_score_correction_bias +ernie.layers.19.mlp.gate_correction_bias:ernie.layers.19.mlp.moe_statics.e_score_correction_bias ernie.layers.19.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.19.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.95.up_gate_proj.weight'] ernie.layers.19.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.19.mlp.experts.0.down_proj.weight', 'ernie.layers.19.mlp.experts.1.down_proj.weight', 'ernie.layers.19.mlp.experts.2.down_proj.weight', 'ernie.layers.19.mlp.experts.3.down_proj.weight', 'ernie.layers.19.mlp.experts.4.down_proj.weight', 'ernie.layers.19.mlp.experts.5.down_proj.weight', 'ernie.layers.19.mlp.experts.6.down_proj.weight', 'ernie.layers.19.mlp.experts.7.down_proj.weight', 'ernie.layers.19.mlp.experts.8.down_proj.weight', 'ernie.layers.19.mlp.experts.9.down_proj.weight', 'ernie.layers.19.mlp.experts.10.down_proj.weight', 'ernie.layers.19.mlp.experts.11.down_proj.weight', 'ernie.layers.19.mlp.experts.12.down_proj.weight', 'ernie.layers.19.mlp.experts.13.down_proj.weight', 'ernie.layers.19.mlp.experts.14.down_proj.weight', 'ernie.layers.19.mlp.experts.15.down_proj.weight', 'ernie.layers.19.mlp.experts.16.down_proj.weight', 'ernie.layers.19.mlp.experts.17.down_proj.weight', 'ernie.layers.19.mlp.experts.18.down_proj.weight', 'ernie.layers.19.mlp.experts.19.down_proj.weight', 'ernie.layers.19.mlp.experts.20.down_proj.weight', 'ernie.layers.19.mlp.experts.21.down_proj.weight', 'ernie.layers.19.mlp.experts.22.down_proj.weight', 'ernie.layers.19.mlp.experts.23.down_proj.weight', 'ernie.layers.19.mlp.experts.24.down_proj.weight', 'ernie.layers.19.mlp.experts.25.down_proj.weight', 'ernie.layers.19.mlp.experts.26.down_proj.weight', 'ernie.layers.19.mlp.experts.27.down_proj.weight', 'ernie.layers.19.mlp.experts.28.down_proj.weight', 'ernie.layers.19.mlp.experts.29.down_proj.weight', 'ernie.layers.19.mlp.experts.30.down_proj.weight', 'ernie.layers.19.mlp.experts.31.down_proj.weight', 'ernie.layers.19.mlp.experts.64.down_proj.weight', 'ernie.layers.19.mlp.experts.65.down_proj.weight', 'ernie.layers.19.mlp.experts.66.down_proj.weight', 'ernie.layers.19.mlp.experts.67.down_proj.weight', 'ernie.layers.19.mlp.experts.68.down_proj.weight', 'ernie.layers.19.mlp.experts.69.down_proj.weight', 'ernie.layers.19.mlp.experts.70.down_proj.weight', 'ernie.layers.19.mlp.experts.71.down_proj.weight', 'ernie.layers.19.mlp.experts.72.down_proj.weight', 'ernie.layers.19.mlp.experts.73.down_proj.weight', 'ernie.layers.19.mlp.experts.74.down_proj.weight', 'ernie.layers.19.mlp.experts.75.down_proj.weight', 'ernie.layers.19.mlp.experts.76.down_proj.weight', 'ernie.layers.19.mlp.experts.77.down_proj.weight', 'ernie.layers.19.mlp.experts.78.down_proj.weight', 'ernie.layers.19.mlp.experts.79.down_proj.weight', 'ernie.layers.19.mlp.experts.80.down_proj.weight', 'ernie.layers.19.mlp.experts.81.down_proj.weight', 'ernie.layers.19.mlp.experts.82.down_proj.weight', 'ernie.layers.19.mlp.experts.83.down_proj.weight', 'ernie.layers.19.mlp.experts.84.down_proj.weight', 'ernie.layers.19.mlp.experts.85.down_proj.weight', 'ernie.layers.19.mlp.experts.86.down_proj.weight', 'ernie.layers.19.mlp.experts.87.down_proj.weight', 'ernie.layers.19.mlp.experts.88.down_proj.weight', 'ernie.layers.19.mlp.experts.89.down_proj.weight', 'ernie.layers.19.mlp.experts.90.down_proj.weight', 'ernie.layers.19.mlp.experts.91.down_proj.weight', 'ernie.layers.19.mlp.experts.92.down_proj.weight', 'ernie.layers.19.mlp.experts.93.down_proj.weight', 'ernie.layers.19.mlp.experts.94.down_proj.weight', 'ernie.layers.19.mlp.experts.95.down_proj.weight'] ernie.layers.20.mlp.text_fused_moe.gate.weight:ernie.layers.20.mlp.gate.weight -ernie.layers.20.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.20.mlp.moe_statics.e_score_correction_bias +ernie.layers.20.mlp.gate_correction_bias:ernie.layers.20.mlp.moe_statics.e_score_correction_bias ernie.layers.20.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.20.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.95.up_gate_proj.weight'] ernie.layers.20.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.20.mlp.experts.0.down_proj.weight', 'ernie.layers.20.mlp.experts.1.down_proj.weight', 'ernie.layers.20.mlp.experts.2.down_proj.weight', 'ernie.layers.20.mlp.experts.3.down_proj.weight', 'ernie.layers.20.mlp.experts.4.down_proj.weight', 'ernie.layers.20.mlp.experts.5.down_proj.weight', 'ernie.layers.20.mlp.experts.6.down_proj.weight', 'ernie.layers.20.mlp.experts.7.down_proj.weight', 'ernie.layers.20.mlp.experts.8.down_proj.weight', 'ernie.layers.20.mlp.experts.9.down_proj.weight', 'ernie.layers.20.mlp.experts.10.down_proj.weight', 'ernie.layers.20.mlp.experts.11.down_proj.weight', 'ernie.layers.20.mlp.experts.12.down_proj.weight', 'ernie.layers.20.mlp.experts.13.down_proj.weight', 'ernie.layers.20.mlp.experts.14.down_proj.weight', 'ernie.layers.20.mlp.experts.15.down_proj.weight', 'ernie.layers.20.mlp.experts.16.down_proj.weight', 'ernie.layers.20.mlp.experts.17.down_proj.weight', 'ernie.layers.20.mlp.experts.18.down_proj.weight', 'ernie.layers.20.mlp.experts.19.down_proj.weight', 'ernie.layers.20.mlp.experts.20.down_proj.weight', 'ernie.layers.20.mlp.experts.21.down_proj.weight', 'ernie.layers.20.mlp.experts.22.down_proj.weight', 'ernie.layers.20.mlp.experts.23.down_proj.weight', 'ernie.layers.20.mlp.experts.24.down_proj.weight', 'ernie.layers.20.mlp.experts.25.down_proj.weight', 'ernie.layers.20.mlp.experts.26.down_proj.weight', 'ernie.layers.20.mlp.experts.27.down_proj.weight', 'ernie.layers.20.mlp.experts.28.down_proj.weight', 'ernie.layers.20.mlp.experts.29.down_proj.weight', 'ernie.layers.20.mlp.experts.30.down_proj.weight', 'ernie.layers.20.mlp.experts.31.down_proj.weight', 'ernie.layers.20.mlp.experts.64.down_proj.weight', 'ernie.layers.20.mlp.experts.65.down_proj.weight', 'ernie.layers.20.mlp.experts.66.down_proj.weight', 'ernie.layers.20.mlp.experts.67.down_proj.weight', 'ernie.layers.20.mlp.experts.68.down_proj.weight', 'ernie.layers.20.mlp.experts.69.down_proj.weight', 'ernie.layers.20.mlp.experts.70.down_proj.weight', 'ernie.layers.20.mlp.experts.71.down_proj.weight', 'ernie.layers.20.mlp.experts.72.down_proj.weight', 'ernie.layers.20.mlp.experts.73.down_proj.weight', 'ernie.layers.20.mlp.experts.74.down_proj.weight', 'ernie.layers.20.mlp.experts.75.down_proj.weight', 'ernie.layers.20.mlp.experts.76.down_proj.weight', 'ernie.layers.20.mlp.experts.77.down_proj.weight', 'ernie.layers.20.mlp.experts.78.down_proj.weight', 'ernie.layers.20.mlp.experts.79.down_proj.weight', 'ernie.layers.20.mlp.experts.80.down_proj.weight', 'ernie.layers.20.mlp.experts.81.down_proj.weight', 'ernie.layers.20.mlp.experts.82.down_proj.weight', 'ernie.layers.20.mlp.experts.83.down_proj.weight', 'ernie.layers.20.mlp.experts.84.down_proj.weight', 'ernie.layers.20.mlp.experts.85.down_proj.weight', 'ernie.layers.20.mlp.experts.86.down_proj.weight', 'ernie.layers.20.mlp.experts.87.down_proj.weight', 'ernie.layers.20.mlp.experts.88.down_proj.weight', 'ernie.layers.20.mlp.experts.89.down_proj.weight', 'ernie.layers.20.mlp.experts.90.down_proj.weight', 'ernie.layers.20.mlp.experts.91.down_proj.weight', 'ernie.layers.20.mlp.experts.92.down_proj.weight', 'ernie.layers.20.mlp.experts.93.down_proj.weight', 'ernie.layers.20.mlp.experts.94.down_proj.weight', 'ernie.layers.20.mlp.experts.95.down_proj.weight'] ernie.layers.21.mlp.text_fused_moe.gate.weight:ernie.layers.21.mlp.gate.weight -ernie.layers.21.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.21.mlp.moe_statics.e_score_correction_bias +ernie.layers.21.mlp.gate_correction_bias:ernie.layers.21.mlp.moe_statics.e_score_correction_bias ernie.layers.21.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.21.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.95.up_gate_proj.weight'] ernie.layers.21.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.21.mlp.experts.0.down_proj.weight', 'ernie.layers.21.mlp.experts.1.down_proj.weight', 'ernie.layers.21.mlp.experts.2.down_proj.weight', 'ernie.layers.21.mlp.experts.3.down_proj.weight', 'ernie.layers.21.mlp.experts.4.down_proj.weight', 'ernie.layers.21.mlp.experts.5.down_proj.weight', 'ernie.layers.21.mlp.experts.6.down_proj.weight', 'ernie.layers.21.mlp.experts.7.down_proj.weight', 'ernie.layers.21.mlp.experts.8.down_proj.weight', 'ernie.layers.21.mlp.experts.9.down_proj.weight', 'ernie.layers.21.mlp.experts.10.down_proj.weight', 'ernie.layers.21.mlp.experts.11.down_proj.weight', 'ernie.layers.21.mlp.experts.12.down_proj.weight', 'ernie.layers.21.mlp.experts.13.down_proj.weight', 'ernie.layers.21.mlp.experts.14.down_proj.weight', 'ernie.layers.21.mlp.experts.15.down_proj.weight', 'ernie.layers.21.mlp.experts.16.down_proj.weight', 'ernie.layers.21.mlp.experts.17.down_proj.weight', 'ernie.layers.21.mlp.experts.18.down_proj.weight', 'ernie.layers.21.mlp.experts.19.down_proj.weight', 'ernie.layers.21.mlp.experts.20.down_proj.weight', 'ernie.layers.21.mlp.experts.21.down_proj.weight', 'ernie.layers.21.mlp.experts.22.down_proj.weight', 'ernie.layers.21.mlp.experts.23.down_proj.weight', 'ernie.layers.21.mlp.experts.24.down_proj.weight', 'ernie.layers.21.mlp.experts.25.down_proj.weight', 'ernie.layers.21.mlp.experts.26.down_proj.weight', 'ernie.layers.21.mlp.experts.27.down_proj.weight', 'ernie.layers.21.mlp.experts.28.down_proj.weight', 'ernie.layers.21.mlp.experts.29.down_proj.weight', 'ernie.layers.21.mlp.experts.30.down_proj.weight', 'ernie.layers.21.mlp.experts.31.down_proj.weight', 'ernie.layers.21.mlp.experts.64.down_proj.weight', 'ernie.layers.21.mlp.experts.65.down_proj.weight', 'ernie.layers.21.mlp.experts.66.down_proj.weight', 'ernie.layers.21.mlp.experts.67.down_proj.weight', 'ernie.layers.21.mlp.experts.68.down_proj.weight', 'ernie.layers.21.mlp.experts.69.down_proj.weight', 'ernie.layers.21.mlp.experts.70.down_proj.weight', 'ernie.layers.21.mlp.experts.71.down_proj.weight', 'ernie.layers.21.mlp.experts.72.down_proj.weight', 'ernie.layers.21.mlp.experts.73.down_proj.weight', 'ernie.layers.21.mlp.experts.74.down_proj.weight', 'ernie.layers.21.mlp.experts.75.down_proj.weight', 'ernie.layers.21.mlp.experts.76.down_proj.weight', 'ernie.layers.21.mlp.experts.77.down_proj.weight', 'ernie.layers.21.mlp.experts.78.down_proj.weight', 'ernie.layers.21.mlp.experts.79.down_proj.weight', 'ernie.layers.21.mlp.experts.80.down_proj.weight', 'ernie.layers.21.mlp.experts.81.down_proj.weight', 'ernie.layers.21.mlp.experts.82.down_proj.weight', 'ernie.layers.21.mlp.experts.83.down_proj.weight', 'ernie.layers.21.mlp.experts.84.down_proj.weight', 'ernie.layers.21.mlp.experts.85.down_proj.weight', 'ernie.layers.21.mlp.experts.86.down_proj.weight', 'ernie.layers.21.mlp.experts.87.down_proj.weight', 'ernie.layers.21.mlp.experts.88.down_proj.weight', 'ernie.layers.21.mlp.experts.89.down_proj.weight', 'ernie.layers.21.mlp.experts.90.down_proj.weight', 'ernie.layers.21.mlp.experts.91.down_proj.weight', 'ernie.layers.21.mlp.experts.92.down_proj.weight', 'ernie.layers.21.mlp.experts.93.down_proj.weight', 'ernie.layers.21.mlp.experts.94.down_proj.weight', 'ernie.layers.21.mlp.experts.95.down_proj.weight'] ernie.layers.22.mlp.text_fused_moe.gate.weight:ernie.layers.22.mlp.gate.weight -ernie.layers.22.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.22.mlp.moe_statics.e_score_correction_bias +ernie.layers.22.mlp.gate_correction_bias:ernie.layers.22.mlp.moe_statics.e_score_correction_bias ernie.layers.22.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.22.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.95.up_gate_proj.weight'] ernie.layers.22.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.22.mlp.experts.0.down_proj.weight', 'ernie.layers.22.mlp.experts.1.down_proj.weight', 'ernie.layers.22.mlp.experts.2.down_proj.weight', 'ernie.layers.22.mlp.experts.3.down_proj.weight', 'ernie.layers.22.mlp.experts.4.down_proj.weight', 'ernie.layers.22.mlp.experts.5.down_proj.weight', 'ernie.layers.22.mlp.experts.6.down_proj.weight', 'ernie.layers.22.mlp.experts.7.down_proj.weight', 'ernie.layers.22.mlp.experts.8.down_proj.weight', 'ernie.layers.22.mlp.experts.9.down_proj.weight', 'ernie.layers.22.mlp.experts.10.down_proj.weight', 'ernie.layers.22.mlp.experts.11.down_proj.weight', 'ernie.layers.22.mlp.experts.12.down_proj.weight', 'ernie.layers.22.mlp.experts.13.down_proj.weight', 'ernie.layers.22.mlp.experts.14.down_proj.weight', 'ernie.layers.22.mlp.experts.15.down_proj.weight', 'ernie.layers.22.mlp.experts.16.down_proj.weight', 'ernie.layers.22.mlp.experts.17.down_proj.weight', 'ernie.layers.22.mlp.experts.18.down_proj.weight', 'ernie.layers.22.mlp.experts.19.down_proj.weight', 'ernie.layers.22.mlp.experts.20.down_proj.weight', 'ernie.layers.22.mlp.experts.21.down_proj.weight', 'ernie.layers.22.mlp.experts.22.down_proj.weight', 'ernie.layers.22.mlp.experts.23.down_proj.weight', 'ernie.layers.22.mlp.experts.24.down_proj.weight', 'ernie.layers.22.mlp.experts.25.down_proj.weight', 'ernie.layers.22.mlp.experts.26.down_proj.weight', 'ernie.layers.22.mlp.experts.27.down_proj.weight', 'ernie.layers.22.mlp.experts.28.down_proj.weight', 'ernie.layers.22.mlp.experts.29.down_proj.weight', 'ernie.layers.22.mlp.experts.30.down_proj.weight', 'ernie.layers.22.mlp.experts.31.down_proj.weight', 'ernie.layers.22.mlp.experts.64.down_proj.weight', 'ernie.layers.22.mlp.experts.65.down_proj.weight', 'ernie.layers.22.mlp.experts.66.down_proj.weight', 'ernie.layers.22.mlp.experts.67.down_proj.weight', 'ernie.layers.22.mlp.experts.68.down_proj.weight', 'ernie.layers.22.mlp.experts.69.down_proj.weight', 'ernie.layers.22.mlp.experts.70.down_proj.weight', 'ernie.layers.22.mlp.experts.71.down_proj.weight', 'ernie.layers.22.mlp.experts.72.down_proj.weight', 'ernie.layers.22.mlp.experts.73.down_proj.weight', 'ernie.layers.22.mlp.experts.74.down_proj.weight', 'ernie.layers.22.mlp.experts.75.down_proj.weight', 'ernie.layers.22.mlp.experts.76.down_proj.weight', 'ernie.layers.22.mlp.experts.77.down_proj.weight', 'ernie.layers.22.mlp.experts.78.down_proj.weight', 'ernie.layers.22.mlp.experts.79.down_proj.weight', 'ernie.layers.22.mlp.experts.80.down_proj.weight', 'ernie.layers.22.mlp.experts.81.down_proj.weight', 'ernie.layers.22.mlp.experts.82.down_proj.weight', 'ernie.layers.22.mlp.experts.83.down_proj.weight', 'ernie.layers.22.mlp.experts.84.down_proj.weight', 'ernie.layers.22.mlp.experts.85.down_proj.weight', 'ernie.layers.22.mlp.experts.86.down_proj.weight', 'ernie.layers.22.mlp.experts.87.down_proj.weight', 'ernie.layers.22.mlp.experts.88.down_proj.weight', 'ernie.layers.22.mlp.experts.89.down_proj.weight', 'ernie.layers.22.mlp.experts.90.down_proj.weight', 'ernie.layers.22.mlp.experts.91.down_proj.weight', 'ernie.layers.22.mlp.experts.92.down_proj.weight', 'ernie.layers.22.mlp.experts.93.down_proj.weight', 'ernie.layers.22.mlp.experts.94.down_proj.weight', 'ernie.layers.22.mlp.experts.95.down_proj.weight'] ernie.layers.23.mlp.text_fused_moe.gate.weight:ernie.layers.23.mlp.gate.weight -ernie.layers.23.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.23.mlp.moe_statics.e_score_correction_bias +ernie.layers.23.mlp.gate_correction_bias:ernie.layers.23.mlp.moe_statics.e_score_correction_bias ernie.layers.23.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.23.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.95.up_gate_proj.weight'] ernie.layers.23.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.23.mlp.experts.0.down_proj.weight', 'ernie.layers.23.mlp.experts.1.down_proj.weight', 'ernie.layers.23.mlp.experts.2.down_proj.weight', 'ernie.layers.23.mlp.experts.3.down_proj.weight', 'ernie.layers.23.mlp.experts.4.down_proj.weight', 'ernie.layers.23.mlp.experts.5.down_proj.weight', 'ernie.layers.23.mlp.experts.6.down_proj.weight', 'ernie.layers.23.mlp.experts.7.down_proj.weight', 'ernie.layers.23.mlp.experts.8.down_proj.weight', 'ernie.layers.23.mlp.experts.9.down_proj.weight', 'ernie.layers.23.mlp.experts.10.down_proj.weight', 'ernie.layers.23.mlp.experts.11.down_proj.weight', 'ernie.layers.23.mlp.experts.12.down_proj.weight', 'ernie.layers.23.mlp.experts.13.down_proj.weight', 'ernie.layers.23.mlp.experts.14.down_proj.weight', 'ernie.layers.23.mlp.experts.15.down_proj.weight', 'ernie.layers.23.mlp.experts.16.down_proj.weight', 'ernie.layers.23.mlp.experts.17.down_proj.weight', 'ernie.layers.23.mlp.experts.18.down_proj.weight', 'ernie.layers.23.mlp.experts.19.down_proj.weight', 'ernie.layers.23.mlp.experts.20.down_proj.weight', 'ernie.layers.23.mlp.experts.21.down_proj.weight', 'ernie.layers.23.mlp.experts.22.down_proj.weight', 'ernie.layers.23.mlp.experts.23.down_proj.weight', 'ernie.layers.23.mlp.experts.24.down_proj.weight', 'ernie.layers.23.mlp.experts.25.down_proj.weight', 'ernie.layers.23.mlp.experts.26.down_proj.weight', 'ernie.layers.23.mlp.experts.27.down_proj.weight', 'ernie.layers.23.mlp.experts.28.down_proj.weight', 'ernie.layers.23.mlp.experts.29.down_proj.weight', 'ernie.layers.23.mlp.experts.30.down_proj.weight', 'ernie.layers.23.mlp.experts.31.down_proj.weight', 'ernie.layers.23.mlp.experts.64.down_proj.weight', 'ernie.layers.23.mlp.experts.65.down_proj.weight', 'ernie.layers.23.mlp.experts.66.down_proj.weight', 'ernie.layers.23.mlp.experts.67.down_proj.weight', 'ernie.layers.23.mlp.experts.68.down_proj.weight', 'ernie.layers.23.mlp.experts.69.down_proj.weight', 'ernie.layers.23.mlp.experts.70.down_proj.weight', 'ernie.layers.23.mlp.experts.71.down_proj.weight', 'ernie.layers.23.mlp.experts.72.down_proj.weight', 'ernie.layers.23.mlp.experts.73.down_proj.weight', 'ernie.layers.23.mlp.experts.74.down_proj.weight', 'ernie.layers.23.mlp.experts.75.down_proj.weight', 'ernie.layers.23.mlp.experts.76.down_proj.weight', 'ernie.layers.23.mlp.experts.77.down_proj.weight', 'ernie.layers.23.mlp.experts.78.down_proj.weight', 'ernie.layers.23.mlp.experts.79.down_proj.weight', 'ernie.layers.23.mlp.experts.80.down_proj.weight', 'ernie.layers.23.mlp.experts.81.down_proj.weight', 'ernie.layers.23.mlp.experts.82.down_proj.weight', 'ernie.layers.23.mlp.experts.83.down_proj.weight', 'ernie.layers.23.mlp.experts.84.down_proj.weight', 'ernie.layers.23.mlp.experts.85.down_proj.weight', 'ernie.layers.23.mlp.experts.86.down_proj.weight', 'ernie.layers.23.mlp.experts.87.down_proj.weight', 'ernie.layers.23.mlp.experts.88.down_proj.weight', 'ernie.layers.23.mlp.experts.89.down_proj.weight', 'ernie.layers.23.mlp.experts.90.down_proj.weight', 'ernie.layers.23.mlp.experts.91.down_proj.weight', 'ernie.layers.23.mlp.experts.92.down_proj.weight', 'ernie.layers.23.mlp.experts.93.down_proj.weight', 'ernie.layers.23.mlp.experts.94.down_proj.weight', 'ernie.layers.23.mlp.experts.95.down_proj.weight'] ernie.layers.24.mlp.text_fused_moe.gate.weight:ernie.layers.24.mlp.gate.weight -ernie.layers.24.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.24.mlp.moe_statics.e_score_correction_bias +ernie.layers.24.mlp.gate_correction_bias:ernie.layers.24.mlp.moe_statics.e_score_correction_bias ernie.layers.24.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.24.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.95.up_gate_proj.weight'] ernie.layers.24.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.24.mlp.experts.0.down_proj.weight', 'ernie.layers.24.mlp.experts.1.down_proj.weight', 'ernie.layers.24.mlp.experts.2.down_proj.weight', 'ernie.layers.24.mlp.experts.3.down_proj.weight', 'ernie.layers.24.mlp.experts.4.down_proj.weight', 'ernie.layers.24.mlp.experts.5.down_proj.weight', 'ernie.layers.24.mlp.experts.6.down_proj.weight', 'ernie.layers.24.mlp.experts.7.down_proj.weight', 'ernie.layers.24.mlp.experts.8.down_proj.weight', 'ernie.layers.24.mlp.experts.9.down_proj.weight', 'ernie.layers.24.mlp.experts.10.down_proj.weight', 'ernie.layers.24.mlp.experts.11.down_proj.weight', 'ernie.layers.24.mlp.experts.12.down_proj.weight', 'ernie.layers.24.mlp.experts.13.down_proj.weight', 'ernie.layers.24.mlp.experts.14.down_proj.weight', 'ernie.layers.24.mlp.experts.15.down_proj.weight', 'ernie.layers.24.mlp.experts.16.down_proj.weight', 'ernie.layers.24.mlp.experts.17.down_proj.weight', 'ernie.layers.24.mlp.experts.18.down_proj.weight', 'ernie.layers.24.mlp.experts.19.down_proj.weight', 'ernie.layers.24.mlp.experts.20.down_proj.weight', 'ernie.layers.24.mlp.experts.21.down_proj.weight', 'ernie.layers.24.mlp.experts.22.down_proj.weight', 'ernie.layers.24.mlp.experts.23.down_proj.weight', 'ernie.layers.24.mlp.experts.24.down_proj.weight', 'ernie.layers.24.mlp.experts.25.down_proj.weight', 'ernie.layers.24.mlp.experts.26.down_proj.weight', 'ernie.layers.24.mlp.experts.27.down_proj.weight', 'ernie.layers.24.mlp.experts.28.down_proj.weight', 'ernie.layers.24.mlp.experts.29.down_proj.weight', 'ernie.layers.24.mlp.experts.30.down_proj.weight', 'ernie.layers.24.mlp.experts.31.down_proj.weight', 'ernie.layers.24.mlp.experts.64.down_proj.weight', 'ernie.layers.24.mlp.experts.65.down_proj.weight', 'ernie.layers.24.mlp.experts.66.down_proj.weight', 'ernie.layers.24.mlp.experts.67.down_proj.weight', 'ernie.layers.24.mlp.experts.68.down_proj.weight', 'ernie.layers.24.mlp.experts.69.down_proj.weight', 'ernie.layers.24.mlp.experts.70.down_proj.weight', 'ernie.layers.24.mlp.experts.71.down_proj.weight', 'ernie.layers.24.mlp.experts.72.down_proj.weight', 'ernie.layers.24.mlp.experts.73.down_proj.weight', 'ernie.layers.24.mlp.experts.74.down_proj.weight', 'ernie.layers.24.mlp.experts.75.down_proj.weight', 'ernie.layers.24.mlp.experts.76.down_proj.weight', 'ernie.layers.24.mlp.experts.77.down_proj.weight', 'ernie.layers.24.mlp.experts.78.down_proj.weight', 'ernie.layers.24.mlp.experts.79.down_proj.weight', 'ernie.layers.24.mlp.experts.80.down_proj.weight', 'ernie.layers.24.mlp.experts.81.down_proj.weight', 'ernie.layers.24.mlp.experts.82.down_proj.weight', 'ernie.layers.24.mlp.experts.83.down_proj.weight', 'ernie.layers.24.mlp.experts.84.down_proj.weight', 'ernie.layers.24.mlp.experts.85.down_proj.weight', 'ernie.layers.24.mlp.experts.86.down_proj.weight', 'ernie.layers.24.mlp.experts.87.down_proj.weight', 'ernie.layers.24.mlp.experts.88.down_proj.weight', 'ernie.layers.24.mlp.experts.89.down_proj.weight', 'ernie.layers.24.mlp.experts.90.down_proj.weight', 'ernie.layers.24.mlp.experts.91.down_proj.weight', 'ernie.layers.24.mlp.experts.92.down_proj.weight', 'ernie.layers.24.mlp.experts.93.down_proj.weight', 'ernie.layers.24.mlp.experts.94.down_proj.weight', 'ernie.layers.24.mlp.experts.95.down_proj.weight'] ernie.layers.25.mlp.text_fused_moe.gate.weight:ernie.layers.25.mlp.gate.weight -ernie.layers.25.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.25.mlp.moe_statics.e_score_correction_bias +ernie.layers.25.mlp.gate_correction_bias:ernie.layers.25.mlp.moe_statics.e_score_correction_bias ernie.layers.25.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.25.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.95.up_gate_proj.weight'] ernie.layers.25.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.25.mlp.experts.0.down_proj.weight', 'ernie.layers.25.mlp.experts.1.down_proj.weight', 'ernie.layers.25.mlp.experts.2.down_proj.weight', 'ernie.layers.25.mlp.experts.3.down_proj.weight', 'ernie.layers.25.mlp.experts.4.down_proj.weight', 'ernie.layers.25.mlp.experts.5.down_proj.weight', 'ernie.layers.25.mlp.experts.6.down_proj.weight', 'ernie.layers.25.mlp.experts.7.down_proj.weight', 'ernie.layers.25.mlp.experts.8.down_proj.weight', 'ernie.layers.25.mlp.experts.9.down_proj.weight', 'ernie.layers.25.mlp.experts.10.down_proj.weight', 'ernie.layers.25.mlp.experts.11.down_proj.weight', 'ernie.layers.25.mlp.experts.12.down_proj.weight', 'ernie.layers.25.mlp.experts.13.down_proj.weight', 'ernie.layers.25.mlp.experts.14.down_proj.weight', 'ernie.layers.25.mlp.experts.15.down_proj.weight', 'ernie.layers.25.mlp.experts.16.down_proj.weight', 'ernie.layers.25.mlp.experts.17.down_proj.weight', 'ernie.layers.25.mlp.experts.18.down_proj.weight', 'ernie.layers.25.mlp.experts.19.down_proj.weight', 'ernie.layers.25.mlp.experts.20.down_proj.weight', 'ernie.layers.25.mlp.experts.21.down_proj.weight', 'ernie.layers.25.mlp.experts.22.down_proj.weight', 'ernie.layers.25.mlp.experts.23.down_proj.weight', 'ernie.layers.25.mlp.experts.24.down_proj.weight', 'ernie.layers.25.mlp.experts.25.down_proj.weight', 'ernie.layers.25.mlp.experts.26.down_proj.weight', 'ernie.layers.25.mlp.experts.27.down_proj.weight', 'ernie.layers.25.mlp.experts.28.down_proj.weight', 'ernie.layers.25.mlp.experts.29.down_proj.weight', 'ernie.layers.25.mlp.experts.30.down_proj.weight', 'ernie.layers.25.mlp.experts.31.down_proj.weight', 'ernie.layers.25.mlp.experts.64.down_proj.weight', 'ernie.layers.25.mlp.experts.65.down_proj.weight', 'ernie.layers.25.mlp.experts.66.down_proj.weight', 'ernie.layers.25.mlp.experts.67.down_proj.weight', 'ernie.layers.25.mlp.experts.68.down_proj.weight', 'ernie.layers.25.mlp.experts.69.down_proj.weight', 'ernie.layers.25.mlp.experts.70.down_proj.weight', 'ernie.layers.25.mlp.experts.71.down_proj.weight', 'ernie.layers.25.mlp.experts.72.down_proj.weight', 'ernie.layers.25.mlp.experts.73.down_proj.weight', 'ernie.layers.25.mlp.experts.74.down_proj.weight', 'ernie.layers.25.mlp.experts.75.down_proj.weight', 'ernie.layers.25.mlp.experts.76.down_proj.weight', 'ernie.layers.25.mlp.experts.77.down_proj.weight', 'ernie.layers.25.mlp.experts.78.down_proj.weight', 'ernie.layers.25.mlp.experts.79.down_proj.weight', 'ernie.layers.25.mlp.experts.80.down_proj.weight', 'ernie.layers.25.mlp.experts.81.down_proj.weight', 'ernie.layers.25.mlp.experts.82.down_proj.weight', 'ernie.layers.25.mlp.experts.83.down_proj.weight', 'ernie.layers.25.mlp.experts.84.down_proj.weight', 'ernie.layers.25.mlp.experts.85.down_proj.weight', 'ernie.layers.25.mlp.experts.86.down_proj.weight', 'ernie.layers.25.mlp.experts.87.down_proj.weight', 'ernie.layers.25.mlp.experts.88.down_proj.weight', 'ernie.layers.25.mlp.experts.89.down_proj.weight', 'ernie.layers.25.mlp.experts.90.down_proj.weight', 'ernie.layers.25.mlp.experts.91.down_proj.weight', 'ernie.layers.25.mlp.experts.92.down_proj.weight', 'ernie.layers.25.mlp.experts.93.down_proj.weight', 'ernie.layers.25.mlp.experts.94.down_proj.weight', 'ernie.layers.25.mlp.experts.95.down_proj.weight'] ernie.layers.26.mlp.text_fused_moe.gate.weight:ernie.layers.26.mlp.gate.weight -ernie.layers.26.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.26.mlp.moe_statics.e_score_correction_bias +ernie.layers.26.mlp.gate_correction_bias:ernie.layers.26.mlp.moe_statics.e_score_correction_bias ernie.layers.26.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.26.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.95.up_gate_proj.weight'] ernie.layers.26.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.26.mlp.experts.0.down_proj.weight', 'ernie.layers.26.mlp.experts.1.down_proj.weight', 'ernie.layers.26.mlp.experts.2.down_proj.weight', 'ernie.layers.26.mlp.experts.3.down_proj.weight', 'ernie.layers.26.mlp.experts.4.down_proj.weight', 'ernie.layers.26.mlp.experts.5.down_proj.weight', 'ernie.layers.26.mlp.experts.6.down_proj.weight', 'ernie.layers.26.mlp.experts.7.down_proj.weight', 'ernie.layers.26.mlp.experts.8.down_proj.weight', 'ernie.layers.26.mlp.experts.9.down_proj.weight', 'ernie.layers.26.mlp.experts.10.down_proj.weight', 'ernie.layers.26.mlp.experts.11.down_proj.weight', 'ernie.layers.26.mlp.experts.12.down_proj.weight', 'ernie.layers.26.mlp.experts.13.down_proj.weight', 'ernie.layers.26.mlp.experts.14.down_proj.weight', 'ernie.layers.26.mlp.experts.15.down_proj.weight', 'ernie.layers.26.mlp.experts.16.down_proj.weight', 'ernie.layers.26.mlp.experts.17.down_proj.weight', 'ernie.layers.26.mlp.experts.18.down_proj.weight', 'ernie.layers.26.mlp.experts.19.down_proj.weight', 'ernie.layers.26.mlp.experts.20.down_proj.weight', 'ernie.layers.26.mlp.experts.21.down_proj.weight', 'ernie.layers.26.mlp.experts.22.down_proj.weight', 'ernie.layers.26.mlp.experts.23.down_proj.weight', 'ernie.layers.26.mlp.experts.24.down_proj.weight', 'ernie.layers.26.mlp.experts.25.down_proj.weight', 'ernie.layers.26.mlp.experts.26.down_proj.weight', 'ernie.layers.26.mlp.experts.27.down_proj.weight', 'ernie.layers.26.mlp.experts.28.down_proj.weight', 'ernie.layers.26.mlp.experts.29.down_proj.weight', 'ernie.layers.26.mlp.experts.30.down_proj.weight', 'ernie.layers.26.mlp.experts.31.down_proj.weight', 'ernie.layers.26.mlp.experts.64.down_proj.weight', 'ernie.layers.26.mlp.experts.65.down_proj.weight', 'ernie.layers.26.mlp.experts.66.down_proj.weight', 'ernie.layers.26.mlp.experts.67.down_proj.weight', 'ernie.layers.26.mlp.experts.68.down_proj.weight', 'ernie.layers.26.mlp.experts.69.down_proj.weight', 'ernie.layers.26.mlp.experts.70.down_proj.weight', 'ernie.layers.26.mlp.experts.71.down_proj.weight', 'ernie.layers.26.mlp.experts.72.down_proj.weight', 'ernie.layers.26.mlp.experts.73.down_proj.weight', 'ernie.layers.26.mlp.experts.74.down_proj.weight', 'ernie.layers.26.mlp.experts.75.down_proj.weight', 'ernie.layers.26.mlp.experts.76.down_proj.weight', 'ernie.layers.26.mlp.experts.77.down_proj.weight', 'ernie.layers.26.mlp.experts.78.down_proj.weight', 'ernie.layers.26.mlp.experts.79.down_proj.weight', 'ernie.layers.26.mlp.experts.80.down_proj.weight', 'ernie.layers.26.mlp.experts.81.down_proj.weight', 'ernie.layers.26.mlp.experts.82.down_proj.weight', 'ernie.layers.26.mlp.experts.83.down_proj.weight', 'ernie.layers.26.mlp.experts.84.down_proj.weight', 'ernie.layers.26.mlp.experts.85.down_proj.weight', 'ernie.layers.26.mlp.experts.86.down_proj.weight', 'ernie.layers.26.mlp.experts.87.down_proj.weight', 'ernie.layers.26.mlp.experts.88.down_proj.weight', 'ernie.layers.26.mlp.experts.89.down_proj.weight', 'ernie.layers.26.mlp.experts.90.down_proj.weight', 'ernie.layers.26.mlp.experts.91.down_proj.weight', 'ernie.layers.26.mlp.experts.92.down_proj.weight', 'ernie.layers.26.mlp.experts.93.down_proj.weight', 'ernie.layers.26.mlp.experts.94.down_proj.weight', 'ernie.layers.26.mlp.experts.95.down_proj.weight'] ernie.layers.27.mlp.text_fused_moe.gate.weight:ernie.layers.27.mlp.gate.weight -ernie.layers.27.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.27.mlp.moe_statics.e_score_correction_bias +ernie.layers.27.mlp.gate_correction_bias:ernie.layers.27.mlp.moe_statics.e_score_correction_bias ernie.layers.27.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.27.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.95.up_gate_proj.weight'] ernie.layers.27.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.27.mlp.experts.0.down_proj.weight', 'ernie.layers.27.mlp.experts.1.down_proj.weight', 'ernie.layers.27.mlp.experts.2.down_proj.weight', 'ernie.layers.27.mlp.experts.3.down_proj.weight', 'ernie.layers.27.mlp.experts.4.down_proj.weight', 'ernie.layers.27.mlp.experts.5.down_proj.weight', 'ernie.layers.27.mlp.experts.6.down_proj.weight', 'ernie.layers.27.mlp.experts.7.down_proj.weight', 'ernie.layers.27.mlp.experts.8.down_proj.weight', 'ernie.layers.27.mlp.experts.9.down_proj.weight', 'ernie.layers.27.mlp.experts.10.down_proj.weight', 'ernie.layers.27.mlp.experts.11.down_proj.weight', 'ernie.layers.27.mlp.experts.12.down_proj.weight', 'ernie.layers.27.mlp.experts.13.down_proj.weight', 'ernie.layers.27.mlp.experts.14.down_proj.weight', 'ernie.layers.27.mlp.experts.15.down_proj.weight', 'ernie.layers.27.mlp.experts.16.down_proj.weight', 'ernie.layers.27.mlp.experts.17.down_proj.weight', 'ernie.layers.27.mlp.experts.18.down_proj.weight', 'ernie.layers.27.mlp.experts.19.down_proj.weight', 'ernie.layers.27.mlp.experts.20.down_proj.weight', 'ernie.layers.27.mlp.experts.21.down_proj.weight', 'ernie.layers.27.mlp.experts.22.down_proj.weight', 'ernie.layers.27.mlp.experts.23.down_proj.weight', 'ernie.layers.27.mlp.experts.24.down_proj.weight', 'ernie.layers.27.mlp.experts.25.down_proj.weight', 'ernie.layers.27.mlp.experts.26.down_proj.weight', 'ernie.layers.27.mlp.experts.27.down_proj.weight', 'ernie.layers.27.mlp.experts.28.down_proj.weight', 'ernie.layers.27.mlp.experts.29.down_proj.weight', 'ernie.layers.27.mlp.experts.30.down_proj.weight', 'ernie.layers.27.mlp.experts.31.down_proj.weight', 'ernie.layers.27.mlp.experts.64.down_proj.weight', 'ernie.layers.27.mlp.experts.65.down_proj.weight', 'ernie.layers.27.mlp.experts.66.down_proj.weight', 'ernie.layers.27.mlp.experts.67.down_proj.weight', 'ernie.layers.27.mlp.experts.68.down_proj.weight', 'ernie.layers.27.mlp.experts.69.down_proj.weight', 'ernie.layers.27.mlp.experts.70.down_proj.weight', 'ernie.layers.27.mlp.experts.71.down_proj.weight', 'ernie.layers.27.mlp.experts.72.down_proj.weight', 'ernie.layers.27.mlp.experts.73.down_proj.weight', 'ernie.layers.27.mlp.experts.74.down_proj.weight', 'ernie.layers.27.mlp.experts.75.down_proj.weight', 'ernie.layers.27.mlp.experts.76.down_proj.weight', 'ernie.layers.27.mlp.experts.77.down_proj.weight', 'ernie.layers.27.mlp.experts.78.down_proj.weight', 'ernie.layers.27.mlp.experts.79.down_proj.weight', 'ernie.layers.27.mlp.experts.80.down_proj.weight', 'ernie.layers.27.mlp.experts.81.down_proj.weight', 'ernie.layers.27.mlp.experts.82.down_proj.weight', 'ernie.layers.27.mlp.experts.83.down_proj.weight', 'ernie.layers.27.mlp.experts.84.down_proj.weight', 'ernie.layers.27.mlp.experts.85.down_proj.weight', 'ernie.layers.27.mlp.experts.86.down_proj.weight', 'ernie.layers.27.mlp.experts.87.down_proj.weight', 'ernie.layers.27.mlp.experts.88.down_proj.weight', 'ernie.layers.27.mlp.experts.89.down_proj.weight', 'ernie.layers.27.mlp.experts.90.down_proj.weight', 'ernie.layers.27.mlp.experts.91.down_proj.weight', 'ernie.layers.27.mlp.experts.92.down_proj.weight', 'ernie.layers.27.mlp.experts.93.down_proj.weight', 'ernie.layers.27.mlp.experts.94.down_proj.weight', 'ernie.layers.27.mlp.experts.95.down_proj.weight'] ernie.layers.28.mlp.text_fused_moe.gate.weight:ernie.layers.28.mlp.gate.weight -ernie.layers.28.mlp.text_fused_moe.experts.gate_correction_bias:ernie.layers.28.mlp.moe_statics.e_score_correction_bias +ernie.layers.28.mlp.gate_correction_bias:ernie.layers.28.mlp.moe_statics.e_score_correction_bias ernie.layers.28.mlp.text_fused_moe.experts.up_gate_proj_weight:['ernie.layers.28.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.28.mlp.experts.95.up_gate_proj.weight'] ernie.layers.28.mlp.text_fused_moe.experts.down_proj_weight:['ernie.layers.28.mlp.experts.0.down_proj.weight', 'ernie.layers.28.mlp.experts.1.down_proj.weight', 'ernie.layers.28.mlp.experts.2.down_proj.weight', 'ernie.layers.28.mlp.experts.3.down_proj.weight', 'ernie.layers.28.mlp.experts.4.down_proj.weight', 'ernie.layers.28.mlp.experts.5.down_proj.weight', 'ernie.layers.28.mlp.experts.6.down_proj.weight', 'ernie.layers.28.mlp.experts.7.down_proj.weight', 'ernie.layers.28.mlp.experts.8.down_proj.weight', 'ernie.layers.28.mlp.experts.9.down_proj.weight', 'ernie.layers.28.mlp.experts.10.down_proj.weight', 'ernie.layers.28.mlp.experts.11.down_proj.weight', 'ernie.layers.28.mlp.experts.12.down_proj.weight', 'ernie.layers.28.mlp.experts.13.down_proj.weight', 'ernie.layers.28.mlp.experts.14.down_proj.weight', 'ernie.layers.28.mlp.experts.15.down_proj.weight', 'ernie.layers.28.mlp.experts.16.down_proj.weight', 'ernie.layers.28.mlp.experts.17.down_proj.weight', 'ernie.layers.28.mlp.experts.18.down_proj.weight', 'ernie.layers.28.mlp.experts.19.down_proj.weight', 'ernie.layers.28.mlp.experts.20.down_proj.weight', 'ernie.layers.28.mlp.experts.21.down_proj.weight', 'ernie.layers.28.mlp.experts.22.down_proj.weight', 'ernie.layers.28.mlp.experts.23.down_proj.weight', 'ernie.layers.28.mlp.experts.24.down_proj.weight', 'ernie.layers.28.mlp.experts.25.down_proj.weight', 'ernie.layers.28.mlp.experts.26.down_proj.weight', 'ernie.layers.28.mlp.experts.27.down_proj.weight', 'ernie.layers.28.mlp.experts.28.down_proj.weight', 'ernie.layers.28.mlp.experts.29.down_proj.weight', 'ernie.layers.28.mlp.experts.30.down_proj.weight', 'ernie.layers.28.mlp.experts.31.down_proj.weight', 'ernie.layers.28.mlp.experts.64.down_proj.weight', 'ernie.layers.28.mlp.experts.65.down_proj.weight', 'ernie.layers.28.mlp.experts.66.down_proj.weight', 'ernie.layers.28.mlp.experts.67.down_proj.weight', 'ernie.layers.28.mlp.experts.68.down_proj.weight', 'ernie.layers.28.mlp.experts.69.down_proj.weight', 'ernie.layers.28.mlp.experts.70.down_proj.weight', 'ernie.layers.28.mlp.experts.71.down_proj.weight', 'ernie.layers.28.mlp.experts.72.down_proj.weight', 'ernie.layers.28.mlp.experts.73.down_proj.weight', 'ernie.layers.28.mlp.experts.74.down_proj.weight', 'ernie.layers.28.mlp.experts.75.down_proj.weight', 'ernie.layers.28.mlp.experts.76.down_proj.weight', 'ernie.layers.28.mlp.experts.77.down_proj.weight', 'ernie.layers.28.mlp.experts.78.down_proj.weight', 'ernie.layers.28.mlp.experts.79.down_proj.weight', 'ernie.layers.28.mlp.experts.80.down_proj.weight', 'ernie.layers.28.mlp.experts.81.down_proj.weight', 'ernie.layers.28.mlp.experts.82.down_proj.weight', 'ernie.layers.28.mlp.experts.83.down_proj.weight', 'ernie.layers.28.mlp.experts.84.down_proj.weight', 'ernie.layers.28.mlp.experts.85.down_proj.weight', 'ernie.layers.28.mlp.experts.86.down_proj.weight', 'ernie.layers.28.mlp.experts.87.down_proj.weight', 'ernie.layers.28.mlp.experts.88.down_proj.weight', 'ernie.layers.28.mlp.experts.89.down_proj.weight', 'ernie.layers.28.mlp.experts.90.down_proj.weight', 'ernie.layers.28.mlp.experts.91.down_proj.weight', 'ernie.layers.28.mlp.experts.92.down_proj.weight', 'ernie.layers.28.mlp.experts.93.down_proj.weight', 'ernie.layers.28.mlp.experts.94.down_proj.weight', 'ernie.layers.28.mlp.experts.95.down_proj.weight'] ernie.layers.1.mlp.image_fused_moe.gate.weight:ernie.layers.1.mlp.gate.weight_1 -ernie.layers.1.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias ernie.layers.1.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.1.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.127.up_gate_proj.weight'] ernie.layers.1.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.1.mlp.experts.32.down_proj.weight', 'ernie.layers.1.mlp.experts.33.down_proj.weight', 'ernie.layers.1.mlp.experts.34.down_proj.weight', 'ernie.layers.1.mlp.experts.35.down_proj.weight', 'ernie.layers.1.mlp.experts.36.down_proj.weight', 'ernie.layers.1.mlp.experts.37.down_proj.weight', 'ernie.layers.1.mlp.experts.38.down_proj.weight', 'ernie.layers.1.mlp.experts.39.down_proj.weight', 'ernie.layers.1.mlp.experts.40.down_proj.weight', 'ernie.layers.1.mlp.experts.41.down_proj.weight', 'ernie.layers.1.mlp.experts.42.down_proj.weight', 'ernie.layers.1.mlp.experts.43.down_proj.weight', 'ernie.layers.1.mlp.experts.44.down_proj.weight', 'ernie.layers.1.mlp.experts.45.down_proj.weight', 'ernie.layers.1.mlp.experts.46.down_proj.weight', 'ernie.layers.1.mlp.experts.47.down_proj.weight', 'ernie.layers.1.mlp.experts.48.down_proj.weight', 'ernie.layers.1.mlp.experts.49.down_proj.weight', 'ernie.layers.1.mlp.experts.50.down_proj.weight', 'ernie.layers.1.mlp.experts.51.down_proj.weight', 'ernie.layers.1.mlp.experts.52.down_proj.weight', 'ernie.layers.1.mlp.experts.53.down_proj.weight', 'ernie.layers.1.mlp.experts.54.down_proj.weight', 'ernie.layers.1.mlp.experts.55.down_proj.weight', 'ernie.layers.1.mlp.experts.56.down_proj.weight', 'ernie.layers.1.mlp.experts.57.down_proj.weight', 'ernie.layers.1.mlp.experts.58.down_proj.weight', 'ernie.layers.1.mlp.experts.59.down_proj.weight', 'ernie.layers.1.mlp.experts.60.down_proj.weight', 'ernie.layers.1.mlp.experts.61.down_proj.weight', 'ernie.layers.1.mlp.experts.62.down_proj.weight', 'ernie.layers.1.mlp.experts.63.down_proj.weight', 'ernie.layers.1.mlp.experts.96.down_proj.weight', 'ernie.layers.1.mlp.experts.97.down_proj.weight', 'ernie.layers.1.mlp.experts.98.down_proj.weight', 'ernie.layers.1.mlp.experts.99.down_proj.weight', 'ernie.layers.1.mlp.experts.100.down_proj.weight', 'ernie.layers.1.mlp.experts.101.down_proj.weight', 'ernie.layers.1.mlp.experts.102.down_proj.weight', 'ernie.layers.1.mlp.experts.103.down_proj.weight', 'ernie.layers.1.mlp.experts.104.down_proj.weight', 'ernie.layers.1.mlp.experts.105.down_proj.weight', 'ernie.layers.1.mlp.experts.106.down_proj.weight', 'ernie.layers.1.mlp.experts.107.down_proj.weight', 'ernie.layers.1.mlp.experts.108.down_proj.weight', 'ernie.layers.1.mlp.experts.109.down_proj.weight', 'ernie.layers.1.mlp.experts.110.down_proj.weight', 'ernie.layers.1.mlp.experts.111.down_proj.weight', 'ernie.layers.1.mlp.experts.112.down_proj.weight', 'ernie.layers.1.mlp.experts.113.down_proj.weight', 'ernie.layers.1.mlp.experts.114.down_proj.weight', 'ernie.layers.1.mlp.experts.115.down_proj.weight', 'ernie.layers.1.mlp.experts.116.down_proj.weight', 'ernie.layers.1.mlp.experts.117.down_proj.weight', 'ernie.layers.1.mlp.experts.118.down_proj.weight', 'ernie.layers.1.mlp.experts.119.down_proj.weight', 'ernie.layers.1.mlp.experts.120.down_proj.weight', 'ernie.layers.1.mlp.experts.121.down_proj.weight', 'ernie.layers.1.mlp.experts.122.down_proj.weight', 'ernie.layers.1.mlp.experts.123.down_proj.weight', 'ernie.layers.1.mlp.experts.124.down_proj.weight', 'ernie.layers.1.mlp.experts.125.down_proj.weight', 'ernie.layers.1.mlp.experts.126.down_proj.weight', 'ernie.layers.1.mlp.experts.127.down_proj.weight'] ernie.layers.2.mlp.image_fused_moe.gate.weight:ernie.layers.2.mlp.gate.weight_1 -ernie.layers.2.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.2.mlp.moe_statics.e_score_correction_bias ernie.layers.2.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.2.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.2.mlp.experts.127.up_gate_proj.weight'] ernie.layers.2.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.2.mlp.experts.32.down_proj.weight', 'ernie.layers.2.mlp.experts.33.down_proj.weight', 'ernie.layers.2.mlp.experts.34.down_proj.weight', 'ernie.layers.2.mlp.experts.35.down_proj.weight', 'ernie.layers.2.mlp.experts.36.down_proj.weight', 'ernie.layers.2.mlp.experts.37.down_proj.weight', 'ernie.layers.2.mlp.experts.38.down_proj.weight', 'ernie.layers.2.mlp.experts.39.down_proj.weight', 'ernie.layers.2.mlp.experts.40.down_proj.weight', 'ernie.layers.2.mlp.experts.41.down_proj.weight', 'ernie.layers.2.mlp.experts.42.down_proj.weight', 'ernie.layers.2.mlp.experts.43.down_proj.weight', 'ernie.layers.2.mlp.experts.44.down_proj.weight', 'ernie.layers.2.mlp.experts.45.down_proj.weight', 'ernie.layers.2.mlp.experts.46.down_proj.weight', 'ernie.layers.2.mlp.experts.47.down_proj.weight', 'ernie.layers.2.mlp.experts.48.down_proj.weight', 'ernie.layers.2.mlp.experts.49.down_proj.weight', 'ernie.layers.2.mlp.experts.50.down_proj.weight', 'ernie.layers.2.mlp.experts.51.down_proj.weight', 'ernie.layers.2.mlp.experts.52.down_proj.weight', 'ernie.layers.2.mlp.experts.53.down_proj.weight', 'ernie.layers.2.mlp.experts.54.down_proj.weight', 'ernie.layers.2.mlp.experts.55.down_proj.weight', 'ernie.layers.2.mlp.experts.56.down_proj.weight', 'ernie.layers.2.mlp.experts.57.down_proj.weight', 'ernie.layers.2.mlp.experts.58.down_proj.weight', 'ernie.layers.2.mlp.experts.59.down_proj.weight', 'ernie.layers.2.mlp.experts.60.down_proj.weight', 'ernie.layers.2.mlp.experts.61.down_proj.weight', 'ernie.layers.2.mlp.experts.62.down_proj.weight', 'ernie.layers.2.mlp.experts.63.down_proj.weight', 'ernie.layers.2.mlp.experts.96.down_proj.weight', 'ernie.layers.2.mlp.experts.97.down_proj.weight', 'ernie.layers.2.mlp.experts.98.down_proj.weight', 'ernie.layers.2.mlp.experts.99.down_proj.weight', 'ernie.layers.2.mlp.experts.100.down_proj.weight', 'ernie.layers.2.mlp.experts.101.down_proj.weight', 'ernie.layers.2.mlp.experts.102.down_proj.weight', 'ernie.layers.2.mlp.experts.103.down_proj.weight', 'ernie.layers.2.mlp.experts.104.down_proj.weight', 'ernie.layers.2.mlp.experts.105.down_proj.weight', 'ernie.layers.2.mlp.experts.106.down_proj.weight', 'ernie.layers.2.mlp.experts.107.down_proj.weight', 'ernie.layers.2.mlp.experts.108.down_proj.weight', 'ernie.layers.2.mlp.experts.109.down_proj.weight', 'ernie.layers.2.mlp.experts.110.down_proj.weight', 'ernie.layers.2.mlp.experts.111.down_proj.weight', 'ernie.layers.2.mlp.experts.112.down_proj.weight', 'ernie.layers.2.mlp.experts.113.down_proj.weight', 'ernie.layers.2.mlp.experts.114.down_proj.weight', 'ernie.layers.2.mlp.experts.115.down_proj.weight', 'ernie.layers.2.mlp.experts.116.down_proj.weight', 'ernie.layers.2.mlp.experts.117.down_proj.weight', 'ernie.layers.2.mlp.experts.118.down_proj.weight', 'ernie.layers.2.mlp.experts.119.down_proj.weight', 'ernie.layers.2.mlp.experts.120.down_proj.weight', 'ernie.layers.2.mlp.experts.121.down_proj.weight', 'ernie.layers.2.mlp.experts.122.down_proj.weight', 'ernie.layers.2.mlp.experts.123.down_proj.weight', 'ernie.layers.2.mlp.experts.124.down_proj.weight', 'ernie.layers.2.mlp.experts.125.down_proj.weight', 'ernie.layers.2.mlp.experts.126.down_proj.weight', 'ernie.layers.2.mlp.experts.127.down_proj.weight'] ernie.layers.3.mlp.image_fused_moe.gate.weight:ernie.layers.3.mlp.gate.weight_1 -ernie.layers.3.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.3.mlp.moe_statics.e_score_correction_bias ernie.layers.3.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.3.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.3.mlp.experts.127.up_gate_proj.weight'] ernie.layers.3.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.3.mlp.experts.32.down_proj.weight', 'ernie.layers.3.mlp.experts.33.down_proj.weight', 'ernie.layers.3.mlp.experts.34.down_proj.weight', 'ernie.layers.3.mlp.experts.35.down_proj.weight', 'ernie.layers.3.mlp.experts.36.down_proj.weight', 'ernie.layers.3.mlp.experts.37.down_proj.weight', 'ernie.layers.3.mlp.experts.38.down_proj.weight', 'ernie.layers.3.mlp.experts.39.down_proj.weight', 'ernie.layers.3.mlp.experts.40.down_proj.weight', 'ernie.layers.3.mlp.experts.41.down_proj.weight', 'ernie.layers.3.mlp.experts.42.down_proj.weight', 'ernie.layers.3.mlp.experts.43.down_proj.weight', 'ernie.layers.3.mlp.experts.44.down_proj.weight', 'ernie.layers.3.mlp.experts.45.down_proj.weight', 'ernie.layers.3.mlp.experts.46.down_proj.weight', 'ernie.layers.3.mlp.experts.47.down_proj.weight', 'ernie.layers.3.mlp.experts.48.down_proj.weight', 'ernie.layers.3.mlp.experts.49.down_proj.weight', 'ernie.layers.3.mlp.experts.50.down_proj.weight', 'ernie.layers.3.mlp.experts.51.down_proj.weight', 'ernie.layers.3.mlp.experts.52.down_proj.weight', 'ernie.layers.3.mlp.experts.53.down_proj.weight', 'ernie.layers.3.mlp.experts.54.down_proj.weight', 'ernie.layers.3.mlp.experts.55.down_proj.weight', 'ernie.layers.3.mlp.experts.56.down_proj.weight', 'ernie.layers.3.mlp.experts.57.down_proj.weight', 'ernie.layers.3.mlp.experts.58.down_proj.weight', 'ernie.layers.3.mlp.experts.59.down_proj.weight', 'ernie.layers.3.mlp.experts.60.down_proj.weight', 'ernie.layers.3.mlp.experts.61.down_proj.weight', 'ernie.layers.3.mlp.experts.62.down_proj.weight', 'ernie.layers.3.mlp.experts.63.down_proj.weight', 'ernie.layers.3.mlp.experts.96.down_proj.weight', 'ernie.layers.3.mlp.experts.97.down_proj.weight', 'ernie.layers.3.mlp.experts.98.down_proj.weight', 'ernie.layers.3.mlp.experts.99.down_proj.weight', 'ernie.layers.3.mlp.experts.100.down_proj.weight', 'ernie.layers.3.mlp.experts.101.down_proj.weight', 'ernie.layers.3.mlp.experts.102.down_proj.weight', 'ernie.layers.3.mlp.experts.103.down_proj.weight', 'ernie.layers.3.mlp.experts.104.down_proj.weight', 'ernie.layers.3.mlp.experts.105.down_proj.weight', 'ernie.layers.3.mlp.experts.106.down_proj.weight', 'ernie.layers.3.mlp.experts.107.down_proj.weight', 'ernie.layers.3.mlp.experts.108.down_proj.weight', 'ernie.layers.3.mlp.experts.109.down_proj.weight', 'ernie.layers.3.mlp.experts.110.down_proj.weight', 'ernie.layers.3.mlp.experts.111.down_proj.weight', 'ernie.layers.3.mlp.experts.112.down_proj.weight', 'ernie.layers.3.mlp.experts.113.down_proj.weight', 'ernie.layers.3.mlp.experts.114.down_proj.weight', 'ernie.layers.3.mlp.experts.115.down_proj.weight', 'ernie.layers.3.mlp.experts.116.down_proj.weight', 'ernie.layers.3.mlp.experts.117.down_proj.weight', 'ernie.layers.3.mlp.experts.118.down_proj.weight', 'ernie.layers.3.mlp.experts.119.down_proj.weight', 'ernie.layers.3.mlp.experts.120.down_proj.weight', 'ernie.layers.3.mlp.experts.121.down_proj.weight', 'ernie.layers.3.mlp.experts.122.down_proj.weight', 'ernie.layers.3.mlp.experts.123.down_proj.weight', 'ernie.layers.3.mlp.experts.124.down_proj.weight', 'ernie.layers.3.mlp.experts.125.down_proj.weight', 'ernie.layers.3.mlp.experts.126.down_proj.weight', 'ernie.layers.3.mlp.experts.127.down_proj.weight'] ernie.layers.4.mlp.image_fused_moe.gate.weight:ernie.layers.4.mlp.gate.weight_1 -ernie.layers.4.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.4.mlp.moe_statics.e_score_correction_bias ernie.layers.4.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.4.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.4.mlp.experts.127.up_gate_proj.weight'] ernie.layers.4.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.4.mlp.experts.32.down_proj.weight', 'ernie.layers.4.mlp.experts.33.down_proj.weight', 'ernie.layers.4.mlp.experts.34.down_proj.weight', 'ernie.layers.4.mlp.experts.35.down_proj.weight', 'ernie.layers.4.mlp.experts.36.down_proj.weight', 'ernie.layers.4.mlp.experts.37.down_proj.weight', 'ernie.layers.4.mlp.experts.38.down_proj.weight', 'ernie.layers.4.mlp.experts.39.down_proj.weight', 'ernie.layers.4.mlp.experts.40.down_proj.weight', 'ernie.layers.4.mlp.experts.41.down_proj.weight', 'ernie.layers.4.mlp.experts.42.down_proj.weight', 'ernie.layers.4.mlp.experts.43.down_proj.weight', 'ernie.layers.4.mlp.experts.44.down_proj.weight', 'ernie.layers.4.mlp.experts.45.down_proj.weight', 'ernie.layers.4.mlp.experts.46.down_proj.weight', 'ernie.layers.4.mlp.experts.47.down_proj.weight', 'ernie.layers.4.mlp.experts.48.down_proj.weight', 'ernie.layers.4.mlp.experts.49.down_proj.weight', 'ernie.layers.4.mlp.experts.50.down_proj.weight', 'ernie.layers.4.mlp.experts.51.down_proj.weight', 'ernie.layers.4.mlp.experts.52.down_proj.weight', 'ernie.layers.4.mlp.experts.53.down_proj.weight', 'ernie.layers.4.mlp.experts.54.down_proj.weight', 'ernie.layers.4.mlp.experts.55.down_proj.weight', 'ernie.layers.4.mlp.experts.56.down_proj.weight', 'ernie.layers.4.mlp.experts.57.down_proj.weight', 'ernie.layers.4.mlp.experts.58.down_proj.weight', 'ernie.layers.4.mlp.experts.59.down_proj.weight', 'ernie.layers.4.mlp.experts.60.down_proj.weight', 'ernie.layers.4.mlp.experts.61.down_proj.weight', 'ernie.layers.4.mlp.experts.62.down_proj.weight', 'ernie.layers.4.mlp.experts.63.down_proj.weight', 'ernie.layers.4.mlp.experts.96.down_proj.weight', 'ernie.layers.4.mlp.experts.97.down_proj.weight', 'ernie.layers.4.mlp.experts.98.down_proj.weight', 'ernie.layers.4.mlp.experts.99.down_proj.weight', 'ernie.layers.4.mlp.experts.100.down_proj.weight', 'ernie.layers.4.mlp.experts.101.down_proj.weight', 'ernie.layers.4.mlp.experts.102.down_proj.weight', 'ernie.layers.4.mlp.experts.103.down_proj.weight', 'ernie.layers.4.mlp.experts.104.down_proj.weight', 'ernie.layers.4.mlp.experts.105.down_proj.weight', 'ernie.layers.4.mlp.experts.106.down_proj.weight', 'ernie.layers.4.mlp.experts.107.down_proj.weight', 'ernie.layers.4.mlp.experts.108.down_proj.weight', 'ernie.layers.4.mlp.experts.109.down_proj.weight', 'ernie.layers.4.mlp.experts.110.down_proj.weight', 'ernie.layers.4.mlp.experts.111.down_proj.weight', 'ernie.layers.4.mlp.experts.112.down_proj.weight', 'ernie.layers.4.mlp.experts.113.down_proj.weight', 'ernie.layers.4.mlp.experts.114.down_proj.weight', 'ernie.layers.4.mlp.experts.115.down_proj.weight', 'ernie.layers.4.mlp.experts.116.down_proj.weight', 'ernie.layers.4.mlp.experts.117.down_proj.weight', 'ernie.layers.4.mlp.experts.118.down_proj.weight', 'ernie.layers.4.mlp.experts.119.down_proj.weight', 'ernie.layers.4.mlp.experts.120.down_proj.weight', 'ernie.layers.4.mlp.experts.121.down_proj.weight', 'ernie.layers.4.mlp.experts.122.down_proj.weight', 'ernie.layers.4.mlp.experts.123.down_proj.weight', 'ernie.layers.4.mlp.experts.124.down_proj.weight', 'ernie.layers.4.mlp.experts.125.down_proj.weight', 'ernie.layers.4.mlp.experts.126.down_proj.weight', 'ernie.layers.4.mlp.experts.127.down_proj.weight'] ernie.layers.5.mlp.image_fused_moe.gate.weight:ernie.layers.5.mlp.gate.weight_1 -ernie.layers.5.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.5.mlp.moe_statics.e_score_correction_bias ernie.layers.5.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.5.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.5.mlp.experts.127.up_gate_proj.weight'] ernie.layers.5.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.5.mlp.experts.32.down_proj.weight', 'ernie.layers.5.mlp.experts.33.down_proj.weight', 'ernie.layers.5.mlp.experts.34.down_proj.weight', 'ernie.layers.5.mlp.experts.35.down_proj.weight', 'ernie.layers.5.mlp.experts.36.down_proj.weight', 'ernie.layers.5.mlp.experts.37.down_proj.weight', 'ernie.layers.5.mlp.experts.38.down_proj.weight', 'ernie.layers.5.mlp.experts.39.down_proj.weight', 'ernie.layers.5.mlp.experts.40.down_proj.weight', 'ernie.layers.5.mlp.experts.41.down_proj.weight', 'ernie.layers.5.mlp.experts.42.down_proj.weight', 'ernie.layers.5.mlp.experts.43.down_proj.weight', 'ernie.layers.5.mlp.experts.44.down_proj.weight', 'ernie.layers.5.mlp.experts.45.down_proj.weight', 'ernie.layers.5.mlp.experts.46.down_proj.weight', 'ernie.layers.5.mlp.experts.47.down_proj.weight', 'ernie.layers.5.mlp.experts.48.down_proj.weight', 'ernie.layers.5.mlp.experts.49.down_proj.weight', 'ernie.layers.5.mlp.experts.50.down_proj.weight', 'ernie.layers.5.mlp.experts.51.down_proj.weight', 'ernie.layers.5.mlp.experts.52.down_proj.weight', 'ernie.layers.5.mlp.experts.53.down_proj.weight', 'ernie.layers.5.mlp.experts.54.down_proj.weight', 'ernie.layers.5.mlp.experts.55.down_proj.weight', 'ernie.layers.5.mlp.experts.56.down_proj.weight', 'ernie.layers.5.mlp.experts.57.down_proj.weight', 'ernie.layers.5.mlp.experts.58.down_proj.weight', 'ernie.layers.5.mlp.experts.59.down_proj.weight', 'ernie.layers.5.mlp.experts.60.down_proj.weight', 'ernie.layers.5.mlp.experts.61.down_proj.weight', 'ernie.layers.5.mlp.experts.62.down_proj.weight', 'ernie.layers.5.mlp.experts.63.down_proj.weight', 'ernie.layers.5.mlp.experts.96.down_proj.weight', 'ernie.layers.5.mlp.experts.97.down_proj.weight', 'ernie.layers.5.mlp.experts.98.down_proj.weight', 'ernie.layers.5.mlp.experts.99.down_proj.weight', 'ernie.layers.5.mlp.experts.100.down_proj.weight', 'ernie.layers.5.mlp.experts.101.down_proj.weight', 'ernie.layers.5.mlp.experts.102.down_proj.weight', 'ernie.layers.5.mlp.experts.103.down_proj.weight', 'ernie.layers.5.mlp.experts.104.down_proj.weight', 'ernie.layers.5.mlp.experts.105.down_proj.weight', 'ernie.layers.5.mlp.experts.106.down_proj.weight', 'ernie.layers.5.mlp.experts.107.down_proj.weight', 'ernie.layers.5.mlp.experts.108.down_proj.weight', 'ernie.layers.5.mlp.experts.109.down_proj.weight', 'ernie.layers.5.mlp.experts.110.down_proj.weight', 'ernie.layers.5.mlp.experts.111.down_proj.weight', 'ernie.layers.5.mlp.experts.112.down_proj.weight', 'ernie.layers.5.mlp.experts.113.down_proj.weight', 'ernie.layers.5.mlp.experts.114.down_proj.weight', 'ernie.layers.5.mlp.experts.115.down_proj.weight', 'ernie.layers.5.mlp.experts.116.down_proj.weight', 'ernie.layers.5.mlp.experts.117.down_proj.weight', 'ernie.layers.5.mlp.experts.118.down_proj.weight', 'ernie.layers.5.mlp.experts.119.down_proj.weight', 'ernie.layers.5.mlp.experts.120.down_proj.weight', 'ernie.layers.5.mlp.experts.121.down_proj.weight', 'ernie.layers.5.mlp.experts.122.down_proj.weight', 'ernie.layers.5.mlp.experts.123.down_proj.weight', 'ernie.layers.5.mlp.experts.124.down_proj.weight', 'ernie.layers.5.mlp.experts.125.down_proj.weight', 'ernie.layers.5.mlp.experts.126.down_proj.weight', 'ernie.layers.5.mlp.experts.127.down_proj.weight'] ernie.layers.6.mlp.image_fused_moe.gate.weight:ernie.layers.6.mlp.gate.weight_1 -ernie.layers.6.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.6.mlp.moe_statics.e_score_correction_bias ernie.layers.6.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.6.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.6.mlp.experts.127.up_gate_proj.weight'] ernie.layers.6.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.6.mlp.experts.32.down_proj.weight', 'ernie.layers.6.mlp.experts.33.down_proj.weight', 'ernie.layers.6.mlp.experts.34.down_proj.weight', 'ernie.layers.6.mlp.experts.35.down_proj.weight', 'ernie.layers.6.mlp.experts.36.down_proj.weight', 'ernie.layers.6.mlp.experts.37.down_proj.weight', 'ernie.layers.6.mlp.experts.38.down_proj.weight', 'ernie.layers.6.mlp.experts.39.down_proj.weight', 'ernie.layers.6.mlp.experts.40.down_proj.weight', 'ernie.layers.6.mlp.experts.41.down_proj.weight', 'ernie.layers.6.mlp.experts.42.down_proj.weight', 'ernie.layers.6.mlp.experts.43.down_proj.weight', 'ernie.layers.6.mlp.experts.44.down_proj.weight', 'ernie.layers.6.mlp.experts.45.down_proj.weight', 'ernie.layers.6.mlp.experts.46.down_proj.weight', 'ernie.layers.6.mlp.experts.47.down_proj.weight', 'ernie.layers.6.mlp.experts.48.down_proj.weight', 'ernie.layers.6.mlp.experts.49.down_proj.weight', 'ernie.layers.6.mlp.experts.50.down_proj.weight', 'ernie.layers.6.mlp.experts.51.down_proj.weight', 'ernie.layers.6.mlp.experts.52.down_proj.weight', 'ernie.layers.6.mlp.experts.53.down_proj.weight', 'ernie.layers.6.mlp.experts.54.down_proj.weight', 'ernie.layers.6.mlp.experts.55.down_proj.weight', 'ernie.layers.6.mlp.experts.56.down_proj.weight', 'ernie.layers.6.mlp.experts.57.down_proj.weight', 'ernie.layers.6.mlp.experts.58.down_proj.weight', 'ernie.layers.6.mlp.experts.59.down_proj.weight', 'ernie.layers.6.mlp.experts.60.down_proj.weight', 'ernie.layers.6.mlp.experts.61.down_proj.weight', 'ernie.layers.6.mlp.experts.62.down_proj.weight', 'ernie.layers.6.mlp.experts.63.down_proj.weight', 'ernie.layers.6.mlp.experts.96.down_proj.weight', 'ernie.layers.6.mlp.experts.97.down_proj.weight', 'ernie.layers.6.mlp.experts.98.down_proj.weight', 'ernie.layers.6.mlp.experts.99.down_proj.weight', 'ernie.layers.6.mlp.experts.100.down_proj.weight', 'ernie.layers.6.mlp.experts.101.down_proj.weight', 'ernie.layers.6.mlp.experts.102.down_proj.weight', 'ernie.layers.6.mlp.experts.103.down_proj.weight', 'ernie.layers.6.mlp.experts.104.down_proj.weight', 'ernie.layers.6.mlp.experts.105.down_proj.weight', 'ernie.layers.6.mlp.experts.106.down_proj.weight', 'ernie.layers.6.mlp.experts.107.down_proj.weight', 'ernie.layers.6.mlp.experts.108.down_proj.weight', 'ernie.layers.6.mlp.experts.109.down_proj.weight', 'ernie.layers.6.mlp.experts.110.down_proj.weight', 'ernie.layers.6.mlp.experts.111.down_proj.weight', 'ernie.layers.6.mlp.experts.112.down_proj.weight', 'ernie.layers.6.mlp.experts.113.down_proj.weight', 'ernie.layers.6.mlp.experts.114.down_proj.weight', 'ernie.layers.6.mlp.experts.115.down_proj.weight', 'ernie.layers.6.mlp.experts.116.down_proj.weight', 'ernie.layers.6.mlp.experts.117.down_proj.weight', 'ernie.layers.6.mlp.experts.118.down_proj.weight', 'ernie.layers.6.mlp.experts.119.down_proj.weight', 'ernie.layers.6.mlp.experts.120.down_proj.weight', 'ernie.layers.6.mlp.experts.121.down_proj.weight', 'ernie.layers.6.mlp.experts.122.down_proj.weight', 'ernie.layers.6.mlp.experts.123.down_proj.weight', 'ernie.layers.6.mlp.experts.124.down_proj.weight', 'ernie.layers.6.mlp.experts.125.down_proj.weight', 'ernie.layers.6.mlp.experts.126.down_proj.weight', 'ernie.layers.6.mlp.experts.127.down_proj.weight'] ernie.layers.7.mlp.image_fused_moe.gate.weight:ernie.layers.7.mlp.gate.weight_1 -ernie.layers.7.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.7.mlp.moe_statics.e_score_correction_bias ernie.layers.7.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.7.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.7.mlp.experts.127.up_gate_proj.weight'] ernie.layers.7.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.7.mlp.experts.32.down_proj.weight', 'ernie.layers.7.mlp.experts.33.down_proj.weight', 'ernie.layers.7.mlp.experts.34.down_proj.weight', 'ernie.layers.7.mlp.experts.35.down_proj.weight', 'ernie.layers.7.mlp.experts.36.down_proj.weight', 'ernie.layers.7.mlp.experts.37.down_proj.weight', 'ernie.layers.7.mlp.experts.38.down_proj.weight', 'ernie.layers.7.mlp.experts.39.down_proj.weight', 'ernie.layers.7.mlp.experts.40.down_proj.weight', 'ernie.layers.7.mlp.experts.41.down_proj.weight', 'ernie.layers.7.mlp.experts.42.down_proj.weight', 'ernie.layers.7.mlp.experts.43.down_proj.weight', 'ernie.layers.7.mlp.experts.44.down_proj.weight', 'ernie.layers.7.mlp.experts.45.down_proj.weight', 'ernie.layers.7.mlp.experts.46.down_proj.weight', 'ernie.layers.7.mlp.experts.47.down_proj.weight', 'ernie.layers.7.mlp.experts.48.down_proj.weight', 'ernie.layers.7.mlp.experts.49.down_proj.weight', 'ernie.layers.7.mlp.experts.50.down_proj.weight', 'ernie.layers.7.mlp.experts.51.down_proj.weight', 'ernie.layers.7.mlp.experts.52.down_proj.weight', 'ernie.layers.7.mlp.experts.53.down_proj.weight', 'ernie.layers.7.mlp.experts.54.down_proj.weight', 'ernie.layers.7.mlp.experts.55.down_proj.weight', 'ernie.layers.7.mlp.experts.56.down_proj.weight', 'ernie.layers.7.mlp.experts.57.down_proj.weight', 'ernie.layers.7.mlp.experts.58.down_proj.weight', 'ernie.layers.7.mlp.experts.59.down_proj.weight', 'ernie.layers.7.mlp.experts.60.down_proj.weight', 'ernie.layers.7.mlp.experts.61.down_proj.weight', 'ernie.layers.7.mlp.experts.62.down_proj.weight', 'ernie.layers.7.mlp.experts.63.down_proj.weight', 'ernie.layers.7.mlp.experts.96.down_proj.weight', 'ernie.layers.7.mlp.experts.97.down_proj.weight', 'ernie.layers.7.mlp.experts.98.down_proj.weight', 'ernie.layers.7.mlp.experts.99.down_proj.weight', 'ernie.layers.7.mlp.experts.100.down_proj.weight', 'ernie.layers.7.mlp.experts.101.down_proj.weight', 'ernie.layers.7.mlp.experts.102.down_proj.weight', 'ernie.layers.7.mlp.experts.103.down_proj.weight', 'ernie.layers.7.mlp.experts.104.down_proj.weight', 'ernie.layers.7.mlp.experts.105.down_proj.weight', 'ernie.layers.7.mlp.experts.106.down_proj.weight', 'ernie.layers.7.mlp.experts.107.down_proj.weight', 'ernie.layers.7.mlp.experts.108.down_proj.weight', 'ernie.layers.7.mlp.experts.109.down_proj.weight', 'ernie.layers.7.mlp.experts.110.down_proj.weight', 'ernie.layers.7.mlp.experts.111.down_proj.weight', 'ernie.layers.7.mlp.experts.112.down_proj.weight', 'ernie.layers.7.mlp.experts.113.down_proj.weight', 'ernie.layers.7.mlp.experts.114.down_proj.weight', 'ernie.layers.7.mlp.experts.115.down_proj.weight', 'ernie.layers.7.mlp.experts.116.down_proj.weight', 'ernie.layers.7.mlp.experts.117.down_proj.weight', 'ernie.layers.7.mlp.experts.118.down_proj.weight', 'ernie.layers.7.mlp.experts.119.down_proj.weight', 'ernie.layers.7.mlp.experts.120.down_proj.weight', 'ernie.layers.7.mlp.experts.121.down_proj.weight', 'ernie.layers.7.mlp.experts.122.down_proj.weight', 'ernie.layers.7.mlp.experts.123.down_proj.weight', 'ernie.layers.7.mlp.experts.124.down_proj.weight', 'ernie.layers.7.mlp.experts.125.down_proj.weight', 'ernie.layers.7.mlp.experts.126.down_proj.weight', 'ernie.layers.7.mlp.experts.127.down_proj.weight'] ernie.layers.8.mlp.image_fused_moe.gate.weight:ernie.layers.8.mlp.gate.weight_1 -ernie.layers.8.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.8.mlp.moe_statics.e_score_correction_bias ernie.layers.8.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.8.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.8.mlp.experts.127.up_gate_proj.weight'] ernie.layers.8.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.8.mlp.experts.32.down_proj.weight', 'ernie.layers.8.mlp.experts.33.down_proj.weight', 'ernie.layers.8.mlp.experts.34.down_proj.weight', 'ernie.layers.8.mlp.experts.35.down_proj.weight', 'ernie.layers.8.mlp.experts.36.down_proj.weight', 'ernie.layers.8.mlp.experts.37.down_proj.weight', 'ernie.layers.8.mlp.experts.38.down_proj.weight', 'ernie.layers.8.mlp.experts.39.down_proj.weight', 'ernie.layers.8.mlp.experts.40.down_proj.weight', 'ernie.layers.8.mlp.experts.41.down_proj.weight', 'ernie.layers.8.mlp.experts.42.down_proj.weight', 'ernie.layers.8.mlp.experts.43.down_proj.weight', 'ernie.layers.8.mlp.experts.44.down_proj.weight', 'ernie.layers.8.mlp.experts.45.down_proj.weight', 'ernie.layers.8.mlp.experts.46.down_proj.weight', 'ernie.layers.8.mlp.experts.47.down_proj.weight', 'ernie.layers.8.mlp.experts.48.down_proj.weight', 'ernie.layers.8.mlp.experts.49.down_proj.weight', 'ernie.layers.8.mlp.experts.50.down_proj.weight', 'ernie.layers.8.mlp.experts.51.down_proj.weight', 'ernie.layers.8.mlp.experts.52.down_proj.weight', 'ernie.layers.8.mlp.experts.53.down_proj.weight', 'ernie.layers.8.mlp.experts.54.down_proj.weight', 'ernie.layers.8.mlp.experts.55.down_proj.weight', 'ernie.layers.8.mlp.experts.56.down_proj.weight', 'ernie.layers.8.mlp.experts.57.down_proj.weight', 'ernie.layers.8.mlp.experts.58.down_proj.weight', 'ernie.layers.8.mlp.experts.59.down_proj.weight', 'ernie.layers.8.mlp.experts.60.down_proj.weight', 'ernie.layers.8.mlp.experts.61.down_proj.weight', 'ernie.layers.8.mlp.experts.62.down_proj.weight', 'ernie.layers.8.mlp.experts.63.down_proj.weight', 'ernie.layers.8.mlp.experts.96.down_proj.weight', 'ernie.layers.8.mlp.experts.97.down_proj.weight', 'ernie.layers.8.mlp.experts.98.down_proj.weight', 'ernie.layers.8.mlp.experts.99.down_proj.weight', 'ernie.layers.8.mlp.experts.100.down_proj.weight', 'ernie.layers.8.mlp.experts.101.down_proj.weight', 'ernie.layers.8.mlp.experts.102.down_proj.weight', 'ernie.layers.8.mlp.experts.103.down_proj.weight', 'ernie.layers.8.mlp.experts.104.down_proj.weight', 'ernie.layers.8.mlp.experts.105.down_proj.weight', 'ernie.layers.8.mlp.experts.106.down_proj.weight', 'ernie.layers.8.mlp.experts.107.down_proj.weight', 'ernie.layers.8.mlp.experts.108.down_proj.weight', 'ernie.layers.8.mlp.experts.109.down_proj.weight', 'ernie.layers.8.mlp.experts.110.down_proj.weight', 'ernie.layers.8.mlp.experts.111.down_proj.weight', 'ernie.layers.8.mlp.experts.112.down_proj.weight', 'ernie.layers.8.mlp.experts.113.down_proj.weight', 'ernie.layers.8.mlp.experts.114.down_proj.weight', 'ernie.layers.8.mlp.experts.115.down_proj.weight', 'ernie.layers.8.mlp.experts.116.down_proj.weight', 'ernie.layers.8.mlp.experts.117.down_proj.weight', 'ernie.layers.8.mlp.experts.118.down_proj.weight', 'ernie.layers.8.mlp.experts.119.down_proj.weight', 'ernie.layers.8.mlp.experts.120.down_proj.weight', 'ernie.layers.8.mlp.experts.121.down_proj.weight', 'ernie.layers.8.mlp.experts.122.down_proj.weight', 'ernie.layers.8.mlp.experts.123.down_proj.weight', 'ernie.layers.8.mlp.experts.124.down_proj.weight', 'ernie.layers.8.mlp.experts.125.down_proj.weight', 'ernie.layers.8.mlp.experts.126.down_proj.weight', 'ernie.layers.8.mlp.experts.127.down_proj.weight'] ernie.layers.9.mlp.image_fused_moe.gate.weight:ernie.layers.9.mlp.gate.weight_1 -ernie.layers.9.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.9.mlp.moe_statics.e_score_correction_bias ernie.layers.9.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.9.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.9.mlp.experts.127.up_gate_proj.weight'] ernie.layers.9.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.9.mlp.experts.32.down_proj.weight', 'ernie.layers.9.mlp.experts.33.down_proj.weight', 'ernie.layers.9.mlp.experts.34.down_proj.weight', 'ernie.layers.9.mlp.experts.35.down_proj.weight', 'ernie.layers.9.mlp.experts.36.down_proj.weight', 'ernie.layers.9.mlp.experts.37.down_proj.weight', 'ernie.layers.9.mlp.experts.38.down_proj.weight', 'ernie.layers.9.mlp.experts.39.down_proj.weight', 'ernie.layers.9.mlp.experts.40.down_proj.weight', 'ernie.layers.9.mlp.experts.41.down_proj.weight', 'ernie.layers.9.mlp.experts.42.down_proj.weight', 'ernie.layers.9.mlp.experts.43.down_proj.weight', 'ernie.layers.9.mlp.experts.44.down_proj.weight', 'ernie.layers.9.mlp.experts.45.down_proj.weight', 'ernie.layers.9.mlp.experts.46.down_proj.weight', 'ernie.layers.9.mlp.experts.47.down_proj.weight', 'ernie.layers.9.mlp.experts.48.down_proj.weight', 'ernie.layers.9.mlp.experts.49.down_proj.weight', 'ernie.layers.9.mlp.experts.50.down_proj.weight', 'ernie.layers.9.mlp.experts.51.down_proj.weight', 'ernie.layers.9.mlp.experts.52.down_proj.weight', 'ernie.layers.9.mlp.experts.53.down_proj.weight', 'ernie.layers.9.mlp.experts.54.down_proj.weight', 'ernie.layers.9.mlp.experts.55.down_proj.weight', 'ernie.layers.9.mlp.experts.56.down_proj.weight', 'ernie.layers.9.mlp.experts.57.down_proj.weight', 'ernie.layers.9.mlp.experts.58.down_proj.weight', 'ernie.layers.9.mlp.experts.59.down_proj.weight', 'ernie.layers.9.mlp.experts.60.down_proj.weight', 'ernie.layers.9.mlp.experts.61.down_proj.weight', 'ernie.layers.9.mlp.experts.62.down_proj.weight', 'ernie.layers.9.mlp.experts.63.down_proj.weight', 'ernie.layers.9.mlp.experts.96.down_proj.weight', 'ernie.layers.9.mlp.experts.97.down_proj.weight', 'ernie.layers.9.mlp.experts.98.down_proj.weight', 'ernie.layers.9.mlp.experts.99.down_proj.weight', 'ernie.layers.9.mlp.experts.100.down_proj.weight', 'ernie.layers.9.mlp.experts.101.down_proj.weight', 'ernie.layers.9.mlp.experts.102.down_proj.weight', 'ernie.layers.9.mlp.experts.103.down_proj.weight', 'ernie.layers.9.mlp.experts.104.down_proj.weight', 'ernie.layers.9.mlp.experts.105.down_proj.weight', 'ernie.layers.9.mlp.experts.106.down_proj.weight', 'ernie.layers.9.mlp.experts.107.down_proj.weight', 'ernie.layers.9.mlp.experts.108.down_proj.weight', 'ernie.layers.9.mlp.experts.109.down_proj.weight', 'ernie.layers.9.mlp.experts.110.down_proj.weight', 'ernie.layers.9.mlp.experts.111.down_proj.weight', 'ernie.layers.9.mlp.experts.112.down_proj.weight', 'ernie.layers.9.mlp.experts.113.down_proj.weight', 'ernie.layers.9.mlp.experts.114.down_proj.weight', 'ernie.layers.9.mlp.experts.115.down_proj.weight', 'ernie.layers.9.mlp.experts.116.down_proj.weight', 'ernie.layers.9.mlp.experts.117.down_proj.weight', 'ernie.layers.9.mlp.experts.118.down_proj.weight', 'ernie.layers.9.mlp.experts.119.down_proj.weight', 'ernie.layers.9.mlp.experts.120.down_proj.weight', 'ernie.layers.9.mlp.experts.121.down_proj.weight', 'ernie.layers.9.mlp.experts.122.down_proj.weight', 'ernie.layers.9.mlp.experts.123.down_proj.weight', 'ernie.layers.9.mlp.experts.124.down_proj.weight', 'ernie.layers.9.mlp.experts.125.down_proj.weight', 'ernie.layers.9.mlp.experts.126.down_proj.weight', 'ernie.layers.9.mlp.experts.127.down_proj.weight'] ernie.layers.10.mlp.image_fused_moe.gate.weight:ernie.layers.10.mlp.gate.weight_1 -ernie.layers.10.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.10.mlp.moe_statics.e_score_correction_bias ernie.layers.10.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.10.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.10.mlp.experts.127.up_gate_proj.weight'] ernie.layers.10.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.10.mlp.experts.32.down_proj.weight', 'ernie.layers.10.mlp.experts.33.down_proj.weight', 'ernie.layers.10.mlp.experts.34.down_proj.weight', 'ernie.layers.10.mlp.experts.35.down_proj.weight', 'ernie.layers.10.mlp.experts.36.down_proj.weight', 'ernie.layers.10.mlp.experts.37.down_proj.weight', 'ernie.layers.10.mlp.experts.38.down_proj.weight', 'ernie.layers.10.mlp.experts.39.down_proj.weight', 'ernie.layers.10.mlp.experts.40.down_proj.weight', 'ernie.layers.10.mlp.experts.41.down_proj.weight', 'ernie.layers.10.mlp.experts.42.down_proj.weight', 'ernie.layers.10.mlp.experts.43.down_proj.weight', 'ernie.layers.10.mlp.experts.44.down_proj.weight', 'ernie.layers.10.mlp.experts.45.down_proj.weight', 'ernie.layers.10.mlp.experts.46.down_proj.weight', 'ernie.layers.10.mlp.experts.47.down_proj.weight', 'ernie.layers.10.mlp.experts.48.down_proj.weight', 'ernie.layers.10.mlp.experts.49.down_proj.weight', 'ernie.layers.10.mlp.experts.50.down_proj.weight', 'ernie.layers.10.mlp.experts.51.down_proj.weight', 'ernie.layers.10.mlp.experts.52.down_proj.weight', 'ernie.layers.10.mlp.experts.53.down_proj.weight', 'ernie.layers.10.mlp.experts.54.down_proj.weight', 'ernie.layers.10.mlp.experts.55.down_proj.weight', 'ernie.layers.10.mlp.experts.56.down_proj.weight', 'ernie.layers.10.mlp.experts.57.down_proj.weight', 'ernie.layers.10.mlp.experts.58.down_proj.weight', 'ernie.layers.10.mlp.experts.59.down_proj.weight', 'ernie.layers.10.mlp.experts.60.down_proj.weight', 'ernie.layers.10.mlp.experts.61.down_proj.weight', 'ernie.layers.10.mlp.experts.62.down_proj.weight', 'ernie.layers.10.mlp.experts.63.down_proj.weight', 'ernie.layers.10.mlp.experts.96.down_proj.weight', 'ernie.layers.10.mlp.experts.97.down_proj.weight', 'ernie.layers.10.mlp.experts.98.down_proj.weight', 'ernie.layers.10.mlp.experts.99.down_proj.weight', 'ernie.layers.10.mlp.experts.100.down_proj.weight', 'ernie.layers.10.mlp.experts.101.down_proj.weight', 'ernie.layers.10.mlp.experts.102.down_proj.weight', 'ernie.layers.10.mlp.experts.103.down_proj.weight', 'ernie.layers.10.mlp.experts.104.down_proj.weight', 'ernie.layers.10.mlp.experts.105.down_proj.weight', 'ernie.layers.10.mlp.experts.106.down_proj.weight', 'ernie.layers.10.mlp.experts.107.down_proj.weight', 'ernie.layers.10.mlp.experts.108.down_proj.weight', 'ernie.layers.10.mlp.experts.109.down_proj.weight', 'ernie.layers.10.mlp.experts.110.down_proj.weight', 'ernie.layers.10.mlp.experts.111.down_proj.weight', 'ernie.layers.10.mlp.experts.112.down_proj.weight', 'ernie.layers.10.mlp.experts.113.down_proj.weight', 'ernie.layers.10.mlp.experts.114.down_proj.weight', 'ernie.layers.10.mlp.experts.115.down_proj.weight', 'ernie.layers.10.mlp.experts.116.down_proj.weight', 'ernie.layers.10.mlp.experts.117.down_proj.weight', 'ernie.layers.10.mlp.experts.118.down_proj.weight', 'ernie.layers.10.mlp.experts.119.down_proj.weight', 'ernie.layers.10.mlp.experts.120.down_proj.weight', 'ernie.layers.10.mlp.experts.121.down_proj.weight', 'ernie.layers.10.mlp.experts.122.down_proj.weight', 'ernie.layers.10.mlp.experts.123.down_proj.weight', 'ernie.layers.10.mlp.experts.124.down_proj.weight', 'ernie.layers.10.mlp.experts.125.down_proj.weight', 'ernie.layers.10.mlp.experts.126.down_proj.weight', 'ernie.layers.10.mlp.experts.127.down_proj.weight'] ernie.layers.11.mlp.image_fused_moe.gate.weight:ernie.layers.11.mlp.gate.weight_1 -ernie.layers.11.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.11.mlp.moe_statics.e_score_correction_bias ernie.layers.11.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.11.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.11.mlp.experts.127.up_gate_proj.weight'] ernie.layers.11.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.11.mlp.experts.32.down_proj.weight', 'ernie.layers.11.mlp.experts.33.down_proj.weight', 'ernie.layers.11.mlp.experts.34.down_proj.weight', 'ernie.layers.11.mlp.experts.35.down_proj.weight', 'ernie.layers.11.mlp.experts.36.down_proj.weight', 'ernie.layers.11.mlp.experts.37.down_proj.weight', 'ernie.layers.11.mlp.experts.38.down_proj.weight', 'ernie.layers.11.mlp.experts.39.down_proj.weight', 'ernie.layers.11.mlp.experts.40.down_proj.weight', 'ernie.layers.11.mlp.experts.41.down_proj.weight', 'ernie.layers.11.mlp.experts.42.down_proj.weight', 'ernie.layers.11.mlp.experts.43.down_proj.weight', 'ernie.layers.11.mlp.experts.44.down_proj.weight', 'ernie.layers.11.mlp.experts.45.down_proj.weight', 'ernie.layers.11.mlp.experts.46.down_proj.weight', 'ernie.layers.11.mlp.experts.47.down_proj.weight', 'ernie.layers.11.mlp.experts.48.down_proj.weight', 'ernie.layers.11.mlp.experts.49.down_proj.weight', 'ernie.layers.11.mlp.experts.50.down_proj.weight', 'ernie.layers.11.mlp.experts.51.down_proj.weight', 'ernie.layers.11.mlp.experts.52.down_proj.weight', 'ernie.layers.11.mlp.experts.53.down_proj.weight', 'ernie.layers.11.mlp.experts.54.down_proj.weight', 'ernie.layers.11.mlp.experts.55.down_proj.weight', 'ernie.layers.11.mlp.experts.56.down_proj.weight', 'ernie.layers.11.mlp.experts.57.down_proj.weight', 'ernie.layers.11.mlp.experts.58.down_proj.weight', 'ernie.layers.11.mlp.experts.59.down_proj.weight', 'ernie.layers.11.mlp.experts.60.down_proj.weight', 'ernie.layers.11.mlp.experts.61.down_proj.weight', 'ernie.layers.11.mlp.experts.62.down_proj.weight', 'ernie.layers.11.mlp.experts.63.down_proj.weight', 'ernie.layers.11.mlp.experts.96.down_proj.weight', 'ernie.layers.11.mlp.experts.97.down_proj.weight', 'ernie.layers.11.mlp.experts.98.down_proj.weight', 'ernie.layers.11.mlp.experts.99.down_proj.weight', 'ernie.layers.11.mlp.experts.100.down_proj.weight', 'ernie.layers.11.mlp.experts.101.down_proj.weight', 'ernie.layers.11.mlp.experts.102.down_proj.weight', 'ernie.layers.11.mlp.experts.103.down_proj.weight', 'ernie.layers.11.mlp.experts.104.down_proj.weight', 'ernie.layers.11.mlp.experts.105.down_proj.weight', 'ernie.layers.11.mlp.experts.106.down_proj.weight', 'ernie.layers.11.mlp.experts.107.down_proj.weight', 'ernie.layers.11.mlp.experts.108.down_proj.weight', 'ernie.layers.11.mlp.experts.109.down_proj.weight', 'ernie.layers.11.mlp.experts.110.down_proj.weight', 'ernie.layers.11.mlp.experts.111.down_proj.weight', 'ernie.layers.11.mlp.experts.112.down_proj.weight', 'ernie.layers.11.mlp.experts.113.down_proj.weight', 'ernie.layers.11.mlp.experts.114.down_proj.weight', 'ernie.layers.11.mlp.experts.115.down_proj.weight', 'ernie.layers.11.mlp.experts.116.down_proj.weight', 'ernie.layers.11.mlp.experts.117.down_proj.weight', 'ernie.layers.11.mlp.experts.118.down_proj.weight', 'ernie.layers.11.mlp.experts.119.down_proj.weight', 'ernie.layers.11.mlp.experts.120.down_proj.weight', 'ernie.layers.11.mlp.experts.121.down_proj.weight', 'ernie.layers.11.mlp.experts.122.down_proj.weight', 'ernie.layers.11.mlp.experts.123.down_proj.weight', 'ernie.layers.11.mlp.experts.124.down_proj.weight', 'ernie.layers.11.mlp.experts.125.down_proj.weight', 'ernie.layers.11.mlp.experts.126.down_proj.weight', 'ernie.layers.11.mlp.experts.127.down_proj.weight'] ernie.layers.12.mlp.image_fused_moe.gate.weight:ernie.layers.12.mlp.gate.weight_1 -ernie.layers.12.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.12.mlp.moe_statics.e_score_correction_bias ernie.layers.12.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.12.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.12.mlp.experts.127.up_gate_proj.weight'] ernie.layers.12.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.12.mlp.experts.32.down_proj.weight', 'ernie.layers.12.mlp.experts.33.down_proj.weight', 'ernie.layers.12.mlp.experts.34.down_proj.weight', 'ernie.layers.12.mlp.experts.35.down_proj.weight', 'ernie.layers.12.mlp.experts.36.down_proj.weight', 'ernie.layers.12.mlp.experts.37.down_proj.weight', 'ernie.layers.12.mlp.experts.38.down_proj.weight', 'ernie.layers.12.mlp.experts.39.down_proj.weight', 'ernie.layers.12.mlp.experts.40.down_proj.weight', 'ernie.layers.12.mlp.experts.41.down_proj.weight', 'ernie.layers.12.mlp.experts.42.down_proj.weight', 'ernie.layers.12.mlp.experts.43.down_proj.weight', 'ernie.layers.12.mlp.experts.44.down_proj.weight', 'ernie.layers.12.mlp.experts.45.down_proj.weight', 'ernie.layers.12.mlp.experts.46.down_proj.weight', 'ernie.layers.12.mlp.experts.47.down_proj.weight', 'ernie.layers.12.mlp.experts.48.down_proj.weight', 'ernie.layers.12.mlp.experts.49.down_proj.weight', 'ernie.layers.12.mlp.experts.50.down_proj.weight', 'ernie.layers.12.mlp.experts.51.down_proj.weight', 'ernie.layers.12.mlp.experts.52.down_proj.weight', 'ernie.layers.12.mlp.experts.53.down_proj.weight', 'ernie.layers.12.mlp.experts.54.down_proj.weight', 'ernie.layers.12.mlp.experts.55.down_proj.weight', 'ernie.layers.12.mlp.experts.56.down_proj.weight', 'ernie.layers.12.mlp.experts.57.down_proj.weight', 'ernie.layers.12.mlp.experts.58.down_proj.weight', 'ernie.layers.12.mlp.experts.59.down_proj.weight', 'ernie.layers.12.mlp.experts.60.down_proj.weight', 'ernie.layers.12.mlp.experts.61.down_proj.weight', 'ernie.layers.12.mlp.experts.62.down_proj.weight', 'ernie.layers.12.mlp.experts.63.down_proj.weight', 'ernie.layers.12.mlp.experts.96.down_proj.weight', 'ernie.layers.12.mlp.experts.97.down_proj.weight', 'ernie.layers.12.mlp.experts.98.down_proj.weight', 'ernie.layers.12.mlp.experts.99.down_proj.weight', 'ernie.layers.12.mlp.experts.100.down_proj.weight', 'ernie.layers.12.mlp.experts.101.down_proj.weight', 'ernie.layers.12.mlp.experts.102.down_proj.weight', 'ernie.layers.12.mlp.experts.103.down_proj.weight', 'ernie.layers.12.mlp.experts.104.down_proj.weight', 'ernie.layers.12.mlp.experts.105.down_proj.weight', 'ernie.layers.12.mlp.experts.106.down_proj.weight', 'ernie.layers.12.mlp.experts.107.down_proj.weight', 'ernie.layers.12.mlp.experts.108.down_proj.weight', 'ernie.layers.12.mlp.experts.109.down_proj.weight', 'ernie.layers.12.mlp.experts.110.down_proj.weight', 'ernie.layers.12.mlp.experts.111.down_proj.weight', 'ernie.layers.12.mlp.experts.112.down_proj.weight', 'ernie.layers.12.mlp.experts.113.down_proj.weight', 'ernie.layers.12.mlp.experts.114.down_proj.weight', 'ernie.layers.12.mlp.experts.115.down_proj.weight', 'ernie.layers.12.mlp.experts.116.down_proj.weight', 'ernie.layers.12.mlp.experts.117.down_proj.weight', 'ernie.layers.12.mlp.experts.118.down_proj.weight', 'ernie.layers.12.mlp.experts.119.down_proj.weight', 'ernie.layers.12.mlp.experts.120.down_proj.weight', 'ernie.layers.12.mlp.experts.121.down_proj.weight', 'ernie.layers.12.mlp.experts.122.down_proj.weight', 'ernie.layers.12.mlp.experts.123.down_proj.weight', 'ernie.layers.12.mlp.experts.124.down_proj.weight', 'ernie.layers.12.mlp.experts.125.down_proj.weight', 'ernie.layers.12.mlp.experts.126.down_proj.weight', 'ernie.layers.12.mlp.experts.127.down_proj.weight'] ernie.layers.13.mlp.image_fused_moe.gate.weight:ernie.layers.13.mlp.gate.weight_1 -ernie.layers.13.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.13.mlp.moe_statics.e_score_correction_bias ernie.layers.13.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.13.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.13.mlp.experts.127.up_gate_proj.weight'] ernie.layers.13.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.13.mlp.experts.32.down_proj.weight', 'ernie.layers.13.mlp.experts.33.down_proj.weight', 'ernie.layers.13.mlp.experts.34.down_proj.weight', 'ernie.layers.13.mlp.experts.35.down_proj.weight', 'ernie.layers.13.mlp.experts.36.down_proj.weight', 'ernie.layers.13.mlp.experts.37.down_proj.weight', 'ernie.layers.13.mlp.experts.38.down_proj.weight', 'ernie.layers.13.mlp.experts.39.down_proj.weight', 'ernie.layers.13.mlp.experts.40.down_proj.weight', 'ernie.layers.13.mlp.experts.41.down_proj.weight', 'ernie.layers.13.mlp.experts.42.down_proj.weight', 'ernie.layers.13.mlp.experts.43.down_proj.weight', 'ernie.layers.13.mlp.experts.44.down_proj.weight', 'ernie.layers.13.mlp.experts.45.down_proj.weight', 'ernie.layers.13.mlp.experts.46.down_proj.weight', 'ernie.layers.13.mlp.experts.47.down_proj.weight', 'ernie.layers.13.mlp.experts.48.down_proj.weight', 'ernie.layers.13.mlp.experts.49.down_proj.weight', 'ernie.layers.13.mlp.experts.50.down_proj.weight', 'ernie.layers.13.mlp.experts.51.down_proj.weight', 'ernie.layers.13.mlp.experts.52.down_proj.weight', 'ernie.layers.13.mlp.experts.53.down_proj.weight', 'ernie.layers.13.mlp.experts.54.down_proj.weight', 'ernie.layers.13.mlp.experts.55.down_proj.weight', 'ernie.layers.13.mlp.experts.56.down_proj.weight', 'ernie.layers.13.mlp.experts.57.down_proj.weight', 'ernie.layers.13.mlp.experts.58.down_proj.weight', 'ernie.layers.13.mlp.experts.59.down_proj.weight', 'ernie.layers.13.mlp.experts.60.down_proj.weight', 'ernie.layers.13.mlp.experts.61.down_proj.weight', 'ernie.layers.13.mlp.experts.62.down_proj.weight', 'ernie.layers.13.mlp.experts.63.down_proj.weight', 'ernie.layers.13.mlp.experts.96.down_proj.weight', 'ernie.layers.13.mlp.experts.97.down_proj.weight', 'ernie.layers.13.mlp.experts.98.down_proj.weight', 'ernie.layers.13.mlp.experts.99.down_proj.weight', 'ernie.layers.13.mlp.experts.100.down_proj.weight', 'ernie.layers.13.mlp.experts.101.down_proj.weight', 'ernie.layers.13.mlp.experts.102.down_proj.weight', 'ernie.layers.13.mlp.experts.103.down_proj.weight', 'ernie.layers.13.mlp.experts.104.down_proj.weight', 'ernie.layers.13.mlp.experts.105.down_proj.weight', 'ernie.layers.13.mlp.experts.106.down_proj.weight', 'ernie.layers.13.mlp.experts.107.down_proj.weight', 'ernie.layers.13.mlp.experts.108.down_proj.weight', 'ernie.layers.13.mlp.experts.109.down_proj.weight', 'ernie.layers.13.mlp.experts.110.down_proj.weight', 'ernie.layers.13.mlp.experts.111.down_proj.weight', 'ernie.layers.13.mlp.experts.112.down_proj.weight', 'ernie.layers.13.mlp.experts.113.down_proj.weight', 'ernie.layers.13.mlp.experts.114.down_proj.weight', 'ernie.layers.13.mlp.experts.115.down_proj.weight', 'ernie.layers.13.mlp.experts.116.down_proj.weight', 'ernie.layers.13.mlp.experts.117.down_proj.weight', 'ernie.layers.13.mlp.experts.118.down_proj.weight', 'ernie.layers.13.mlp.experts.119.down_proj.weight', 'ernie.layers.13.mlp.experts.120.down_proj.weight', 'ernie.layers.13.mlp.experts.121.down_proj.weight', 'ernie.layers.13.mlp.experts.122.down_proj.weight', 'ernie.layers.13.mlp.experts.123.down_proj.weight', 'ernie.layers.13.mlp.experts.124.down_proj.weight', 'ernie.layers.13.mlp.experts.125.down_proj.weight', 'ernie.layers.13.mlp.experts.126.down_proj.weight', 'ernie.layers.13.mlp.experts.127.down_proj.weight'] ernie.layers.14.mlp.image_fused_moe.gate.weight:ernie.layers.14.mlp.gate.weight_1 -ernie.layers.14.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.14.mlp.moe_statics.e_score_correction_bias ernie.layers.14.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.14.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.14.mlp.experts.127.up_gate_proj.weight'] ernie.layers.14.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.14.mlp.experts.32.down_proj.weight', 'ernie.layers.14.mlp.experts.33.down_proj.weight', 'ernie.layers.14.mlp.experts.34.down_proj.weight', 'ernie.layers.14.mlp.experts.35.down_proj.weight', 'ernie.layers.14.mlp.experts.36.down_proj.weight', 'ernie.layers.14.mlp.experts.37.down_proj.weight', 'ernie.layers.14.mlp.experts.38.down_proj.weight', 'ernie.layers.14.mlp.experts.39.down_proj.weight', 'ernie.layers.14.mlp.experts.40.down_proj.weight', 'ernie.layers.14.mlp.experts.41.down_proj.weight', 'ernie.layers.14.mlp.experts.42.down_proj.weight', 'ernie.layers.14.mlp.experts.43.down_proj.weight', 'ernie.layers.14.mlp.experts.44.down_proj.weight', 'ernie.layers.14.mlp.experts.45.down_proj.weight', 'ernie.layers.14.mlp.experts.46.down_proj.weight', 'ernie.layers.14.mlp.experts.47.down_proj.weight', 'ernie.layers.14.mlp.experts.48.down_proj.weight', 'ernie.layers.14.mlp.experts.49.down_proj.weight', 'ernie.layers.14.mlp.experts.50.down_proj.weight', 'ernie.layers.14.mlp.experts.51.down_proj.weight', 'ernie.layers.14.mlp.experts.52.down_proj.weight', 'ernie.layers.14.mlp.experts.53.down_proj.weight', 'ernie.layers.14.mlp.experts.54.down_proj.weight', 'ernie.layers.14.mlp.experts.55.down_proj.weight', 'ernie.layers.14.mlp.experts.56.down_proj.weight', 'ernie.layers.14.mlp.experts.57.down_proj.weight', 'ernie.layers.14.mlp.experts.58.down_proj.weight', 'ernie.layers.14.mlp.experts.59.down_proj.weight', 'ernie.layers.14.mlp.experts.60.down_proj.weight', 'ernie.layers.14.mlp.experts.61.down_proj.weight', 'ernie.layers.14.mlp.experts.62.down_proj.weight', 'ernie.layers.14.mlp.experts.63.down_proj.weight', 'ernie.layers.14.mlp.experts.96.down_proj.weight', 'ernie.layers.14.mlp.experts.97.down_proj.weight', 'ernie.layers.14.mlp.experts.98.down_proj.weight', 'ernie.layers.14.mlp.experts.99.down_proj.weight', 'ernie.layers.14.mlp.experts.100.down_proj.weight', 'ernie.layers.14.mlp.experts.101.down_proj.weight', 'ernie.layers.14.mlp.experts.102.down_proj.weight', 'ernie.layers.14.mlp.experts.103.down_proj.weight', 'ernie.layers.14.mlp.experts.104.down_proj.weight', 'ernie.layers.14.mlp.experts.105.down_proj.weight', 'ernie.layers.14.mlp.experts.106.down_proj.weight', 'ernie.layers.14.mlp.experts.107.down_proj.weight', 'ernie.layers.14.mlp.experts.108.down_proj.weight', 'ernie.layers.14.mlp.experts.109.down_proj.weight', 'ernie.layers.14.mlp.experts.110.down_proj.weight', 'ernie.layers.14.mlp.experts.111.down_proj.weight', 'ernie.layers.14.mlp.experts.112.down_proj.weight', 'ernie.layers.14.mlp.experts.113.down_proj.weight', 'ernie.layers.14.mlp.experts.114.down_proj.weight', 'ernie.layers.14.mlp.experts.115.down_proj.weight', 'ernie.layers.14.mlp.experts.116.down_proj.weight', 'ernie.layers.14.mlp.experts.117.down_proj.weight', 'ernie.layers.14.mlp.experts.118.down_proj.weight', 'ernie.layers.14.mlp.experts.119.down_proj.weight', 'ernie.layers.14.mlp.experts.120.down_proj.weight', 'ernie.layers.14.mlp.experts.121.down_proj.weight', 'ernie.layers.14.mlp.experts.122.down_proj.weight', 'ernie.layers.14.mlp.experts.123.down_proj.weight', 'ernie.layers.14.mlp.experts.124.down_proj.weight', 'ernie.layers.14.mlp.experts.125.down_proj.weight', 'ernie.layers.14.mlp.experts.126.down_proj.weight', 'ernie.layers.14.mlp.experts.127.down_proj.weight'] ernie.layers.15.mlp.image_fused_moe.gate.weight:ernie.layers.15.mlp.gate.weight_1 -ernie.layers.15.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.15.mlp.moe_statics.e_score_correction_bias ernie.layers.15.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.15.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.15.mlp.experts.127.up_gate_proj.weight'] ernie.layers.15.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.15.mlp.experts.32.down_proj.weight', 'ernie.layers.15.mlp.experts.33.down_proj.weight', 'ernie.layers.15.mlp.experts.34.down_proj.weight', 'ernie.layers.15.mlp.experts.35.down_proj.weight', 'ernie.layers.15.mlp.experts.36.down_proj.weight', 'ernie.layers.15.mlp.experts.37.down_proj.weight', 'ernie.layers.15.mlp.experts.38.down_proj.weight', 'ernie.layers.15.mlp.experts.39.down_proj.weight', 'ernie.layers.15.mlp.experts.40.down_proj.weight', 'ernie.layers.15.mlp.experts.41.down_proj.weight', 'ernie.layers.15.mlp.experts.42.down_proj.weight', 'ernie.layers.15.mlp.experts.43.down_proj.weight', 'ernie.layers.15.mlp.experts.44.down_proj.weight', 'ernie.layers.15.mlp.experts.45.down_proj.weight', 'ernie.layers.15.mlp.experts.46.down_proj.weight', 'ernie.layers.15.mlp.experts.47.down_proj.weight', 'ernie.layers.15.mlp.experts.48.down_proj.weight', 'ernie.layers.15.mlp.experts.49.down_proj.weight', 'ernie.layers.15.mlp.experts.50.down_proj.weight', 'ernie.layers.15.mlp.experts.51.down_proj.weight', 'ernie.layers.15.mlp.experts.52.down_proj.weight', 'ernie.layers.15.mlp.experts.53.down_proj.weight', 'ernie.layers.15.mlp.experts.54.down_proj.weight', 'ernie.layers.15.mlp.experts.55.down_proj.weight', 'ernie.layers.15.mlp.experts.56.down_proj.weight', 'ernie.layers.15.mlp.experts.57.down_proj.weight', 'ernie.layers.15.mlp.experts.58.down_proj.weight', 'ernie.layers.15.mlp.experts.59.down_proj.weight', 'ernie.layers.15.mlp.experts.60.down_proj.weight', 'ernie.layers.15.mlp.experts.61.down_proj.weight', 'ernie.layers.15.mlp.experts.62.down_proj.weight', 'ernie.layers.15.mlp.experts.63.down_proj.weight', 'ernie.layers.15.mlp.experts.96.down_proj.weight', 'ernie.layers.15.mlp.experts.97.down_proj.weight', 'ernie.layers.15.mlp.experts.98.down_proj.weight', 'ernie.layers.15.mlp.experts.99.down_proj.weight', 'ernie.layers.15.mlp.experts.100.down_proj.weight', 'ernie.layers.15.mlp.experts.101.down_proj.weight', 'ernie.layers.15.mlp.experts.102.down_proj.weight', 'ernie.layers.15.mlp.experts.103.down_proj.weight', 'ernie.layers.15.mlp.experts.104.down_proj.weight', 'ernie.layers.15.mlp.experts.105.down_proj.weight', 'ernie.layers.15.mlp.experts.106.down_proj.weight', 'ernie.layers.15.mlp.experts.107.down_proj.weight', 'ernie.layers.15.mlp.experts.108.down_proj.weight', 'ernie.layers.15.mlp.experts.109.down_proj.weight', 'ernie.layers.15.mlp.experts.110.down_proj.weight', 'ernie.layers.15.mlp.experts.111.down_proj.weight', 'ernie.layers.15.mlp.experts.112.down_proj.weight', 'ernie.layers.15.mlp.experts.113.down_proj.weight', 'ernie.layers.15.mlp.experts.114.down_proj.weight', 'ernie.layers.15.mlp.experts.115.down_proj.weight', 'ernie.layers.15.mlp.experts.116.down_proj.weight', 'ernie.layers.15.mlp.experts.117.down_proj.weight', 'ernie.layers.15.mlp.experts.118.down_proj.weight', 'ernie.layers.15.mlp.experts.119.down_proj.weight', 'ernie.layers.15.mlp.experts.120.down_proj.weight', 'ernie.layers.15.mlp.experts.121.down_proj.weight', 'ernie.layers.15.mlp.experts.122.down_proj.weight', 'ernie.layers.15.mlp.experts.123.down_proj.weight', 'ernie.layers.15.mlp.experts.124.down_proj.weight', 'ernie.layers.15.mlp.experts.125.down_proj.weight', 'ernie.layers.15.mlp.experts.126.down_proj.weight', 'ernie.layers.15.mlp.experts.127.down_proj.weight'] ernie.layers.16.mlp.image_fused_moe.gate.weight:ernie.layers.16.mlp.gate.weight_1 -ernie.layers.16.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.16.mlp.moe_statics.e_score_correction_bias ernie.layers.16.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.16.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.16.mlp.experts.127.up_gate_proj.weight'] ernie.layers.16.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.16.mlp.experts.32.down_proj.weight', 'ernie.layers.16.mlp.experts.33.down_proj.weight', 'ernie.layers.16.mlp.experts.34.down_proj.weight', 'ernie.layers.16.mlp.experts.35.down_proj.weight', 'ernie.layers.16.mlp.experts.36.down_proj.weight', 'ernie.layers.16.mlp.experts.37.down_proj.weight', 'ernie.layers.16.mlp.experts.38.down_proj.weight', 'ernie.layers.16.mlp.experts.39.down_proj.weight', 'ernie.layers.16.mlp.experts.40.down_proj.weight', 'ernie.layers.16.mlp.experts.41.down_proj.weight', 'ernie.layers.16.mlp.experts.42.down_proj.weight', 'ernie.layers.16.mlp.experts.43.down_proj.weight', 'ernie.layers.16.mlp.experts.44.down_proj.weight', 'ernie.layers.16.mlp.experts.45.down_proj.weight', 'ernie.layers.16.mlp.experts.46.down_proj.weight', 'ernie.layers.16.mlp.experts.47.down_proj.weight', 'ernie.layers.16.mlp.experts.48.down_proj.weight', 'ernie.layers.16.mlp.experts.49.down_proj.weight', 'ernie.layers.16.mlp.experts.50.down_proj.weight', 'ernie.layers.16.mlp.experts.51.down_proj.weight', 'ernie.layers.16.mlp.experts.52.down_proj.weight', 'ernie.layers.16.mlp.experts.53.down_proj.weight', 'ernie.layers.16.mlp.experts.54.down_proj.weight', 'ernie.layers.16.mlp.experts.55.down_proj.weight', 'ernie.layers.16.mlp.experts.56.down_proj.weight', 'ernie.layers.16.mlp.experts.57.down_proj.weight', 'ernie.layers.16.mlp.experts.58.down_proj.weight', 'ernie.layers.16.mlp.experts.59.down_proj.weight', 'ernie.layers.16.mlp.experts.60.down_proj.weight', 'ernie.layers.16.mlp.experts.61.down_proj.weight', 'ernie.layers.16.mlp.experts.62.down_proj.weight', 'ernie.layers.16.mlp.experts.63.down_proj.weight', 'ernie.layers.16.mlp.experts.96.down_proj.weight', 'ernie.layers.16.mlp.experts.97.down_proj.weight', 'ernie.layers.16.mlp.experts.98.down_proj.weight', 'ernie.layers.16.mlp.experts.99.down_proj.weight', 'ernie.layers.16.mlp.experts.100.down_proj.weight', 'ernie.layers.16.mlp.experts.101.down_proj.weight', 'ernie.layers.16.mlp.experts.102.down_proj.weight', 'ernie.layers.16.mlp.experts.103.down_proj.weight', 'ernie.layers.16.mlp.experts.104.down_proj.weight', 'ernie.layers.16.mlp.experts.105.down_proj.weight', 'ernie.layers.16.mlp.experts.106.down_proj.weight', 'ernie.layers.16.mlp.experts.107.down_proj.weight', 'ernie.layers.16.mlp.experts.108.down_proj.weight', 'ernie.layers.16.mlp.experts.109.down_proj.weight', 'ernie.layers.16.mlp.experts.110.down_proj.weight', 'ernie.layers.16.mlp.experts.111.down_proj.weight', 'ernie.layers.16.mlp.experts.112.down_proj.weight', 'ernie.layers.16.mlp.experts.113.down_proj.weight', 'ernie.layers.16.mlp.experts.114.down_proj.weight', 'ernie.layers.16.mlp.experts.115.down_proj.weight', 'ernie.layers.16.mlp.experts.116.down_proj.weight', 'ernie.layers.16.mlp.experts.117.down_proj.weight', 'ernie.layers.16.mlp.experts.118.down_proj.weight', 'ernie.layers.16.mlp.experts.119.down_proj.weight', 'ernie.layers.16.mlp.experts.120.down_proj.weight', 'ernie.layers.16.mlp.experts.121.down_proj.weight', 'ernie.layers.16.mlp.experts.122.down_proj.weight', 'ernie.layers.16.mlp.experts.123.down_proj.weight', 'ernie.layers.16.mlp.experts.124.down_proj.weight', 'ernie.layers.16.mlp.experts.125.down_proj.weight', 'ernie.layers.16.mlp.experts.126.down_proj.weight', 'ernie.layers.16.mlp.experts.127.down_proj.weight'] ernie.layers.17.mlp.image_fused_moe.gate.weight:ernie.layers.17.mlp.gate.weight_1 -ernie.layers.17.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.17.mlp.moe_statics.e_score_correction_bias ernie.layers.17.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.17.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.17.mlp.experts.127.up_gate_proj.weight'] ernie.layers.17.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.17.mlp.experts.32.down_proj.weight', 'ernie.layers.17.mlp.experts.33.down_proj.weight', 'ernie.layers.17.mlp.experts.34.down_proj.weight', 'ernie.layers.17.mlp.experts.35.down_proj.weight', 'ernie.layers.17.mlp.experts.36.down_proj.weight', 'ernie.layers.17.mlp.experts.37.down_proj.weight', 'ernie.layers.17.mlp.experts.38.down_proj.weight', 'ernie.layers.17.mlp.experts.39.down_proj.weight', 'ernie.layers.17.mlp.experts.40.down_proj.weight', 'ernie.layers.17.mlp.experts.41.down_proj.weight', 'ernie.layers.17.mlp.experts.42.down_proj.weight', 'ernie.layers.17.mlp.experts.43.down_proj.weight', 'ernie.layers.17.mlp.experts.44.down_proj.weight', 'ernie.layers.17.mlp.experts.45.down_proj.weight', 'ernie.layers.17.mlp.experts.46.down_proj.weight', 'ernie.layers.17.mlp.experts.47.down_proj.weight', 'ernie.layers.17.mlp.experts.48.down_proj.weight', 'ernie.layers.17.mlp.experts.49.down_proj.weight', 'ernie.layers.17.mlp.experts.50.down_proj.weight', 'ernie.layers.17.mlp.experts.51.down_proj.weight', 'ernie.layers.17.mlp.experts.52.down_proj.weight', 'ernie.layers.17.mlp.experts.53.down_proj.weight', 'ernie.layers.17.mlp.experts.54.down_proj.weight', 'ernie.layers.17.mlp.experts.55.down_proj.weight', 'ernie.layers.17.mlp.experts.56.down_proj.weight', 'ernie.layers.17.mlp.experts.57.down_proj.weight', 'ernie.layers.17.mlp.experts.58.down_proj.weight', 'ernie.layers.17.mlp.experts.59.down_proj.weight', 'ernie.layers.17.mlp.experts.60.down_proj.weight', 'ernie.layers.17.mlp.experts.61.down_proj.weight', 'ernie.layers.17.mlp.experts.62.down_proj.weight', 'ernie.layers.17.mlp.experts.63.down_proj.weight', 'ernie.layers.17.mlp.experts.96.down_proj.weight', 'ernie.layers.17.mlp.experts.97.down_proj.weight', 'ernie.layers.17.mlp.experts.98.down_proj.weight', 'ernie.layers.17.mlp.experts.99.down_proj.weight', 'ernie.layers.17.mlp.experts.100.down_proj.weight', 'ernie.layers.17.mlp.experts.101.down_proj.weight', 'ernie.layers.17.mlp.experts.102.down_proj.weight', 'ernie.layers.17.mlp.experts.103.down_proj.weight', 'ernie.layers.17.mlp.experts.104.down_proj.weight', 'ernie.layers.17.mlp.experts.105.down_proj.weight', 'ernie.layers.17.mlp.experts.106.down_proj.weight', 'ernie.layers.17.mlp.experts.107.down_proj.weight', 'ernie.layers.17.mlp.experts.108.down_proj.weight', 'ernie.layers.17.mlp.experts.109.down_proj.weight', 'ernie.layers.17.mlp.experts.110.down_proj.weight', 'ernie.layers.17.mlp.experts.111.down_proj.weight', 'ernie.layers.17.mlp.experts.112.down_proj.weight', 'ernie.layers.17.mlp.experts.113.down_proj.weight', 'ernie.layers.17.mlp.experts.114.down_proj.weight', 'ernie.layers.17.mlp.experts.115.down_proj.weight', 'ernie.layers.17.mlp.experts.116.down_proj.weight', 'ernie.layers.17.mlp.experts.117.down_proj.weight', 'ernie.layers.17.mlp.experts.118.down_proj.weight', 'ernie.layers.17.mlp.experts.119.down_proj.weight', 'ernie.layers.17.mlp.experts.120.down_proj.weight', 'ernie.layers.17.mlp.experts.121.down_proj.weight', 'ernie.layers.17.mlp.experts.122.down_proj.weight', 'ernie.layers.17.mlp.experts.123.down_proj.weight', 'ernie.layers.17.mlp.experts.124.down_proj.weight', 'ernie.layers.17.mlp.experts.125.down_proj.weight', 'ernie.layers.17.mlp.experts.126.down_proj.weight', 'ernie.layers.17.mlp.experts.127.down_proj.weight'] ernie.layers.18.mlp.image_fused_moe.gate.weight:ernie.layers.18.mlp.gate.weight_1 -ernie.layers.18.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.18.mlp.moe_statics.e_score_correction_bias ernie.layers.18.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.18.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.18.mlp.experts.127.up_gate_proj.weight'] ernie.layers.18.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.18.mlp.experts.32.down_proj.weight', 'ernie.layers.18.mlp.experts.33.down_proj.weight', 'ernie.layers.18.mlp.experts.34.down_proj.weight', 'ernie.layers.18.mlp.experts.35.down_proj.weight', 'ernie.layers.18.mlp.experts.36.down_proj.weight', 'ernie.layers.18.mlp.experts.37.down_proj.weight', 'ernie.layers.18.mlp.experts.38.down_proj.weight', 'ernie.layers.18.mlp.experts.39.down_proj.weight', 'ernie.layers.18.mlp.experts.40.down_proj.weight', 'ernie.layers.18.mlp.experts.41.down_proj.weight', 'ernie.layers.18.mlp.experts.42.down_proj.weight', 'ernie.layers.18.mlp.experts.43.down_proj.weight', 'ernie.layers.18.mlp.experts.44.down_proj.weight', 'ernie.layers.18.mlp.experts.45.down_proj.weight', 'ernie.layers.18.mlp.experts.46.down_proj.weight', 'ernie.layers.18.mlp.experts.47.down_proj.weight', 'ernie.layers.18.mlp.experts.48.down_proj.weight', 'ernie.layers.18.mlp.experts.49.down_proj.weight', 'ernie.layers.18.mlp.experts.50.down_proj.weight', 'ernie.layers.18.mlp.experts.51.down_proj.weight', 'ernie.layers.18.mlp.experts.52.down_proj.weight', 'ernie.layers.18.mlp.experts.53.down_proj.weight', 'ernie.layers.18.mlp.experts.54.down_proj.weight', 'ernie.layers.18.mlp.experts.55.down_proj.weight', 'ernie.layers.18.mlp.experts.56.down_proj.weight', 'ernie.layers.18.mlp.experts.57.down_proj.weight', 'ernie.layers.18.mlp.experts.58.down_proj.weight', 'ernie.layers.18.mlp.experts.59.down_proj.weight', 'ernie.layers.18.mlp.experts.60.down_proj.weight', 'ernie.layers.18.mlp.experts.61.down_proj.weight', 'ernie.layers.18.mlp.experts.62.down_proj.weight', 'ernie.layers.18.mlp.experts.63.down_proj.weight', 'ernie.layers.18.mlp.experts.96.down_proj.weight', 'ernie.layers.18.mlp.experts.97.down_proj.weight', 'ernie.layers.18.mlp.experts.98.down_proj.weight', 'ernie.layers.18.mlp.experts.99.down_proj.weight', 'ernie.layers.18.mlp.experts.100.down_proj.weight', 'ernie.layers.18.mlp.experts.101.down_proj.weight', 'ernie.layers.18.mlp.experts.102.down_proj.weight', 'ernie.layers.18.mlp.experts.103.down_proj.weight', 'ernie.layers.18.mlp.experts.104.down_proj.weight', 'ernie.layers.18.mlp.experts.105.down_proj.weight', 'ernie.layers.18.mlp.experts.106.down_proj.weight', 'ernie.layers.18.mlp.experts.107.down_proj.weight', 'ernie.layers.18.mlp.experts.108.down_proj.weight', 'ernie.layers.18.mlp.experts.109.down_proj.weight', 'ernie.layers.18.mlp.experts.110.down_proj.weight', 'ernie.layers.18.mlp.experts.111.down_proj.weight', 'ernie.layers.18.mlp.experts.112.down_proj.weight', 'ernie.layers.18.mlp.experts.113.down_proj.weight', 'ernie.layers.18.mlp.experts.114.down_proj.weight', 'ernie.layers.18.mlp.experts.115.down_proj.weight', 'ernie.layers.18.mlp.experts.116.down_proj.weight', 'ernie.layers.18.mlp.experts.117.down_proj.weight', 'ernie.layers.18.mlp.experts.118.down_proj.weight', 'ernie.layers.18.mlp.experts.119.down_proj.weight', 'ernie.layers.18.mlp.experts.120.down_proj.weight', 'ernie.layers.18.mlp.experts.121.down_proj.weight', 'ernie.layers.18.mlp.experts.122.down_proj.weight', 'ernie.layers.18.mlp.experts.123.down_proj.weight', 'ernie.layers.18.mlp.experts.124.down_proj.weight', 'ernie.layers.18.mlp.experts.125.down_proj.weight', 'ernie.layers.18.mlp.experts.126.down_proj.weight', 'ernie.layers.18.mlp.experts.127.down_proj.weight'] ernie.layers.19.mlp.image_fused_moe.gate.weight:ernie.layers.19.mlp.gate.weight_1 -ernie.layers.19.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.19.mlp.moe_statics.e_score_correction_bias ernie.layers.19.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.19.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.19.mlp.experts.127.up_gate_proj.weight'] ernie.layers.19.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.19.mlp.experts.32.down_proj.weight', 'ernie.layers.19.mlp.experts.33.down_proj.weight', 'ernie.layers.19.mlp.experts.34.down_proj.weight', 'ernie.layers.19.mlp.experts.35.down_proj.weight', 'ernie.layers.19.mlp.experts.36.down_proj.weight', 'ernie.layers.19.mlp.experts.37.down_proj.weight', 'ernie.layers.19.mlp.experts.38.down_proj.weight', 'ernie.layers.19.mlp.experts.39.down_proj.weight', 'ernie.layers.19.mlp.experts.40.down_proj.weight', 'ernie.layers.19.mlp.experts.41.down_proj.weight', 'ernie.layers.19.mlp.experts.42.down_proj.weight', 'ernie.layers.19.mlp.experts.43.down_proj.weight', 'ernie.layers.19.mlp.experts.44.down_proj.weight', 'ernie.layers.19.mlp.experts.45.down_proj.weight', 'ernie.layers.19.mlp.experts.46.down_proj.weight', 'ernie.layers.19.mlp.experts.47.down_proj.weight', 'ernie.layers.19.mlp.experts.48.down_proj.weight', 'ernie.layers.19.mlp.experts.49.down_proj.weight', 'ernie.layers.19.mlp.experts.50.down_proj.weight', 'ernie.layers.19.mlp.experts.51.down_proj.weight', 'ernie.layers.19.mlp.experts.52.down_proj.weight', 'ernie.layers.19.mlp.experts.53.down_proj.weight', 'ernie.layers.19.mlp.experts.54.down_proj.weight', 'ernie.layers.19.mlp.experts.55.down_proj.weight', 'ernie.layers.19.mlp.experts.56.down_proj.weight', 'ernie.layers.19.mlp.experts.57.down_proj.weight', 'ernie.layers.19.mlp.experts.58.down_proj.weight', 'ernie.layers.19.mlp.experts.59.down_proj.weight', 'ernie.layers.19.mlp.experts.60.down_proj.weight', 'ernie.layers.19.mlp.experts.61.down_proj.weight', 'ernie.layers.19.mlp.experts.62.down_proj.weight', 'ernie.layers.19.mlp.experts.63.down_proj.weight', 'ernie.layers.19.mlp.experts.96.down_proj.weight', 'ernie.layers.19.mlp.experts.97.down_proj.weight', 'ernie.layers.19.mlp.experts.98.down_proj.weight', 'ernie.layers.19.mlp.experts.99.down_proj.weight', 'ernie.layers.19.mlp.experts.100.down_proj.weight', 'ernie.layers.19.mlp.experts.101.down_proj.weight', 'ernie.layers.19.mlp.experts.102.down_proj.weight', 'ernie.layers.19.mlp.experts.103.down_proj.weight', 'ernie.layers.19.mlp.experts.104.down_proj.weight', 'ernie.layers.19.mlp.experts.105.down_proj.weight', 'ernie.layers.19.mlp.experts.106.down_proj.weight', 'ernie.layers.19.mlp.experts.107.down_proj.weight', 'ernie.layers.19.mlp.experts.108.down_proj.weight', 'ernie.layers.19.mlp.experts.109.down_proj.weight', 'ernie.layers.19.mlp.experts.110.down_proj.weight', 'ernie.layers.19.mlp.experts.111.down_proj.weight', 'ernie.layers.19.mlp.experts.112.down_proj.weight', 'ernie.layers.19.mlp.experts.113.down_proj.weight', 'ernie.layers.19.mlp.experts.114.down_proj.weight', 'ernie.layers.19.mlp.experts.115.down_proj.weight', 'ernie.layers.19.mlp.experts.116.down_proj.weight', 'ernie.layers.19.mlp.experts.117.down_proj.weight', 'ernie.layers.19.mlp.experts.118.down_proj.weight', 'ernie.layers.19.mlp.experts.119.down_proj.weight', 'ernie.layers.19.mlp.experts.120.down_proj.weight', 'ernie.layers.19.mlp.experts.121.down_proj.weight', 'ernie.layers.19.mlp.experts.122.down_proj.weight', 'ernie.layers.19.mlp.experts.123.down_proj.weight', 'ernie.layers.19.mlp.experts.124.down_proj.weight', 'ernie.layers.19.mlp.experts.125.down_proj.weight', 'ernie.layers.19.mlp.experts.126.down_proj.weight', 'ernie.layers.19.mlp.experts.127.down_proj.weight'] ernie.layers.20.mlp.image_fused_moe.gate.weight:ernie.layers.20.mlp.gate.weight_1 -ernie.layers.20.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.20.mlp.moe_statics.e_score_correction_bias ernie.layers.20.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.20.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.20.mlp.experts.127.up_gate_proj.weight'] ernie.layers.20.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.20.mlp.experts.32.down_proj.weight', 'ernie.layers.20.mlp.experts.33.down_proj.weight', 'ernie.layers.20.mlp.experts.34.down_proj.weight', 'ernie.layers.20.mlp.experts.35.down_proj.weight', 'ernie.layers.20.mlp.experts.36.down_proj.weight', 'ernie.layers.20.mlp.experts.37.down_proj.weight', 'ernie.layers.20.mlp.experts.38.down_proj.weight', 'ernie.layers.20.mlp.experts.39.down_proj.weight', 'ernie.layers.20.mlp.experts.40.down_proj.weight', 'ernie.layers.20.mlp.experts.41.down_proj.weight', 'ernie.layers.20.mlp.experts.42.down_proj.weight', 'ernie.layers.20.mlp.experts.43.down_proj.weight', 'ernie.layers.20.mlp.experts.44.down_proj.weight', 'ernie.layers.20.mlp.experts.45.down_proj.weight', 'ernie.layers.20.mlp.experts.46.down_proj.weight', 'ernie.layers.20.mlp.experts.47.down_proj.weight', 'ernie.layers.20.mlp.experts.48.down_proj.weight', 'ernie.layers.20.mlp.experts.49.down_proj.weight', 'ernie.layers.20.mlp.experts.50.down_proj.weight', 'ernie.layers.20.mlp.experts.51.down_proj.weight', 'ernie.layers.20.mlp.experts.52.down_proj.weight', 'ernie.layers.20.mlp.experts.53.down_proj.weight', 'ernie.layers.20.mlp.experts.54.down_proj.weight', 'ernie.layers.20.mlp.experts.55.down_proj.weight', 'ernie.layers.20.mlp.experts.56.down_proj.weight', 'ernie.layers.20.mlp.experts.57.down_proj.weight', 'ernie.layers.20.mlp.experts.58.down_proj.weight', 'ernie.layers.20.mlp.experts.59.down_proj.weight', 'ernie.layers.20.mlp.experts.60.down_proj.weight', 'ernie.layers.20.mlp.experts.61.down_proj.weight', 'ernie.layers.20.mlp.experts.62.down_proj.weight', 'ernie.layers.20.mlp.experts.63.down_proj.weight', 'ernie.layers.20.mlp.experts.96.down_proj.weight', 'ernie.layers.20.mlp.experts.97.down_proj.weight', 'ernie.layers.20.mlp.experts.98.down_proj.weight', 'ernie.layers.20.mlp.experts.99.down_proj.weight', 'ernie.layers.20.mlp.experts.100.down_proj.weight', 'ernie.layers.20.mlp.experts.101.down_proj.weight', 'ernie.layers.20.mlp.experts.102.down_proj.weight', 'ernie.layers.20.mlp.experts.103.down_proj.weight', 'ernie.layers.20.mlp.experts.104.down_proj.weight', 'ernie.layers.20.mlp.experts.105.down_proj.weight', 'ernie.layers.20.mlp.experts.106.down_proj.weight', 'ernie.layers.20.mlp.experts.107.down_proj.weight', 'ernie.layers.20.mlp.experts.108.down_proj.weight', 'ernie.layers.20.mlp.experts.109.down_proj.weight', 'ernie.layers.20.mlp.experts.110.down_proj.weight', 'ernie.layers.20.mlp.experts.111.down_proj.weight', 'ernie.layers.20.mlp.experts.112.down_proj.weight', 'ernie.layers.20.mlp.experts.113.down_proj.weight', 'ernie.layers.20.mlp.experts.114.down_proj.weight', 'ernie.layers.20.mlp.experts.115.down_proj.weight', 'ernie.layers.20.mlp.experts.116.down_proj.weight', 'ernie.layers.20.mlp.experts.117.down_proj.weight', 'ernie.layers.20.mlp.experts.118.down_proj.weight', 'ernie.layers.20.mlp.experts.119.down_proj.weight', 'ernie.layers.20.mlp.experts.120.down_proj.weight', 'ernie.layers.20.mlp.experts.121.down_proj.weight', 'ernie.layers.20.mlp.experts.122.down_proj.weight', 'ernie.layers.20.mlp.experts.123.down_proj.weight', 'ernie.layers.20.mlp.experts.124.down_proj.weight', 'ernie.layers.20.mlp.experts.125.down_proj.weight', 'ernie.layers.20.mlp.experts.126.down_proj.weight', 'ernie.layers.20.mlp.experts.127.down_proj.weight'] ernie.layers.21.mlp.image_fused_moe.gate.weight:ernie.layers.21.mlp.gate.weight_1 -ernie.layers.21.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.21.mlp.moe_statics.e_score_correction_bias ernie.layers.21.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.21.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.21.mlp.experts.127.up_gate_proj.weight'] ernie.layers.21.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.21.mlp.experts.32.down_proj.weight', 'ernie.layers.21.mlp.experts.33.down_proj.weight', 'ernie.layers.21.mlp.experts.34.down_proj.weight', 'ernie.layers.21.mlp.experts.35.down_proj.weight', 'ernie.layers.21.mlp.experts.36.down_proj.weight', 'ernie.layers.21.mlp.experts.37.down_proj.weight', 'ernie.layers.21.mlp.experts.38.down_proj.weight', 'ernie.layers.21.mlp.experts.39.down_proj.weight', 'ernie.layers.21.mlp.experts.40.down_proj.weight', 'ernie.layers.21.mlp.experts.41.down_proj.weight', 'ernie.layers.21.mlp.experts.42.down_proj.weight', 'ernie.layers.21.mlp.experts.43.down_proj.weight', 'ernie.layers.21.mlp.experts.44.down_proj.weight', 'ernie.layers.21.mlp.experts.45.down_proj.weight', 'ernie.layers.21.mlp.experts.46.down_proj.weight', 'ernie.layers.21.mlp.experts.47.down_proj.weight', 'ernie.layers.21.mlp.experts.48.down_proj.weight', 'ernie.layers.21.mlp.experts.49.down_proj.weight', 'ernie.layers.21.mlp.experts.50.down_proj.weight', 'ernie.layers.21.mlp.experts.51.down_proj.weight', 'ernie.layers.21.mlp.experts.52.down_proj.weight', 'ernie.layers.21.mlp.experts.53.down_proj.weight', 'ernie.layers.21.mlp.experts.54.down_proj.weight', 'ernie.layers.21.mlp.experts.55.down_proj.weight', 'ernie.layers.21.mlp.experts.56.down_proj.weight', 'ernie.layers.21.mlp.experts.57.down_proj.weight', 'ernie.layers.21.mlp.experts.58.down_proj.weight', 'ernie.layers.21.mlp.experts.59.down_proj.weight', 'ernie.layers.21.mlp.experts.60.down_proj.weight', 'ernie.layers.21.mlp.experts.61.down_proj.weight', 'ernie.layers.21.mlp.experts.62.down_proj.weight', 'ernie.layers.21.mlp.experts.63.down_proj.weight', 'ernie.layers.21.mlp.experts.96.down_proj.weight', 'ernie.layers.21.mlp.experts.97.down_proj.weight', 'ernie.layers.21.mlp.experts.98.down_proj.weight', 'ernie.layers.21.mlp.experts.99.down_proj.weight', 'ernie.layers.21.mlp.experts.100.down_proj.weight', 'ernie.layers.21.mlp.experts.101.down_proj.weight', 'ernie.layers.21.mlp.experts.102.down_proj.weight', 'ernie.layers.21.mlp.experts.103.down_proj.weight', 'ernie.layers.21.mlp.experts.104.down_proj.weight', 'ernie.layers.21.mlp.experts.105.down_proj.weight', 'ernie.layers.21.mlp.experts.106.down_proj.weight', 'ernie.layers.21.mlp.experts.107.down_proj.weight', 'ernie.layers.21.mlp.experts.108.down_proj.weight', 'ernie.layers.21.mlp.experts.109.down_proj.weight', 'ernie.layers.21.mlp.experts.110.down_proj.weight', 'ernie.layers.21.mlp.experts.111.down_proj.weight', 'ernie.layers.21.mlp.experts.112.down_proj.weight', 'ernie.layers.21.mlp.experts.113.down_proj.weight', 'ernie.layers.21.mlp.experts.114.down_proj.weight', 'ernie.layers.21.mlp.experts.115.down_proj.weight', 'ernie.layers.21.mlp.experts.116.down_proj.weight', 'ernie.layers.21.mlp.experts.117.down_proj.weight', 'ernie.layers.21.mlp.experts.118.down_proj.weight', 'ernie.layers.21.mlp.experts.119.down_proj.weight', 'ernie.layers.21.mlp.experts.120.down_proj.weight', 'ernie.layers.21.mlp.experts.121.down_proj.weight', 'ernie.layers.21.mlp.experts.122.down_proj.weight', 'ernie.layers.21.mlp.experts.123.down_proj.weight', 'ernie.layers.21.mlp.experts.124.down_proj.weight', 'ernie.layers.21.mlp.experts.125.down_proj.weight', 'ernie.layers.21.mlp.experts.126.down_proj.weight', 'ernie.layers.21.mlp.experts.127.down_proj.weight'] ernie.layers.22.mlp.image_fused_moe.gate.weight:ernie.layers.22.mlp.gate.weight_1 -ernie.layers.22.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.22.mlp.moe_statics.e_score_correction_bias ernie.layers.22.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.22.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.22.mlp.experts.127.up_gate_proj.weight'] ernie.layers.22.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.22.mlp.experts.32.down_proj.weight', 'ernie.layers.22.mlp.experts.33.down_proj.weight', 'ernie.layers.22.mlp.experts.34.down_proj.weight', 'ernie.layers.22.mlp.experts.35.down_proj.weight', 'ernie.layers.22.mlp.experts.36.down_proj.weight', 'ernie.layers.22.mlp.experts.37.down_proj.weight', 'ernie.layers.22.mlp.experts.38.down_proj.weight', 'ernie.layers.22.mlp.experts.39.down_proj.weight', 'ernie.layers.22.mlp.experts.40.down_proj.weight', 'ernie.layers.22.mlp.experts.41.down_proj.weight', 'ernie.layers.22.mlp.experts.42.down_proj.weight', 'ernie.layers.22.mlp.experts.43.down_proj.weight', 'ernie.layers.22.mlp.experts.44.down_proj.weight', 'ernie.layers.22.mlp.experts.45.down_proj.weight', 'ernie.layers.22.mlp.experts.46.down_proj.weight', 'ernie.layers.22.mlp.experts.47.down_proj.weight', 'ernie.layers.22.mlp.experts.48.down_proj.weight', 'ernie.layers.22.mlp.experts.49.down_proj.weight', 'ernie.layers.22.mlp.experts.50.down_proj.weight', 'ernie.layers.22.mlp.experts.51.down_proj.weight', 'ernie.layers.22.mlp.experts.52.down_proj.weight', 'ernie.layers.22.mlp.experts.53.down_proj.weight', 'ernie.layers.22.mlp.experts.54.down_proj.weight', 'ernie.layers.22.mlp.experts.55.down_proj.weight', 'ernie.layers.22.mlp.experts.56.down_proj.weight', 'ernie.layers.22.mlp.experts.57.down_proj.weight', 'ernie.layers.22.mlp.experts.58.down_proj.weight', 'ernie.layers.22.mlp.experts.59.down_proj.weight', 'ernie.layers.22.mlp.experts.60.down_proj.weight', 'ernie.layers.22.mlp.experts.61.down_proj.weight', 'ernie.layers.22.mlp.experts.62.down_proj.weight', 'ernie.layers.22.mlp.experts.63.down_proj.weight', 'ernie.layers.22.mlp.experts.96.down_proj.weight', 'ernie.layers.22.mlp.experts.97.down_proj.weight', 'ernie.layers.22.mlp.experts.98.down_proj.weight', 'ernie.layers.22.mlp.experts.99.down_proj.weight', 'ernie.layers.22.mlp.experts.100.down_proj.weight', 'ernie.layers.22.mlp.experts.101.down_proj.weight', 'ernie.layers.22.mlp.experts.102.down_proj.weight', 'ernie.layers.22.mlp.experts.103.down_proj.weight', 'ernie.layers.22.mlp.experts.104.down_proj.weight', 'ernie.layers.22.mlp.experts.105.down_proj.weight', 'ernie.layers.22.mlp.experts.106.down_proj.weight', 'ernie.layers.22.mlp.experts.107.down_proj.weight', 'ernie.layers.22.mlp.experts.108.down_proj.weight', 'ernie.layers.22.mlp.experts.109.down_proj.weight', 'ernie.layers.22.mlp.experts.110.down_proj.weight', 'ernie.layers.22.mlp.experts.111.down_proj.weight', 'ernie.layers.22.mlp.experts.112.down_proj.weight', 'ernie.layers.22.mlp.experts.113.down_proj.weight', 'ernie.layers.22.mlp.experts.114.down_proj.weight', 'ernie.layers.22.mlp.experts.115.down_proj.weight', 'ernie.layers.22.mlp.experts.116.down_proj.weight', 'ernie.layers.22.mlp.experts.117.down_proj.weight', 'ernie.layers.22.mlp.experts.118.down_proj.weight', 'ernie.layers.22.mlp.experts.119.down_proj.weight', 'ernie.layers.22.mlp.experts.120.down_proj.weight', 'ernie.layers.22.mlp.experts.121.down_proj.weight', 'ernie.layers.22.mlp.experts.122.down_proj.weight', 'ernie.layers.22.mlp.experts.123.down_proj.weight', 'ernie.layers.22.mlp.experts.124.down_proj.weight', 'ernie.layers.22.mlp.experts.125.down_proj.weight', 'ernie.layers.22.mlp.experts.126.down_proj.weight', 'ernie.layers.22.mlp.experts.127.down_proj.weight'] ernie.layers.23.mlp.image_fused_moe.gate.weight:ernie.layers.23.mlp.gate.weight_1 -ernie.layers.23.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.23.mlp.moe_statics.e_score_correction_bias ernie.layers.23.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.23.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.23.mlp.experts.127.up_gate_proj.weight'] ernie.layers.23.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.23.mlp.experts.32.down_proj.weight', 'ernie.layers.23.mlp.experts.33.down_proj.weight', 'ernie.layers.23.mlp.experts.34.down_proj.weight', 'ernie.layers.23.mlp.experts.35.down_proj.weight', 'ernie.layers.23.mlp.experts.36.down_proj.weight', 'ernie.layers.23.mlp.experts.37.down_proj.weight', 'ernie.layers.23.mlp.experts.38.down_proj.weight', 'ernie.layers.23.mlp.experts.39.down_proj.weight', 'ernie.layers.23.mlp.experts.40.down_proj.weight', 'ernie.layers.23.mlp.experts.41.down_proj.weight', 'ernie.layers.23.mlp.experts.42.down_proj.weight', 'ernie.layers.23.mlp.experts.43.down_proj.weight', 'ernie.layers.23.mlp.experts.44.down_proj.weight', 'ernie.layers.23.mlp.experts.45.down_proj.weight', 'ernie.layers.23.mlp.experts.46.down_proj.weight', 'ernie.layers.23.mlp.experts.47.down_proj.weight', 'ernie.layers.23.mlp.experts.48.down_proj.weight', 'ernie.layers.23.mlp.experts.49.down_proj.weight', 'ernie.layers.23.mlp.experts.50.down_proj.weight', 'ernie.layers.23.mlp.experts.51.down_proj.weight', 'ernie.layers.23.mlp.experts.52.down_proj.weight', 'ernie.layers.23.mlp.experts.53.down_proj.weight', 'ernie.layers.23.mlp.experts.54.down_proj.weight', 'ernie.layers.23.mlp.experts.55.down_proj.weight', 'ernie.layers.23.mlp.experts.56.down_proj.weight', 'ernie.layers.23.mlp.experts.57.down_proj.weight', 'ernie.layers.23.mlp.experts.58.down_proj.weight', 'ernie.layers.23.mlp.experts.59.down_proj.weight', 'ernie.layers.23.mlp.experts.60.down_proj.weight', 'ernie.layers.23.mlp.experts.61.down_proj.weight', 'ernie.layers.23.mlp.experts.62.down_proj.weight', 'ernie.layers.23.mlp.experts.63.down_proj.weight', 'ernie.layers.23.mlp.experts.96.down_proj.weight', 'ernie.layers.23.mlp.experts.97.down_proj.weight', 'ernie.layers.23.mlp.experts.98.down_proj.weight', 'ernie.layers.23.mlp.experts.99.down_proj.weight', 'ernie.layers.23.mlp.experts.100.down_proj.weight', 'ernie.layers.23.mlp.experts.101.down_proj.weight', 'ernie.layers.23.mlp.experts.102.down_proj.weight', 'ernie.layers.23.mlp.experts.103.down_proj.weight', 'ernie.layers.23.mlp.experts.104.down_proj.weight', 'ernie.layers.23.mlp.experts.105.down_proj.weight', 'ernie.layers.23.mlp.experts.106.down_proj.weight', 'ernie.layers.23.mlp.experts.107.down_proj.weight', 'ernie.layers.23.mlp.experts.108.down_proj.weight', 'ernie.layers.23.mlp.experts.109.down_proj.weight', 'ernie.layers.23.mlp.experts.110.down_proj.weight', 'ernie.layers.23.mlp.experts.111.down_proj.weight', 'ernie.layers.23.mlp.experts.112.down_proj.weight', 'ernie.layers.23.mlp.experts.113.down_proj.weight', 'ernie.layers.23.mlp.experts.114.down_proj.weight', 'ernie.layers.23.mlp.experts.115.down_proj.weight', 'ernie.layers.23.mlp.experts.116.down_proj.weight', 'ernie.layers.23.mlp.experts.117.down_proj.weight', 'ernie.layers.23.mlp.experts.118.down_proj.weight', 'ernie.layers.23.mlp.experts.119.down_proj.weight', 'ernie.layers.23.mlp.experts.120.down_proj.weight', 'ernie.layers.23.mlp.experts.121.down_proj.weight', 'ernie.layers.23.mlp.experts.122.down_proj.weight', 'ernie.layers.23.mlp.experts.123.down_proj.weight', 'ernie.layers.23.mlp.experts.124.down_proj.weight', 'ernie.layers.23.mlp.experts.125.down_proj.weight', 'ernie.layers.23.mlp.experts.126.down_proj.weight', 'ernie.layers.23.mlp.experts.127.down_proj.weight'] ernie.layers.24.mlp.image_fused_moe.gate.weight:ernie.layers.24.mlp.gate.weight_1 -ernie.layers.24.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.24.mlp.moe_statics.e_score_correction_bias ernie.layers.24.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.24.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.24.mlp.experts.127.up_gate_proj.weight'] ernie.layers.24.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.24.mlp.experts.32.down_proj.weight', 'ernie.layers.24.mlp.experts.33.down_proj.weight', 'ernie.layers.24.mlp.experts.34.down_proj.weight', 'ernie.layers.24.mlp.experts.35.down_proj.weight', 'ernie.layers.24.mlp.experts.36.down_proj.weight', 'ernie.layers.24.mlp.experts.37.down_proj.weight', 'ernie.layers.24.mlp.experts.38.down_proj.weight', 'ernie.layers.24.mlp.experts.39.down_proj.weight', 'ernie.layers.24.mlp.experts.40.down_proj.weight', 'ernie.layers.24.mlp.experts.41.down_proj.weight', 'ernie.layers.24.mlp.experts.42.down_proj.weight', 'ernie.layers.24.mlp.experts.43.down_proj.weight', 'ernie.layers.24.mlp.experts.44.down_proj.weight', 'ernie.layers.24.mlp.experts.45.down_proj.weight', 'ernie.layers.24.mlp.experts.46.down_proj.weight', 'ernie.layers.24.mlp.experts.47.down_proj.weight', 'ernie.layers.24.mlp.experts.48.down_proj.weight', 'ernie.layers.24.mlp.experts.49.down_proj.weight', 'ernie.layers.24.mlp.experts.50.down_proj.weight', 'ernie.layers.24.mlp.experts.51.down_proj.weight', 'ernie.layers.24.mlp.experts.52.down_proj.weight', 'ernie.layers.24.mlp.experts.53.down_proj.weight', 'ernie.layers.24.mlp.experts.54.down_proj.weight', 'ernie.layers.24.mlp.experts.55.down_proj.weight', 'ernie.layers.24.mlp.experts.56.down_proj.weight', 'ernie.layers.24.mlp.experts.57.down_proj.weight', 'ernie.layers.24.mlp.experts.58.down_proj.weight', 'ernie.layers.24.mlp.experts.59.down_proj.weight', 'ernie.layers.24.mlp.experts.60.down_proj.weight', 'ernie.layers.24.mlp.experts.61.down_proj.weight', 'ernie.layers.24.mlp.experts.62.down_proj.weight', 'ernie.layers.24.mlp.experts.63.down_proj.weight', 'ernie.layers.24.mlp.experts.96.down_proj.weight', 'ernie.layers.24.mlp.experts.97.down_proj.weight', 'ernie.layers.24.mlp.experts.98.down_proj.weight', 'ernie.layers.24.mlp.experts.99.down_proj.weight', 'ernie.layers.24.mlp.experts.100.down_proj.weight', 'ernie.layers.24.mlp.experts.101.down_proj.weight', 'ernie.layers.24.mlp.experts.102.down_proj.weight', 'ernie.layers.24.mlp.experts.103.down_proj.weight', 'ernie.layers.24.mlp.experts.104.down_proj.weight', 'ernie.layers.24.mlp.experts.105.down_proj.weight', 'ernie.layers.24.mlp.experts.106.down_proj.weight', 'ernie.layers.24.mlp.experts.107.down_proj.weight', 'ernie.layers.24.mlp.experts.108.down_proj.weight', 'ernie.layers.24.mlp.experts.109.down_proj.weight', 'ernie.layers.24.mlp.experts.110.down_proj.weight', 'ernie.layers.24.mlp.experts.111.down_proj.weight', 'ernie.layers.24.mlp.experts.112.down_proj.weight', 'ernie.layers.24.mlp.experts.113.down_proj.weight', 'ernie.layers.24.mlp.experts.114.down_proj.weight', 'ernie.layers.24.mlp.experts.115.down_proj.weight', 'ernie.layers.24.mlp.experts.116.down_proj.weight', 'ernie.layers.24.mlp.experts.117.down_proj.weight', 'ernie.layers.24.mlp.experts.118.down_proj.weight', 'ernie.layers.24.mlp.experts.119.down_proj.weight', 'ernie.layers.24.mlp.experts.120.down_proj.weight', 'ernie.layers.24.mlp.experts.121.down_proj.weight', 'ernie.layers.24.mlp.experts.122.down_proj.weight', 'ernie.layers.24.mlp.experts.123.down_proj.weight', 'ernie.layers.24.mlp.experts.124.down_proj.weight', 'ernie.layers.24.mlp.experts.125.down_proj.weight', 'ernie.layers.24.mlp.experts.126.down_proj.weight', 'ernie.layers.24.mlp.experts.127.down_proj.weight'] ernie.layers.25.mlp.image_fused_moe.gate.weight:ernie.layers.25.mlp.gate.weight_1 -ernie.layers.25.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.25.mlp.moe_statics.e_score_correction_bias ernie.layers.25.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.25.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.25.mlp.experts.127.up_gate_proj.weight'] ernie.layers.25.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.25.mlp.experts.32.down_proj.weight', 'ernie.layers.25.mlp.experts.33.down_proj.weight', 'ernie.layers.25.mlp.experts.34.down_proj.weight', 'ernie.layers.25.mlp.experts.35.down_proj.weight', 'ernie.layers.25.mlp.experts.36.down_proj.weight', 'ernie.layers.25.mlp.experts.37.down_proj.weight', 'ernie.layers.25.mlp.experts.38.down_proj.weight', 'ernie.layers.25.mlp.experts.39.down_proj.weight', 'ernie.layers.25.mlp.experts.40.down_proj.weight', 'ernie.layers.25.mlp.experts.41.down_proj.weight', 'ernie.layers.25.mlp.experts.42.down_proj.weight', 'ernie.layers.25.mlp.experts.43.down_proj.weight', 'ernie.layers.25.mlp.experts.44.down_proj.weight', 'ernie.layers.25.mlp.experts.45.down_proj.weight', 'ernie.layers.25.mlp.experts.46.down_proj.weight', 'ernie.layers.25.mlp.experts.47.down_proj.weight', 'ernie.layers.25.mlp.experts.48.down_proj.weight', 'ernie.layers.25.mlp.experts.49.down_proj.weight', 'ernie.layers.25.mlp.experts.50.down_proj.weight', 'ernie.layers.25.mlp.experts.51.down_proj.weight', 'ernie.layers.25.mlp.experts.52.down_proj.weight', 'ernie.layers.25.mlp.experts.53.down_proj.weight', 'ernie.layers.25.mlp.experts.54.down_proj.weight', 'ernie.layers.25.mlp.experts.55.down_proj.weight', 'ernie.layers.25.mlp.experts.56.down_proj.weight', 'ernie.layers.25.mlp.experts.57.down_proj.weight', 'ernie.layers.25.mlp.experts.58.down_proj.weight', 'ernie.layers.25.mlp.experts.59.down_proj.weight', 'ernie.layers.25.mlp.experts.60.down_proj.weight', 'ernie.layers.25.mlp.experts.61.down_proj.weight', 'ernie.layers.25.mlp.experts.62.down_proj.weight', 'ernie.layers.25.mlp.experts.63.down_proj.weight', 'ernie.layers.25.mlp.experts.96.down_proj.weight', 'ernie.layers.25.mlp.experts.97.down_proj.weight', 'ernie.layers.25.mlp.experts.98.down_proj.weight', 'ernie.layers.25.mlp.experts.99.down_proj.weight', 'ernie.layers.25.mlp.experts.100.down_proj.weight', 'ernie.layers.25.mlp.experts.101.down_proj.weight', 'ernie.layers.25.mlp.experts.102.down_proj.weight', 'ernie.layers.25.mlp.experts.103.down_proj.weight', 'ernie.layers.25.mlp.experts.104.down_proj.weight', 'ernie.layers.25.mlp.experts.105.down_proj.weight', 'ernie.layers.25.mlp.experts.106.down_proj.weight', 'ernie.layers.25.mlp.experts.107.down_proj.weight', 'ernie.layers.25.mlp.experts.108.down_proj.weight', 'ernie.layers.25.mlp.experts.109.down_proj.weight', 'ernie.layers.25.mlp.experts.110.down_proj.weight', 'ernie.layers.25.mlp.experts.111.down_proj.weight', 'ernie.layers.25.mlp.experts.112.down_proj.weight', 'ernie.layers.25.mlp.experts.113.down_proj.weight', 'ernie.layers.25.mlp.experts.114.down_proj.weight', 'ernie.layers.25.mlp.experts.115.down_proj.weight', 'ernie.layers.25.mlp.experts.116.down_proj.weight', 'ernie.layers.25.mlp.experts.117.down_proj.weight', 'ernie.layers.25.mlp.experts.118.down_proj.weight', 'ernie.layers.25.mlp.experts.119.down_proj.weight', 'ernie.layers.25.mlp.experts.120.down_proj.weight', 'ernie.layers.25.mlp.experts.121.down_proj.weight', 'ernie.layers.25.mlp.experts.122.down_proj.weight', 'ernie.layers.25.mlp.experts.123.down_proj.weight', 'ernie.layers.25.mlp.experts.124.down_proj.weight', 'ernie.layers.25.mlp.experts.125.down_proj.weight', 'ernie.layers.25.mlp.experts.126.down_proj.weight', 'ernie.layers.25.mlp.experts.127.down_proj.weight'] ernie.layers.26.mlp.image_fused_moe.gate.weight:ernie.layers.26.mlp.gate.weight_1 -ernie.layers.26.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.26.mlp.moe_statics.e_score_correction_bias ernie.layers.26.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.26.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.26.mlp.experts.127.up_gate_proj.weight'] ernie.layers.26.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.26.mlp.experts.32.down_proj.weight', 'ernie.layers.26.mlp.experts.33.down_proj.weight', 'ernie.layers.26.mlp.experts.34.down_proj.weight', 'ernie.layers.26.mlp.experts.35.down_proj.weight', 'ernie.layers.26.mlp.experts.36.down_proj.weight', 'ernie.layers.26.mlp.experts.37.down_proj.weight', 'ernie.layers.26.mlp.experts.38.down_proj.weight', 'ernie.layers.26.mlp.experts.39.down_proj.weight', 'ernie.layers.26.mlp.experts.40.down_proj.weight', 'ernie.layers.26.mlp.experts.41.down_proj.weight', 'ernie.layers.26.mlp.experts.42.down_proj.weight', 'ernie.layers.26.mlp.experts.43.down_proj.weight', 'ernie.layers.26.mlp.experts.44.down_proj.weight', 'ernie.layers.26.mlp.experts.45.down_proj.weight', 'ernie.layers.26.mlp.experts.46.down_proj.weight', 'ernie.layers.26.mlp.experts.47.down_proj.weight', 'ernie.layers.26.mlp.experts.48.down_proj.weight', 'ernie.layers.26.mlp.experts.49.down_proj.weight', 'ernie.layers.26.mlp.experts.50.down_proj.weight', 'ernie.layers.26.mlp.experts.51.down_proj.weight', 'ernie.layers.26.mlp.experts.52.down_proj.weight', 'ernie.layers.26.mlp.experts.53.down_proj.weight', 'ernie.layers.26.mlp.experts.54.down_proj.weight', 'ernie.layers.26.mlp.experts.55.down_proj.weight', 'ernie.layers.26.mlp.experts.56.down_proj.weight', 'ernie.layers.26.mlp.experts.57.down_proj.weight', 'ernie.layers.26.mlp.experts.58.down_proj.weight', 'ernie.layers.26.mlp.experts.59.down_proj.weight', 'ernie.layers.26.mlp.experts.60.down_proj.weight', 'ernie.layers.26.mlp.experts.61.down_proj.weight', 'ernie.layers.26.mlp.experts.62.down_proj.weight', 'ernie.layers.26.mlp.experts.63.down_proj.weight', 'ernie.layers.26.mlp.experts.96.down_proj.weight', 'ernie.layers.26.mlp.experts.97.down_proj.weight', 'ernie.layers.26.mlp.experts.98.down_proj.weight', 'ernie.layers.26.mlp.experts.99.down_proj.weight', 'ernie.layers.26.mlp.experts.100.down_proj.weight', 'ernie.layers.26.mlp.experts.101.down_proj.weight', 'ernie.layers.26.mlp.experts.102.down_proj.weight', 'ernie.layers.26.mlp.experts.103.down_proj.weight', 'ernie.layers.26.mlp.experts.104.down_proj.weight', 'ernie.layers.26.mlp.experts.105.down_proj.weight', 'ernie.layers.26.mlp.experts.106.down_proj.weight', 'ernie.layers.26.mlp.experts.107.down_proj.weight', 'ernie.layers.26.mlp.experts.108.down_proj.weight', 'ernie.layers.26.mlp.experts.109.down_proj.weight', 'ernie.layers.26.mlp.experts.110.down_proj.weight', 'ernie.layers.26.mlp.experts.111.down_proj.weight', 'ernie.layers.26.mlp.experts.112.down_proj.weight', 'ernie.layers.26.mlp.experts.113.down_proj.weight', 'ernie.layers.26.mlp.experts.114.down_proj.weight', 'ernie.layers.26.mlp.experts.115.down_proj.weight', 'ernie.layers.26.mlp.experts.116.down_proj.weight', 'ernie.layers.26.mlp.experts.117.down_proj.weight', 'ernie.layers.26.mlp.experts.118.down_proj.weight', 'ernie.layers.26.mlp.experts.119.down_proj.weight', 'ernie.layers.26.mlp.experts.120.down_proj.weight', 'ernie.layers.26.mlp.experts.121.down_proj.weight', 'ernie.layers.26.mlp.experts.122.down_proj.weight', 'ernie.layers.26.mlp.experts.123.down_proj.weight', 'ernie.layers.26.mlp.experts.124.down_proj.weight', 'ernie.layers.26.mlp.experts.125.down_proj.weight', 'ernie.layers.26.mlp.experts.126.down_proj.weight', 'ernie.layers.26.mlp.experts.127.down_proj.weight'] ernie.layers.27.mlp.image_fused_moe.gate.weight:ernie.layers.27.mlp.gate.weight_1 -ernie.layers.27.mlp.image_fused_moe.experts.gate_correction_bias:ernie.layers.27.mlp.moe_statics.e_score_correction_bias ernie.layers.27.mlp.image_fused_moe.experts.up_gate_proj_weight:['ernie.layers.27.mlp.experts.32.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.33.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.34.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.35.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.36.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.37.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.38.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.39.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.40.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.41.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.42.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.43.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.44.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.45.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.46.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.47.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.48.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.49.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.50.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.51.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.52.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.53.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.54.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.55.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.56.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.57.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.58.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.59.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.60.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.61.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.62.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.63.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.96.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.97.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.98.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.99.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.100.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.101.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.102.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.103.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.104.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.105.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.106.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.107.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.108.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.109.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.110.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.111.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.112.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.113.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.114.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.115.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.116.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.117.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.118.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.119.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.120.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.121.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.122.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.123.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.124.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.125.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.126.up_gate_proj.weight', 'ernie.layers.27.mlp.experts.127.up_gate_proj.weight'] ernie.layers.27.mlp.image_fused_moe.experts.down_proj_weight:['ernie.layers.27.mlp.experts.32.down_proj.weight', 'ernie.layers.27.mlp.experts.33.down_proj.weight', 'ernie.layers.27.mlp.experts.34.down_proj.weight', 'ernie.layers.27.mlp.experts.35.down_proj.weight', 'ernie.layers.27.mlp.experts.36.down_proj.weight', 'ernie.layers.27.mlp.experts.37.down_proj.weight', 'ernie.layers.27.mlp.experts.38.down_proj.weight', 'ernie.layers.27.mlp.experts.39.down_proj.weight', 'ernie.layers.27.mlp.experts.40.down_proj.weight', 'ernie.layers.27.mlp.experts.41.down_proj.weight', 'ernie.layers.27.mlp.experts.42.down_proj.weight', 'ernie.layers.27.mlp.experts.43.down_proj.weight', 'ernie.layers.27.mlp.experts.44.down_proj.weight', 'ernie.layers.27.mlp.experts.45.down_proj.weight', 'ernie.layers.27.mlp.experts.46.down_proj.weight', 'ernie.layers.27.mlp.experts.47.down_proj.weight', 'ernie.layers.27.mlp.experts.48.down_proj.weight', 'ernie.layers.27.mlp.experts.49.down_proj.weight', 'ernie.layers.27.mlp.experts.50.down_proj.weight', 'ernie.layers.27.mlp.experts.51.down_proj.weight', 'ernie.layers.27.mlp.experts.52.down_proj.weight', 'ernie.layers.27.mlp.experts.53.down_proj.weight', 'ernie.layers.27.mlp.experts.54.down_proj.weight', 'ernie.layers.27.mlp.experts.55.down_proj.weight', 'ernie.layers.27.mlp.experts.56.down_proj.weight', 'ernie.layers.27.mlp.experts.57.down_proj.weight', 'ernie.layers.27.mlp.experts.58.down_proj.weight', 'ernie.layers.27.mlp.experts.59.down_proj.weight', 'ernie.layers.27.mlp.experts.60.down_proj.weight', 'ernie.layers.27.mlp.experts.61.down_proj.weight', 'ernie.layers.27.mlp.experts.62.down_proj.weight', 'ernie.layers.27.mlp.experts.63.down_proj.weight', 'ernie.layers.27.mlp.experts.96.down_proj.weight', 'ernie.layers.27.mlp.experts.97.down_proj.weight', 'ernie.layers.27.mlp.experts.98.down_proj.weight', 'ernie.layers.27.mlp.experts.99.down_proj.weight', 'ernie.layers.27.mlp.experts.100.down_proj.weight', 'ernie.layers.27.mlp.experts.101.down_proj.weight', 'ernie.layers.27.mlp.experts.102.down_proj.weight', 'ernie.layers.27.mlp.experts.103.down_proj.weight', 'ernie.layers.27.mlp.experts.104.down_proj.weight', 'ernie.layers.27.mlp.experts.105.down_proj.weight', 'ernie.layers.27.mlp.experts.106.down_proj.weight', 'ernie.layers.27.mlp.experts.107.down_proj.weight', 'ernie.layers.27.mlp.experts.108.down_proj.weight', 'ernie.layers.27.mlp.experts.109.down_proj.weight', 'ernie.layers.27.mlp.experts.110.down_proj.weight', 'ernie.layers.27.mlp.experts.111.down_proj.weight', 'ernie.layers.27.mlp.experts.112.down_proj.weight', 'ernie.layers.27.mlp.experts.113.down_proj.weight', 'ernie.layers.27.mlp.experts.114.down_proj.weight', 'ernie.layers.27.mlp.experts.115.down_proj.weight', 'ernie.layers.27.mlp.experts.116.down_proj.weight', 'ernie.layers.27.mlp.experts.117.down_proj.weight', 'ernie.layers.27.mlp.experts.118.down_proj.weight', 'ernie.layers.27.mlp.experts.119.down_proj.weight', 'ernie.layers.27.mlp.experts.120.down_proj.weight', 'ernie.layers.27.mlp.experts.121.down_proj.weight', 'ernie.layers.27.mlp.experts.122.down_proj.weight', 'ernie.layers.27.mlp.experts.123.down_proj.weight', 'ernie.layers.27.mlp.experts.124.down_proj.weight', 'ernie.layers.27.mlp.experts.125.down_proj.weight', 'ernie.layers.27.mlp.experts.126.down_proj.weight', 'ernie.layers.27.mlp.experts.127.down_proj.weight'] vision_model.patch_embed.proj.weight:vision_model.patch_embed.proj.weight diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..80e4047c0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import signal +import socket +import subprocess +from typing import Any, Union + +import pytest + + +def kill_process_on_port(port: int): + """ + Kill processes that are listening on the given port. + Uses `lsof` to find process ids and sends SIGKILL. + """ + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + for pid in output.splitlines(): + os.kill(int(pid), signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + +def clean_ports(ports_to_clean: list[int]): + """ + Kill all processes occupying the ports listed in PORTS_TO_CLEAN. + """ + for port in ports_to_clean: + kill_process_on_port(port) + + +def is_port_open(host: str, port: int, timeout=1.0): + """ + Check if a TCP port is open on the given host. + Returns True if connection succeeds, False otherwise. + """ + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +class FDRunner: + def __init__( + self, + model_name_or_path: str, + tensor_parallel_size: int = 1, + max_model_len: int = 1024, + load_choices: str = "default", + quantization: str = "None", + **kwargs, + ) -> None: + from fastdeploy.entrypoints.llm import LLM + + ports_to_clean = [] + if "engine_worker_queue_port" in kwargs: + ports_to_clean.append(kwargs["engine_worker_queue_port"]) + clean_ports(ports_to_clean) + self.llm = LLM( + model=model_name_or_path, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + load_choices=load_choices, + quantization=quantization, + **kwargs, + ) + + def generate( + self, + prompts: list[str], + sampling_params, + **kwargs: Any, + ) -> list[tuple[list[list[int]], list[str]]]: + + req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs) + outputs: list[tuple[list[list[int]], list[str]]] = [] + sample_output_ids: list[list[int]] = [] + sample_output_strs: list[str] = [] + for output in req_outputs: + sample_output_ids.append(output.outputs.token_ids) + sample_output_strs.append(output.outputs.text) + outputs.append((sample_output_ids, sample_output_strs)) + return outputs + + def generate_topp0( + self, + prompts: Union[list[str]], + max_tokens: int, + **kwargs: Any, + ) -> list[tuple[list[int], str]]: + from fastdeploy.engine.sampling_params import SamplingParams + + topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=max_tokens) + outputs = self.generate(prompts, topp_params, **kwargs) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.llm + + +@pytest.fixture(scope="session") +def fd_runner(): + return FDRunner diff --git a/tests/model_loader/__init__.py b/tests/model_loader/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/model_loader/test_common_model.py b/tests/model_loader/test_common_model.py new file mode 100644 index 000000000..e4eec9925 --- /dev/null +++ b/tests/model_loader/test_common_model.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import traceback +import warnings +from multiprocessing import Process, Queue + +import pytest + +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) +MAX_WAIT_SECONDS = 60 * 5 + +prompts = ["解释下“温故而知新", "Hello, how are you?"] +TokensIdText = list[tuple[list[int], str]] +# (token_ids, text) + + +def check_tokens_id_and_text_close( + *, + outputs_0_lst: TokensIdText, + outputs_1_lst: TokensIdText, + name_0: str, + name_1: str, + warn_on_mismatch: bool = True, +) -> None: + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): + assert len(outputs_0) == len(outputs_1) + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + is_tok_mismatch = output_id_0 != output_id_1 + if is_tok_mismatch and warn_on_mismatch: + fail_msg = ( + f"Test{prompt_idx}:" + f"\nMatched tokens:\t{output_ids_0[:idx]}" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}" + ) + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn(fail_msg, stacklevel=2) + break + else: + if output_str_0 != output_str_1 and warn_on_mismatch: + fail_msg = f"Test{prompt_idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}" + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn(fail_msg, stacklevel=2) + + +def form_model_get_output( + fd_runner, + model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + load_choices, + result_queue, +): + try: + with fd_runner( + model_path, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + load_choices=load_choices, + quantization=quantization, + engine_worker_queue_port=FD_ENGINE_QUEUE_PORT, + ) as fd_model: + fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens) + result_queue.put(fd_outputs) + except Exception: + print(f"Failed using {load_choices} laoder to load model from {model_path}.") + traceback.print_exc() + pytest.fail(f"Failed to initialize LLM model from {model_path}") + + +model_param_map = { + "Qwen3-0.6B": { + "quantizations": ["None", "wint4", "wint8"], + }, + "ernie-4_5-21b-a3b-bf16-paddle": { + "tensor_parallel_size": 2, + "quantizations": ["wint8"], + }, +} + +params = [] +for model, cfg in model_param_map.items(): + for q in cfg["quantizations"]: + params.append( + pytest.param( + model, + cfg.get("tensor_parallel_size", 1), + cfg.get("max_model_len", 1024), + q, + cfg.get("max_tokens", 32), + marks=[pytest.mark.core_model], + ) + ) + + +@pytest.mark.parametrize( + "model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens", + params, +) +def test_common_model( + fd_runner, + model_name_or_path: str, + tensor_parallel_size: int, + max_model_len: int, + max_tokens: int, + quantization: str, +) -> None: + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, model_name_or_path) + else: + model_path = model_name_or_path + result_queue = Queue() + p = Process( + target=form_model_get_output, + args=( + fd_runner, + model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + "default", + result_queue, + ), + ) + p.start() + p.join() + fd_outputs_v0 = result_queue.get(timeout=60) + + p = Process( + target=form_model_get_output, + args=( + fd_runner, + model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + "default_v1", + result_queue, + ), + ) + p.start() + p.join() + fd_outputs_v1 = result_queue.get(timeout=60) + check_tokens_id_and_text_close( + outputs_0_lst=fd_outputs_v0, + outputs_1_lst=fd_outputs_v1, + name_0="default loader", + name_1="default_v1 loader", + )