[Precision] Support lm_head layer running in float32 (#3597)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support lm_head fp32 bf16 fp16

* support lm_head fp32 bf16 fp16

* add doc and check code

* lm_head_fp32 specify lm_head as fp32

* code check

* check doc
This commit is contained in:
chen
2025-08-27 11:34:53 +08:00
committed by GitHub
parent ad319a87cc
commit ce9c0917c5
15 changed files with 99 additions and 60 deletions

View File

@@ -51,6 +51,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
| ```lm_head_fp32``` | `bool` | Specify the dtype of the lm_head layer as FP32. |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?

View File

@@ -49,6 +49,7 @@
| ```chat_template``` | `str` | 指定模型拼接使用的模板支持字符串与文件路径默认为None如未指定则使用模型默认模板 |
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器以便从模型输出中抽取 function call内容|
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径以便注册不在代码库中的parserparser中代码格式需遵循代码库中格式|
| ```lm_head_fp32``` | `bool` | 指定lm_head层的类型为 FP32 |
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?

View File

@@ -129,6 +129,7 @@ class ModelConfig:
self.quantization = None
self.pad_token_id: int = -1
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.model_format = "auto"
for key, value in args.items():
if hasattr(self, key):

View File

@@ -370,6 +370,11 @@ class EngineArgs:
- "default_v1": default_v1 loader.
"""
lm_head_fp32: bool = False
"""
Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -576,6 +581,12 @@ class EngineArgs:
default=EngineArgs.early_stop_config,
help="the config for early stop.",
)
model_group.add_argument(
"--lm_head-fp32",
action="store_true",
default=EngineArgs.lm_head_fp32,
help="Specify the dtype of lm_head weight as float32.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")

View File

@@ -477,6 +477,7 @@ class LLMEngine:
"disable_any_whitespace": self.cfg.disable_any_whitespace,
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
"enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
}
for worker_flag, value in worker_append_flag.items():
if value:

View File

@@ -22,7 +22,11 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from fastdeploy.model_executor.utils import (
default_weight_loader,
set_weight_attrs,
temporary_dtype,
)
from .utils import get_tensor
@@ -39,6 +43,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim: int,
prefix: str = "",
with_bias: bool = False,
dtype: str = None,
) -> None:
"""
Parallelized LMhead.
@@ -51,6 +56,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim (int): size of hidden state.
prefix (str): The name of current layer. Defaults to "".
with_bias (bool): whether to have bias. Default: False.
dtype (str): The dtype of weight. Defalut: None.
"""
super(ParallelLMHead, self).__init__()
self.weight_key: str = prefix + ".weight"
@@ -65,49 +71,51 @@ class ParallelLMHead(nn.Layer):
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
with temporary_dtype(self.dtype):
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False})
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
@@ -119,16 +127,16 @@ class ParallelLMHead(nn.Layer):
if self.tie_word_embeddings:
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype)
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
@@ -141,6 +149,6 @@ class ParallelLMHead(nn.Layer):
Returns:
Tensor: The output tensor after processing through the layer.
"""
logits = input
logits = input.astype(self.linear.weight.dtype)
logits = self.linear(logits)
return logits

View File

@@ -694,7 +694,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -511,10 +511,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
"""
self.ernie.load_state_dict(state_dict)
if self.tie_word_embeddings:
if hasattr(self.lm_head, "linear"):
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
else: # ep
self.lm_head.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
@@ -581,11 +578,11 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -370,7 +370,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
compute logits
"""
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -650,7 +650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
# because we use lazy guard and is not initialized by default
if not self.lm_head.linear.weight._is_initialized():
self.lm_head.linear.weight.initialize()
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
@@ -666,13 +666,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.vision_model.load_state_dict(state_dict)
self.resampler_model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -379,7 +379,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -294,7 +294,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
@paddle.no_grad()
def set_state_dict(self, state_dict):
@@ -308,14 +308,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
"""
self.model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -436,7 +436,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
from contextlib import contextmanager
from typing import Any, Optional, Union
import paddle
@@ -185,3 +186,15 @@ def default_weight_loader(fd_config: FDConfig) -> None:
param.copy_(loaded_weight, False)
return fn
@contextmanager
def temporary_dtype(dtype: str):
"""Temporarily set Paddle default dtype"""
orig_dtype = paddle.get_default_dtype()
try:
if dtype is not None and dtype == "float32":
paddle.set_default_dtype(dtype)
yield
finally:
paddle.set_default_dtype(orig_dtype)

View File

@@ -599,6 +599,12 @@ def parse_args():
help="The ips of multinode deployment.",
)
parser.add_argument(
"--lm_head_fp32",
action="store_true",
help="Flag to specify dtype of lm_head as FP32",
)
args = parser.parse_args()
return args