mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
support tmp (#3675)
This commit is contained in:
@@ -18,6 +18,8 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
from .utils import get_tensor
|
from .utils import get_tensor
|
||||||
|
|
||||||
|
|
||||||
@@ -75,6 +77,9 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
gather_output=need_gather,
|
gather_output=need_gather,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
fuse_matmul_bias=False, # False diff更小
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
|
if self.bias_key is not None:
|
||||||
|
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
||||||
else:
|
else:
|
||||||
self.linear = RowParallelLinear(
|
self.linear = RowParallelLinear(
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
@@ -85,6 +90,7 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
input_is_parallel=False,
|
input_is_parallel=False,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
fuse_matmul_bias=False, # False diff更小
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""
|
"""
|
||||||
|
@@ -365,6 +365,43 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
# else:
|
# else:
|
||||||
# self.lm_head.load_state_dict(state_dict)
|
# self.lm_head.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def load_weights(self, weights_iterator) -> None:
|
||||||
|
"""
|
||||||
|
Load model parameters from a given weights_iterator object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.utils import default_weight_loader
|
||||||
|
|
||||||
|
all_param_mapping = [
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("embed_tokens.embeddings", "embed_tokens", None, None),
|
||||||
|
("lm_head.linear", "lm_head", None, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
shard_id = None
|
||||||
|
|
||||||
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
|
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||||
|
if weight_name not in loaded_weight_name:
|
||||||
|
continue
|
||||||
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[model_param_name]
|
||||||
|
shard_id = shard_id
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if loaded_weight_name not in params_dict.keys():
|
||||||
|
continue
|
||||||
|
param = params_dict[loaded_weight_name]
|
||||||
|
|
||||||
|
# Get weight loader from parameter and set weight
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
"""
|
"""
|
||||||
compute logits
|
compute logits
|
||||||
|
Reference in New Issue
Block a user