support tmp (#3675)

This commit is contained in:
YuanRisheng
2025-08-28 19:42:32 +08:00
committed by GitHub
parent 368bbd9dc6
commit 808b548761
2 changed files with 43 additions and 0 deletions

View File

@@ -18,6 +18,8 @@ import paddle
from paddle import nn from paddle import nn
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.model_executor.utils import set_weight_attrs
from .utils import get_tensor from .utils import get_tensor
@@ -75,6 +77,9 @@ class ParallelEHProjection(nn.Layer):
gather_output=need_gather, gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})
else: else:
self.linear = RowParallelLinear( self.linear = RowParallelLinear(
embedding_dim, embedding_dim,
@@ -85,6 +90,7 @@ class ParallelEHProjection(nn.Layer):
input_is_parallel=False, input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """

View File

@@ -365,6 +365,43 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
# else: # else:
# self.lm_head.load_state_dict(state_dict) # self.lm_head.load_state_dict(state_dict)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import default_weight_loader
all_param_mapping = [
# (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None),
]
params_dict = dict(self.named_parameters())
shard_id = None
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
param = params_dict[model_param_name]
shard_id = shard_id
break
else:
if loaded_weight_name not in params_dict.keys():
continue
param = params_dict[loaded_weight_name]
# Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
""" """
compute logits compute logits