mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix lm head bias (#3185)
Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
@@ -72,6 +72,13 @@ class ParallelLMHead(nn.Layer):
|
|||||||
dtype=paddle.get_default_dtype(),
|
dtype=paddle.get_default_dtype(),
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
)
|
)
|
||||||
|
if self.bias_key is not None:
|
||||||
|
self.bias = self.create_parameter(
|
||||||
|
shape=[num_embeddings],
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
is_bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.column_cut:
|
if self.column_cut:
|
||||||
need_gather = True
|
need_gather = True
|
||||||
@@ -107,6 +114,10 @@ class ParallelLMHead(nn.Layer):
|
|||||||
|
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||||
|
if self.bias_key is not None:
|
||||||
|
self.bias.set_value(
|
||||||
|
get_tensor(state_dict.pop(self.linear_bias_key)).astype(paddle.get_default_dtype())
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
self.linear.weight.set_value(
|
self.linear.weight.set_value(
|
||||||
@@ -134,7 +145,10 @@ class ParallelLMHead(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
logits = input
|
logits = input
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
|
if self.linear_bias_key is None:
|
||||||
logits = paddle.matmul(logits, self.weight)
|
logits = paddle.matmul(logits, self.weight)
|
||||||
|
else:
|
||||||
|
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
|
||||||
else:
|
else:
|
||||||
logits = self.linear(logits)
|
logits = self.linear(logits)
|
||||||
return logits
|
return logits
|
||||||
|
Reference in New Issue
Block a user