mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[BugFix] fix qwen3-embedding model tp>1 (#4223)
* support qwen3-embedding * fix ci bug * fix * fix ci bug * fix ci bug * fix * fix qwen3-embedding * fix * fix * fix
This commit is contained in:
@@ -164,7 +164,9 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
if num_embeddings % self.world_size != 0:
|
||||
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
@@ -236,6 +238,9 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
if param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
@@ -247,7 +252,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
assert (
|
||||
param.shape == loaded_weight.shape
|
||||
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
|
||||
param.set_value(loaded_weight)
|
||||
param.copy_(loaded_weight, False)
|
||||
return
|
||||
|
||||
start_idx = self.shard_indices.org_vocab_start_index
|
||||
|
Reference in New Issue
Block a user