mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-18 14:40:44 +08:00
[Precision] Support lm_head layer running in float32 (#3597)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* support lm_head fp32 bf16 fp16 * support lm_head fp32 bf16 fp16 * add doc and check code * lm_head_fp32 specify lm_head as fp32 * code check * check doc
This commit is contained in:
@@ -22,7 +22,11 @@ from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
set_weight_attrs,
|
||||
temporary_dtype,
|
||||
)
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -39,6 +43,7 @@ class ParallelLMHead(nn.Layer):
|
||||
embedding_dim: int,
|
||||
prefix: str = "",
|
||||
with_bias: bool = False,
|
||||
dtype: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parallelized LMhead.
|
||||
@@ -51,6 +56,7 @@ class ParallelLMHead(nn.Layer):
|
||||
embedding_dim (int): size of hidden state.
|
||||
prefix (str): The name of current layer. Defaults to "".
|
||||
with_bias (bool): whether to have bias. Default: False.
|
||||
dtype (str): The dtype of weight. Defalut: None.
|
||||
"""
|
||||
super(ParallelLMHead, self).__init__()
|
||||
self.weight_key: str = prefix + ".weight"
|
||||
@@ -65,49 +71,51 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
|
||||
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
with temporary_dtype(self.dtype):
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
@@ -119,16 +127,16 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
|
||||
)
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype)
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
@@ -141,6 +149,6 @@ class ParallelLMHead(nn.Layer):
|
||||
Returns:
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
logits = input.astype(self.linear.weight.dtype)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -694,7 +694,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
@@ -511,10 +511,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
self.ernie.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
if hasattr(self.lm_head, "linear"):
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
else: # ep
|
||||
self.lm_head.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
@@ -581,11 +578,11 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -370,7 +370,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
||||
compute logits
|
||||
"""
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -650,7 +650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
# because we use lazy guard and is not initialized by default
|
||||
if not self.lm_head.linear.weight._is_initialized():
|
||||
self.lm_head.linear.weight.initialize()
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
@@ -666,13 +666,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
self.vision_model.load_state_dict(state_dict)
|
||||
self.resampler_model.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -379,7 +379,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -294,7 +294,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict):
|
||||
@@ -308,14 +308,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
self.model.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -436,7 +436,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import paddle
|
||||
@@ -185,3 +186,15 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_dtype(dtype: str):
|
||||
"""Temporarily set Paddle default dtype"""
|
||||
orig_dtype = paddle.get_default_dtype()
|
||||
try:
|
||||
if dtype is not None and dtype == "float32":
|
||||
paddle.set_default_dtype(dtype)
|
||||
yield
|
||||
finally:
|
||||
paddle.set_default_dtype(orig_dtype)
|
||||
|
Reference in New Issue
Block a user