support tmp (#3675)

This commit is contained in:
YuanRisheng
2025-08-28 19:42:32 +08:00
committed by GitHub
parent 368bbd9dc6
commit 808b548761
2 changed files with 43 additions and 0 deletions

View File

@@ -18,6 +18,8 @@ import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.model_executor.utils import set_weight_attrs
from .utils import get_tensor
@@ -75,6 +77,9 @@ class ParallelEHProjection(nn.Layer):
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
@@ -85,6 +90,7 @@ class ParallelEHProjection(nn.Layer):
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict):
"""