[CP] CP Lm head fp32 and temp_logprob to release/2.1 (#3766)

* [Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3552)

* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing

* infer engine support temp_scaled_logprobs and top_p_normalized_logprobs

* delete some code

* code check

* code check and add doc

* fix tokenizer.decoder(-1), return 'Invalid Token'

* add ci for temp_scaled and top_p logprobs

* check test

* check seq len time shape

* logprob clip inf

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>

* [Precision] Support lm_head layer running in float32 (#3597)

* support lm_head fp32 bf16 fp16

* support lm_head fp32 bf16 fp16

* add doc and check code

* lm_head_fp32 specify lm_head as fp32

* code check

* check doc

* code check

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
chen
2025-09-01 19:56:54 +08:00
committed by GitHub
parent 4da603daec
commit 1e19833ba5
22 changed files with 188 additions and 54 deletions

View File

@@ -22,7 +22,7 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.model_executor.models.utils import set_weight_attrs, temporary_dtype
from .utils import get_tensor
@@ -39,6 +39,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim: int,
prefix: str = "",
with_bias: bool = False,
dtype: str = None,
) -> None:
"""
Parallelized LMhead.
@@ -51,6 +52,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim (int): size of hidden state.
prefix (str): The name of current layer. Defaults to "".
with_bias (bool): whether to have bias. Default: False.
dtype (str): The dtype of weight. Defalut: None.
"""
super(ParallelLMHead, self).__init__()
self.weight_key: str = prefix + ".weight"
@@ -63,39 +65,40 @@ class ParallelLMHead(nn.Layer):
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
)
else:
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
with temporary_dtype(self.dtype):
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
@@ -106,20 +109,20 @@ class ParallelLMHead(nn.Layer):
"""
if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype))
else:
if self.tie_word_embeddings:
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.weight.dtype)
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
@@ -134,7 +137,8 @@ class ParallelLMHead(nn.Layer):
"""
logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight)
else:
logits = self.linear(logits)
logits = self.linear(logits.astype(self.linear.weight.dtype))
print(self.linear.weight.dtype)
return logits