refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* refactor rl get_name_mappings_to_training

* fix tp>1

* change variable name(ffn1->up_gate_proj/ffn2->down_proj)

* change variable name(linear_weight->weight/linear_bias->bias)

* add rl names mapping for vl

* fix ernie 0.3B error

* fix develop code

* fix
This commit is contained in:
Yuanle Liu
2025-07-15 22:31:42 +08:00
committed by GitHub
parent e7bcbbab52
commit 61b3997b85
47 changed files with 1591 additions and 1629 deletions

View File

@@ -52,11 +52,11 @@ class ParallelLMHead(nn.Layer):
with_bias (bool): whether to have bias. Default: False.
"""
super(ParallelLMHead, self).__init__()
self.linear_weight_key: str = prefix + ".weight"
self.weight_key: str = prefix + ".weight"
if with_bias:
self.linear_bias_key: Optional[str] = prefix + ".bias"
self.bias_key: Optional[str] = prefix + ".bias"
else:
self.linear_bias_key: Optional[str] = None
self.bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True
@@ -74,26 +74,26 @@ class ParallelLMHead(nn.Layer):
else:
if self.column_cut:
need_gather = True
self.out_linear = ColumnParallelLinear(
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
self.out_linear = RowParallelLinear(
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
@@ -109,25 +109,25 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()))
else:
if self.tie_word_embeddings:
self.out_linear.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()).transpose([1, 0]))
else:
weight_tensor = get_tensor(
state_dict.pop(self.linear_weight_key)).astype(
state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype())
if self.out_linear.weight.shape != weight_tensor.shape:
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.out_linear.weight.set_value(weight_tensor)
self.linear.weight.set_value(weight_tensor)
if self.linear_bias_key is not None:
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
"""
@@ -143,5 +143,5 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.out_linear(logits)
logits = self.linear(logits)
return logits