[BugFix] fix qwen3-embedding model tp>1 (#4223)

* support qwen3-embedding

* fix ci bug

* fix

* fix ci bug

* fix ci bug

* fix

* fix qwen3-embedding

* fix

* fix

* fix
This commit is contained in:
lizexu123
2025-09-24 14:13:26 +08:00
committed by GitHub
parent 3161014e49
commit e8318b7477
3 changed files with 11 additions and 4 deletions

View File

@@ -164,7 +164,9 @@ class VocabParallelEmbedding(nn.Layer):
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
if num_embeddings % self.world_size != 0:
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -236,6 +238,9 @@ class VocabParallelEmbedding(nn.Layer):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
if not param._is_initialized():
param.initialize()
loaded_weight = get_tensor(loaded_weight)
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
@@ -247,7 +252,7 @@ class VocabParallelEmbedding(nn.Layer):
assert (
param.shape == loaded_weight.shape
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
param.set_value(loaded_weight)
param.copy_(loaded_weight, False)
return
start_idx = self.shard_indices.org_vocab_start_index