deepseek torch (#5373)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-12-04 23:26:53 +08:00
committed by GitHub
parent 1b5fd79d6b
commit 620d1da1c9

View File

@@ -367,10 +367,11 @@ class MergedReplicatedLinear(ReplicatedLinear):
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
param_output_dim = True
if hasattr(param, "tensor_track"):
param_output_dim = param.tensor_track.output_dim
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, param_output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)