fix lm head bias (#3185)

Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
RichardWooSJTU
2025-08-05 15:40:24 +08:00
committed by GitHub
parent f5c64a074c
commit 1e9a8e8cef

View File

@@ -72,6 +72,13 @@ class ParallelLMHead(nn.Layer):
dtype=paddle.get_default_dtype(),
is_bias=False,
)
if self.bias_key is not None:
self.bias = self.create_parameter(
shape=[num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=True,
)
else:
if self.column_cut:
need_gather = True
@@ -107,6 +114,10 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
if self.bias_key is not None:
self.bias.set_value(
get_tensor(state_dict.pop(self.linear_bias_key)).astype(paddle.get_default_dtype())
)
else:
if self.tie_word_embeddings:
self.linear.weight.set_value(
@@ -134,7 +145,10 @@ class ParallelLMHead(nn.Layer):
"""
logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
if self.linear_bias_key is None:
logits = paddle.matmul(logits, self.weight)
else:
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
else:
logits = self.linear(logits)
return logits