mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[BugFix] fix qwen3-embedding model tp>1 (#4223)
* support qwen3-embedding * fix ci bug * fix * fix ci bug * fix ci bug * fix * fix qwen3-embedding * fix * fix * fix
This commit is contained in:
@@ -164,7 +164,9 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if self.world_size > 1:
|
if self.world_size > 1:
|
||||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
|
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||||
|
if num_embeddings % self.world_size != 0:
|
||||||
|
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
|
||||||
else:
|
else:
|
||||||
# column cut embedding
|
# column cut embedding
|
||||||
self.embeddings = nn.Embedding(
|
self.embeddings = nn.Embedding(
|
||||||
@@ -236,6 +238,9 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
|
|
||||||
|
if not param._is_initialized():
|
||||||
|
param.initialize()
|
||||||
|
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
if param.dtype != loaded_weight.dtype:
|
if param.dtype != loaded_weight.dtype:
|
||||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||||
@@ -247,7 +252,7 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
assert (
|
assert (
|
||||||
param.shape == loaded_weight.shape
|
param.shape == loaded_weight.shape
|
||||||
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
|
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
|
||||||
param.set_value(loaded_weight)
|
param.copy_(loaded_weight, False)
|
||||||
return
|
return
|
||||||
|
|
||||||
start_idx = self.shard_indices.org_vocab_start_index
|
start_idx = self.shard_indices.org_vocab_start_index
|
||||||
|
@@ -22,7 +22,6 @@ import paddle.nn as nn
|
|||||||
|
|
||||||
from fastdeploy.config import ModelConfig
|
from fastdeploy.config import ModelConfig
|
||||||
from fastdeploy.model_executor.layers.activation import get_act_fn
|
from fastdeploy.model_executor.layers.activation import get_act_fn
|
||||||
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
|
|
||||||
from fastdeploy.transformer_utils.config import get_hf_file_to_dict
|
from fastdeploy.transformer_utils.config import get_hf_file_to_dict
|
||||||
|
|
||||||
_T = TypeVar("_T", bound=type[nn.Layer])
|
_T = TypeVar("_T", bound=type[nn.Layer])
|
||||||
@@ -191,6 +190,8 @@ def as_embedding_model(cls: _T) -> _T:
|
|||||||
please implement your own model if this is not the case.
|
please implement your own model if this is not the case.
|
||||||
"""
|
"""
|
||||||
# Avoid modifying existing embedding models
|
# Avoid modifying existing embedding models
|
||||||
|
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
|
||||||
|
|
||||||
if is_pooling_model(cls):
|
if is_pooling_model(cls):
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@@ -1321,6 +1321,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
logits = None
|
logits = None
|
||||||
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
|
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
|
||||||
|
# TODO(lizexu123) The preheating the pooling function have not been implemented yet.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# 4. Execute spec decode
|
# 4. Execute spec decode
|
||||||
@@ -1632,9 +1633,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
logits = None
|
logits = None
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
|
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
|
||||||
|
# TODO(lizexu123) The execution of the pooling function have not been implemented yet.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# 4. Execute spec decode
|
|
||||||
logits = self.model.compute_logits(hidden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
if not self.speculative_decoding:
|
if not self.speculative_decoding:
|
||||||
|
Reference in New Issue
Block a user