mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] ernie4_5_vl_moe
support huggingface safetensor loading (#3750)
* update * update * update in tp * add todo * update --------- Co-authored-by: aquagull <hongyuh@qq.com>
This commit is contained in:
@@ -100,13 +100,11 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
if self.tie_word_embeddings:
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
|
||||
)
|
||||
weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
|
||||
)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
|
||||
|
||||
self.embeddings.weight.set_value(weight_tensor)
|
||||
|
||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user