Simplify the Config code (#2770)

* simplify the code

* fix vl

* delete config

* fix

* perfect code

* fix ci

* fix xpu

* fix xpu

* fix server

* resolve conflict

* fix mtp

* resolve conflict

* fix xpu

* fix xpu

* fix vl

* fix log

* fix qwen moe

* fix qwen moe

* fix qwen moe
This commit is contained in:
YuanRisheng
2025-07-14 19:50:05 +08:00
committed by GitHub
parent 2e81792d64
commit 4c7b8bc458
34 changed files with 551 additions and 911 deletions

View File

@@ -59,13 +59,11 @@ class VocabParallelEmbedding(nn.Layer):
self.world_size: int = hcg.get_model_parallel_world_size()
self.ring_id: int = hcg.get_model_parallel_group().id
self.use_rope: bool = fd_config.model_config.use_rope
self.rope_head_dim: int = fd_config.model_config.rope_head_dim
self.use_ep: bool = fd_config.parallel_config.use_ep
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.freeze_embedding: bool = fd_config.model_config.freeze_embedding
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
@@ -104,15 +102,7 @@ class VocabParallelEmbedding(nn.Layer):
)
self.prefix = prefix
if self.freeze_embedding:
self.word_embeddings.weight.learning_rate = 0.0
if not self.use_rope:
self.position_embeddings.weight.learning_rate = 0.0
self.dropout = nn.Dropout(self.hidden_dropout_prob)
self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim),
dtype="int8")
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
@@ -122,6 +112,7 @@ class VocabParallelEmbedding(nn.Layer):
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
a = state_dict[self.prefix + ".weight"]
if self.tie_word_embeddings:
self.word_embeddings.weight.set_value(
get_tensor(state_dict[self.prefix + ".weight"]).astype(