diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 91bd5841e..e126aed2b 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -367,10 +367,11 @@ class MergedReplicatedLinear(ReplicatedLinear): # loaded_shard_id == "kv_a" param_shard_offset = self.output_sizes[0] param_shard_size = self.output_sizes[1] - + param_output_dim = True if hasattr(param, "tensor_track"): + param_output_dim = param.tensor_track.output_dim param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) - param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size) + param = slice_fn(param, param_output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" )