bf16 deepseek (#5379)

This commit is contained in:
bukejiyu
2025-12-05 22:23:30 +08:00
committed by GitHub
parent b2908b8e82
commit f6eb4dcc40

View File

@@ -367,11 +367,14 @@ class MergedReplicatedLinear(ReplicatedLinear):
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
param_output_dim = True
if hasattr(param, "tensor_track"):
param_output_dim = param.tensor_track.output_dim
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, param_output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(
param,
(self.fd_config.model_config.model_format == "torch") ^ True,
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)