[Model] tp+ep support v1_loader (#5600)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Model] tp+ep support v1_loader

* fix

* fix mtp_linear

* fix mtp_linear

* fix

* fix

* fix v0 loader

* fix

* Add get_tensor for EP

* fix linear weight_loader

* fix typo

* fix
This commit is contained in:
Longzhi Wang
2025-12-18 15:27:12 +08:00
committed by GitHub
parent 5300e73f8b
commit a30a5b4216
8 changed files with 48 additions and 20 deletions

View File

@@ -229,6 +229,11 @@ class Attention(nn.Layer):
self.sinks.set_value(sinks_tensor)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
loaded_weight = get_tensor(loaded_weight).astype("float32")
param.copy_(loaded_weight, False)
return
loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())
if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
loaded_weight = 1.0 / loaded_weight

View File

@@ -283,10 +283,12 @@ class VocabParallelEmbedding(nn.Layer):
if output_dim == 0:
h2d_copy(param[: shard_weight.shape[0]], shard_weight)
if not current_platform.is_maca():
param[shard_weight.shape[0] :].fill_(0)
if param.shape[0] != shard_weight.shape[0]:
param[shard_weight.shape[0] :].fill_(0)
else:
h2d_copy(param[:, : shard_weight.shape[1]], shard_weight)
param[:, shard_weight.shape[1] :].fill_(0)
if param.shape[1] != shard_weight.shape[1]:
param[:, shard_weight.shape[1] :].fill_(0)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""

View File

@@ -356,25 +356,31 @@ class MergedReplicatedLinear(ReplicatedLinear):
self.output_sizes = output_sizes
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in ["q_a", "kv_a"]
if not param._is_initialized():
param.initialize()
if loaded_shard_id is None:
axis = -1 if (self.fd_config.model_config.model_format == "torch") ^ True else 0
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=0, end=loaded_weight.shape[axis])
if loaded_shard_id == "q_a":
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
else:
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(
param,
(self.fd_config.model_config.model_format == "torch") ^ True,
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"]
if loaded_shard_id in ["q_a", "gate"]:
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
elif loaded_shard_id in ["kv_a", "up"]:
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(
param,
(self.fd_config.model_config.model_format == "torch") ^ True,
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)

View File

@@ -102,6 +102,10 @@ class ParallelLMHead(nn.Layer):
},
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.tp_size > 1:
if with_bias:
set_weight_attrs(self.linear.bias, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,

View File

@@ -274,10 +274,13 @@ class FusedMoE(nn.Layer):
if not param._is_initialized():
param.initialize()
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.ep_size > 1 or weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
if shard_id is None:
# 1.gate up fused in disk
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
shard_offsets = [
@@ -293,7 +296,6 @@ class FusedMoE(nn.Layer):
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused")
else:
if weight_need_transpose and source != "fused":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]

View File

@@ -86,6 +86,9 @@ class ParallelEHProjection(nn.Layer):
)
if self.tp_size > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,

View File

@@ -130,6 +130,10 @@ class RMSNorm(nn.Layer):
dtype=self._norm_weight_dtype,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.

View File

@@ -138,7 +138,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
weight_shape = layer.weight_shape
weight_scale_inv_shape = weight_scale_inv_shape
extra_weight_attrs["output_dim"] = (
not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None
not extra_weight_attrs["output_dim"]
if extra_weight_attrs.get("output_dim", None) is not None
else None
)
layer.weight_dtype = "float8_e4m3fn"