mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
fix ep lm head (#3244)
Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
@@ -118,9 +118,7 @@ class ParallelLMHead(nn.Layer):
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
if self.bias_key is not None:
|
||||
self.bias.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_bias_key)).astype(paddle.get_default_dtype())
|
||||
)
|
||||
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
|
||||
else:
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
@@ -148,7 +146,7 @@ class ParallelLMHead(nn.Layer):
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
if self.linear_bias_key is None:
|
||||
if self.bias_key is None:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
|
||||
|
Reference in New Issue
Block a user