[bugfix]fix blockwisefp8 and all_reduce (#3243)

* fix

* update

* fix linear for prequant loader
This commit is contained in:
bukejiyu
2025-08-06 23:54:33 +08:00
committed by GitHub
parent 3a15e0c53e
commit 9408e667a5
4 changed files with 37 additions and 24 deletions

View File

@@ -81,7 +81,8 @@ class VocabParallelEmbedding(nn.Layer):
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -91,7 +92,8 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)