mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Model] tp+ep support v1_loader (#5600)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Model] tp+ep support v1_loader * fix * fix mtp_linear * fix mtp_linear * fix * fix * fix v0 loader * fix * Add get_tensor for EP * fix linear weight_loader * fix typo * fix
This commit is contained in:
@@ -229,6 +229,11 @@ class Attention(nn.Layer):
|
||||
self.sinks.set_value(sinks_tensor)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
|
||||
loaded_weight = get_tensor(loaded_weight).astype("float32")
|
||||
param.copy_(loaded_weight, False)
|
||||
return
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())
|
||||
if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
|
||||
loaded_weight = 1.0 / loaded_weight
|
||||
|
||||
@@ -283,10 +283,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
if output_dim == 0:
|
||||
h2d_copy(param[: shard_weight.shape[0]], shard_weight)
|
||||
if not current_platform.is_maca():
|
||||
param[shard_weight.shape[0] :].fill_(0)
|
||||
if param.shape[0] != shard_weight.shape[0]:
|
||||
param[shard_weight.shape[0] :].fill_(0)
|
||||
else:
|
||||
h2d_copy(param[:, : shard_weight.shape[1]], shard_weight)
|
||||
param[:, shard_weight.shape[1] :].fill_(0)
|
||||
if param.shape[1] != shard_weight.shape[1]:
|
||||
param[:, shard_weight.shape[1] :].fill_(0)
|
||||
|
||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||
"""
|
||||
|
||||
@@ -356,25 +356,31 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
self.output_sizes = output_sizes
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
if loaded_shard_id is None:
|
||||
axis = -1 if (self.fd_config.model_config.model_format == "torch") ^ True else 0
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=0, end=loaded_weight.shape[axis])
|
||||
|
||||
if loaded_shard_id == "q_a":
|
||||
param_shard_offset = 0
|
||||
param_shard_size = self.output_sizes[0]
|
||||
else:
|
||||
# loaded_shard_id == "kv_a"
|
||||
param_shard_offset = self.output_sizes[0]
|
||||
param_shard_size = self.output_sizes[1]
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
param = slice_fn(
|
||||
param,
|
||||
(self.fd_config.model_config.model_format == "torch") ^ True,
|
||||
start=param_shard_offset,
|
||||
end=param_shard_offset + param_shard_size,
|
||||
)
|
||||
assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"]
|
||||
|
||||
if loaded_shard_id in ["q_a", "gate"]:
|
||||
param_shard_offset = 0
|
||||
param_shard_size = self.output_sizes[0]
|
||||
elif loaded_shard_id in ["kv_a", "up"]:
|
||||
param_shard_offset = self.output_sizes[0]
|
||||
param_shard_size = self.output_sizes[1]
|
||||
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
param = slice_fn(
|
||||
param,
|
||||
(self.fd_config.model_config.model_format == "torch") ^ True,
|
||||
start=param_shard_offset,
|
||||
end=param_shard_offset + param_shard_size,
|
||||
)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
|
||||
@@ -102,6 +102,10 @@ class ParallelLMHead(nn.Layer):
|
||||
},
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
if self.tp_size > 1:
|
||||
if with_bias:
|
||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
||||
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
|
||||
@@ -274,10 +274,13 @@ class FusedMoE(nn.Layer):
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
|
||||
if self.ep_size > 1 or weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
||||
shard_offsets = [
|
||||
@@ -293,7 +296,6 @@ class FusedMoE(nn.Layer):
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused")
|
||||
else:
|
||||
if weight_need_transpose and source != "fused":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
|
||||
@@ -86,6 +86,9 @@ class ParallelEHProjection(nn.Layer):
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
||||
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
|
||||
@@ -130,6 +130,10 @@ class RMSNorm(nn.Layer):
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -138,7 +138,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
weight_shape = layer.weight_shape
|
||||
weight_scale_inv_shape = weight_scale_inv_shape
|
||||
extra_weight_attrs["output_dim"] = (
|
||||
not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None
|
||||
not extra_weight_attrs["output_dim"]
|
||||
if extra_weight_attrs.get("output_dim", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
|
||||
Reference in New Issue
Block a user