[BugFix] fix qwen3-embedding model tp>1 (#4223)

* support qwen3-embedding

* fix ci bug

* fix

* fix ci bug

* fix ci bug

* fix

* fix qwen3-embedding

* fix

* fix

* fix
This commit is contained in:
lizexu123
2025-09-24 14:13:26 +08:00
committed by GitHub
parent 3161014e49
commit e8318b7477
3 changed files with 11 additions and 4 deletions

View File

@@ -164,7 +164,9 @@ class VocabParallelEmbedding(nn.Layer):
), ),
) )
if self.world_size > 1: if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader}) set_weight_attrs(self.embeddings.weight, {"output_dim": False})
if num_embeddings % self.world_size != 0:
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
else: else:
# column cut embedding # column cut embedding
self.embeddings = nn.Embedding( self.embeddings = nn.Embedding(
@@ -236,6 +238,9 @@ class VocabParallelEmbedding(nn.Layer):
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
if not param._is_initialized():
param.initialize()
loaded_weight = get_tensor(loaded_weight) loaded_weight = get_tensor(loaded_weight)
if param.dtype != loaded_weight.dtype: if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
@@ -247,7 +252,7 @@ class VocabParallelEmbedding(nn.Layer):
assert ( assert (
param.shape == loaded_weight.shape param.shape == loaded_weight.shape
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}" ), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
param.set_value(loaded_weight) param.copy_(loaded_weight, False)
return return
start_idx = self.shard_indices.org_vocab_start_index start_idx = self.shard_indices.org_vocab_start_index

View File

@@ -22,7 +22,6 @@ import paddle.nn as nn
from fastdeploy.config import ModelConfig from fastdeploy.config import ModelConfig
from fastdeploy.model_executor.layers.activation import get_act_fn from fastdeploy.model_executor.layers.activation import get_act_fn
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
from fastdeploy.transformer_utils.config import get_hf_file_to_dict from fastdeploy.transformer_utils.config import get_hf_file_to_dict
_T = TypeVar("_T", bound=type[nn.Layer]) _T = TypeVar("_T", bound=type[nn.Layer])
@@ -191,6 +190,8 @@ def as_embedding_model(cls: _T) -> _T:
please implement your own model if this is not the case. please implement your own model if this is not the case.
""" """
# Avoid modifying existing embedding models # Avoid modifying existing embedding models
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
if is_pooling_model(cls): if is_pooling_model(cls):
return cls return cls

View File

@@ -1321,6 +1321,7 @@ class GPUModelRunner(ModelRunnerBase):
logits = None logits = None
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
# TODO(lizexu123) The preheating the pooling function have not been implemented yet.
pass pass
else: else:
# 4. Execute spec decode # 4. Execute spec decode
@@ -1632,9 +1633,9 @@ class GPUModelRunner(ModelRunnerBase):
logits = None logits = None
# 4. Compute logits, Sample # 4. Compute logits, Sample
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
# TODO(lizexu123) The execution of the pooling function have not been implemented yet.
pass pass
else: else:
# 4. Execute spec decode
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding: if not self.speculative_decoding: