[Precision] Change lm_head layer running in float32 (#3596)

* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
This commit is contained in:
chen
2025-08-26 20:20:06 +08:00
committed by GitHub
parent 2136990144
commit d233e3c97c
14 changed files with 85 additions and 49 deletions

View File

@@ -108,6 +108,7 @@ class ModelConfig:
self.enable_mm = False self.enable_mm = False
self.enable_redundant_experts = False self.enable_redundant_experts = False
self.redundant_experts_num = 0 self.redundant_experts_num = 0
self.lm_head_fp32: bool = False
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):

View File

@@ -15,10 +15,10 @@
""" """
import json import json
import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import os
from fastdeploy.engine.config import ( from fastdeploy.engine.config import (
CacheConfig, CacheConfig,
@@ -315,6 +315,11 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
""" """
lm_head_fp32: bool = None
"""
Flag to specify the data type of lm_head as FP32.
"""
def __post_init__(self): def __post_init__(self):
""" """
Post-initialization processing to set default tokenizer if not provided. Post-initialization processing to set default tokenizer if not provided.
@@ -466,6 +471,12 @@ class EngineArgs:
default=EngineArgs.enable_logprob, default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.", help="Enable output of token-level log probabilities.",
) )
model_group.add_argument(
"--lm_head-fp32",
action="store_true",
default=EngineArgs.lm_head_fp32,
help="Specify the dtype of lm_head weight as float32.",
)
# Parallel processing parameters group # Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration") parallel_group = parser.add_argument_group("Parallel Configuration")
@@ -769,6 +780,7 @@ class EngineArgs:
quantization=self.quantization, quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight, dynamic_load_weight=self.dynamic_load_weight,
load_strategy=self.load_strategy, load_strategy=self.load_strategy,
lm_head_fp32=self.lm_head_fp32,
) )
def create_cache_config(self, model_cfg) -> CacheConfig: def create_cache_config(self, model_cfg) -> CacheConfig:
@@ -855,7 +867,7 @@ class EngineArgs:
if self.enable_chunked_prefill: if self.enable_chunked_prefill:
self.max_num_batched_tokens = 2048 self.max_num_batched_tokens = 2048
else: else:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len self.max_num_batched_tokens = self.max_model_len
else: else:
self.max_num_batched_tokens = 8192 self.max_num_batched_tokens = 8192

View File

@@ -51,6 +51,7 @@ class ModelConfig:
load_strategy: str = "ipc_snapshot", load_strategy: str = "ipc_snapshot",
quantization: str = None, quantization: str = None,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
lm_head_fp32: bool = False,
): ):
""" """
Initialize the ModelConfig class. Initialize the ModelConfig class.
@@ -65,6 +66,7 @@ class ModelConfig:
self.dynamic_load_weight = dynamic_load_weight self.dynamic_load_weight = dynamic_load_weight
self.load_strategy = load_strategy self.load_strategy = load_strategy
self.quantization = quantization self.quantization = quantization
self.lm_head_fp32 = lm_head_fp32
config_file = os.path.join(model_name_or_path, config_json_file) config_file = os.path.join(model_name_or_path, config_json_file)
if os.path.isfile(model_name_or_path): if os.path.isfile(model_name_or_path):
@@ -804,7 +806,7 @@ class Config:
if self.cache_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
self.max_num_batched_tokens = 2048 self.max_num_batched_tokens = 2048
else: else:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len self.max_num_batched_tokens = self.max_model_len
else: else:
self.max_num_batched_tokens = 8192 self.max_num_batched_tokens = 8192
@@ -855,7 +857,7 @@ class Config:
) )
if not self.cache_config.enable_chunked_prefill: if not self.cache_config.enable_chunked_prefill:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
assert self.max_num_batched_tokens >= self.max_model_len, ( assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} " f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}" f"should be larger than or equal to max_model_len: {self.max_model_len}"

View File

@@ -1099,6 +1099,7 @@ class LLMEngine:
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
"enable_logprob": self.cfg.enable_logprob, "enable_logprob": self.cfg.enable_logprob,
"enable_mm": self.cfg.enable_mm, "enable_mm": self.cfg.enable_mm,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
} }
for worker_flag, value in worker_append_flag.items(): for worker_flag, value in worker_append_flag.items():
if value: if value:

View File

@@ -23,7 +23,7 @@ from paddle.distributed import fleet
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from .utils import get_tensor from .utils import get_tensor, temporary_dtype
class ParallelLMHead(nn.Layer): class ParallelLMHead(nn.Layer):
@@ -38,6 +38,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim: int, embedding_dim: int,
prefix: str = "", prefix: str = "",
with_bias: bool = False, with_bias: bool = False,
dtype: str = None,
) -> None: ) -> None:
""" """
Parallelized LMhead. Parallelized LMhead.
@@ -50,6 +51,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim (int): size of hidden state. embedding_dim (int): size of hidden state.
prefix (str): The name of current layer. Defaults to "". prefix (str): The name of current layer. Defaults to "".
with_bias (bool): whether to have bias. Default: False. with_bias (bool): whether to have bias. Default: False.
dtype (str): The dtype of weight. Defalut: None.
""" """
super(ParallelLMHead, self).__init__() super(ParallelLMHead, self).__init__()
self.weight_key: str = prefix + ".weight" self.weight_key: str = prefix + ".weight"
@@ -62,9 +64,9 @@ class ParallelLMHead(nn.Layer):
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
with temporary_dtype(self.dtype):
if self.use_ep: if self.use_ep:
self.weight = self.create_parameter( self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings], shape=[embedding_dim, num_embeddings],
@@ -103,20 +105,20 @@ class ParallelLMHead(nn.Layer):
""" """
if self.use_ep: if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype))
else: else:
if self.tie_word_embeddings: if self.tie_word_embeddings:
self.linear.weight.set_value( self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0]) get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
) )
else: else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
if self.linear.weight.shape != weight_tensor.shape: if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0]) weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor) self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None: if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype)
self.linear.bias.set_value(bias) self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor: def forward(self, input: paddle.Tensor) -> paddle.Tensor:
@@ -131,7 +133,7 @@ class ParallelLMHead(nn.Layer):
""" """
logits = input logits = input
if self.use_ep: if self.use_ep:
logits = paddle.matmul(logits, self.weight) logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight)
else: else:
logits = self.linear(logits) logits = self.linear(logits.astype(self.linear.weight.dtype))
return logits return logits

View File

@@ -15,6 +15,7 @@
""" """
import functools import functools
from contextlib import contextmanager
from typing import Tuple, Union from typing import Tuple, Union
import numpy as np import numpy as np
@@ -377,3 +378,15 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
paddle.Tensor: An empty tensor with the specified shape and data type. paddle.Tensor: An empty tensor with the specified shape and data type.
""" """
return paddle.empty(list(shape), dtype=dtype) return paddle.empty(list(shape), dtype=dtype)
@contextmanager
def temporary_dtype(dtype: str):
"""Temporarily set Paddle default dtype"""
orig_dtype = paddle.get_default_dtype()
try:
if dtype is not None and dtype == "float32":
paddle.set_default_dtype(dtype)
yield
finally:
paddle.set_default_dtype(orig_dtype)

View File

@@ -613,7 +613,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
""" """ """ """
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -412,13 +412,13 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
""" """
self.ernie.load_state_dict(state_dict) self.ernie.load_state_dict(state_dict)
if self.tie_word_embeddings: if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else: else:
self.lm_head.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -363,7 +363,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
compute logits compute logits
""" """
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -570,13 +570,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.vision_model.load_state_dict(state_dict) self.vision_model.load_state_dict(state_dict)
self.resampler_model.load_state_dict(state_dict) self.resampler_model.load_state_dict(state_dict)
if self.tie_word_embeddings: if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else: else:
self.lm_head.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -326,7 +326,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
""" """ """ """
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -257,14 +257,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
""" """
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
if self.tie_word_embeddings: if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else: else:
self.lm_head.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
""" """ """ """
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -298,7 +298,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
""" """ """ """
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32) logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf") logits[:, self.ori_vocab_size :] = -float("inf")
return logits return logits

View File

@@ -587,6 +587,11 @@ def parse_args():
action="store_true", action="store_true",
help="Enable output of token-level log probabilities.", help="Enable output of token-level log probabilities.",
) )
parser.add_argument(
"--lm_head_fp32",
action="store_true",
help="The data type of lm_head",
)
args = parser.parse_args() args = parser.parse_args()
return args return args