diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 79804aa2d..a5ac1876e 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -229,6 +229,11 @@ class Attention(nn.Layer): self.sinks.set_value(sinks_tensor) def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name): + loaded_weight = get_tensor(loaded_weight).astype("float32") + param.copy_(loaded_weight, False) + return + loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype()) if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp loaded_weight = 1.0 / loaded_weight diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 52d7dadee..5ae82efe4 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -283,10 +283,12 @@ class VocabParallelEmbedding(nn.Layer): if output_dim == 0: h2d_copy(param[: shard_weight.shape[0]], shard_weight) if not current_platform.is_maca(): - param[shard_weight.shape[0] :].fill_(0) + if param.shape[0] != shard_weight.shape[0]: + param[shard_weight.shape[0] :].fill_(0) else: h2d_copy(param[:, : shard_weight.shape[1]], shard_weight) - param[:, shard_weight.shape[1] :].fill_(0) + if param.shape[1] != shard_weight.shape[1]: + param[:, shard_weight.shape[1] :].fill_(0) def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 14d1e0dcc..49b25dc3d 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -356,25 +356,31 @@ class MergedReplicatedLinear(ReplicatedLinear): self.output_sizes = output_sizes def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - assert loaded_shard_id in ["q_a", "kv_a"] if not param._is_initialized(): param.initialize() + if loaded_shard_id is None: + axis = -1 if (self.fd_config.model_config.model_format == "torch") ^ True else 0 + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=0, end=loaded_weight.shape[axis]) - if loaded_shard_id == "q_a": - param_shard_offset = 0 - param_shard_size = self.output_sizes[0] else: - # loaded_shard_id == "kv_a" - param_shard_offset = self.output_sizes[0] - param_shard_size = self.output_sizes[1] - if hasattr(param, "tensor_track"): - param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) - param = slice_fn( - param, - (self.fd_config.model_config.model_format == "torch") ^ True, - start=param_shard_offset, - end=param_shard_offset + param_shard_size, - ) + assert loaded_shard_id in ["q_a", "kv_a", "gate", "up"] + + if loaded_shard_id in ["q_a", "gate"]: + param_shard_offset = 0 + param_shard_size = self.output_sizes[0] + elif loaded_shard_id in ["kv_a", "up"]: + param_shard_offset = self.output_sizes[0] + param_shard_size = self.output_sizes[1] + + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) + param = slice_fn( + param, + (self.fd_config.model_config.model_format == "torch") ^ True, + start=param_shard_offset, + end=param_shard_offset + param_shard_size, + ) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" ) diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index ff2797a04..a7bff3905 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -102,6 +102,10 @@ class ParallelLMHead(nn.Layer): }, ) set_weight_attrs(self.linear.weight, {"output_dim": True}) + if self.tp_size > 1: + if with_bias: + set_weight_attrs(self.linear.bias, {"output_dim": True}) + else: self.linear = RowParallelLinear( embedding_dim, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 5b1be52d1..11725729a 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -274,10 +274,13 @@ class FusedMoE(nn.Layer): if not param._is_initialized(): param.initialize() weight_need_transpose = getattr(param, "weight_need_transpose", False) + + if self.ep_size > 1 or weight_need_transpose: + loaded_weight = get_tensor(loaded_weight) + if shard_id is None: # 1.gate up fused in disk if weight_need_transpose: - loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] shard_offsets = [ @@ -293,7 +296,6 @@ class FusedMoE(nn.Layer): self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused") else: if weight_need_transpose and source != "fused": - loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) # 2.gate up splited in disk assert shard_id in ["gate", "down", "up"] diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index b1699720b..e1f52d738 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -86,6 +86,9 @@ class ParallelEHProjection(nn.Layer): ) if self.tp_size > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) + if self.bias_key is not None: + set_weight_attrs(self.linear.bias, {"output_dim": True}) + else: self.linear = RowParallelLinear( embedding_dim, diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index ec1f0e658..1e37d73bd 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -130,6 +130,10 @@ class RMSNorm(nn.Layer): dtype=self._norm_weight_dtype, ) + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype) + param.copy_(loaded_weight, False) + def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index a7b61fc0e..59daa2384 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -138,7 +138,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): weight_shape = layer.weight_shape weight_scale_inv_shape = weight_scale_inv_shape extra_weight_attrs["output_dim"] = ( - not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None + not extra_weight_attrs["output_dim"] + if extra_weight_attrs.get("output_dim", None) is not None + else None ) layer.weight_dtype = "float8_e4m3fn"