mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 05:30:58 +08:00
[Precision] Change lm_head layer running in float32 (#3596)
* support lm_head fp32 bf16 fp16 * delete print * code check * check * check * code check * check * check
This commit is contained in:
@@ -108,6 +108,7 @@ class ModelConfig:
|
||||
self.enable_mm = False
|
||||
self.enable_redundant_experts = False
|
||||
self.redundant_experts_num = 0
|
||||
self.lm_head_fp32: bool = False
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
|
@@ -15,10 +15,10 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
import os
|
||||
|
||||
from fastdeploy.engine.config import (
|
||||
CacheConfig,
|
||||
@@ -315,6 +315,11 @@ class EngineArgs:
|
||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||
"""
|
||||
|
||||
lm_head_fp32: bool = None
|
||||
"""
|
||||
Flag to specify the data type of lm_head as FP32.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -466,6 +471,12 @@ class EngineArgs:
|
||||
default=EngineArgs.enable_logprob,
|
||||
help="Enable output of token-level log probabilities.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--lm_head-fp32",
|
||||
action="store_true",
|
||||
default=EngineArgs.lm_head_fp32,
|
||||
help="Specify the dtype of lm_head weight as float32.",
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
@@ -769,6 +780,7 @@ class EngineArgs:
|
||||
quantization=self.quantization,
|
||||
dynamic_load_weight=self.dynamic_load_weight,
|
||||
load_strategy=self.load_strategy,
|
||||
lm_head_fp32=self.lm_head_fp32,
|
||||
)
|
||||
|
||||
def create_cache_config(self, model_cfg) -> CacheConfig:
|
||||
@@ -855,7 +867,7 @@ class EngineArgs:
|
||||
if self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = 8192
|
||||
|
@@ -51,6 +51,7 @@ class ModelConfig:
|
||||
load_strategy: str = "ipc_snapshot",
|
||||
quantization: str = None,
|
||||
download_dir: Optional[str] = None,
|
||||
lm_head_fp32: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ModelConfig class.
|
||||
@@ -65,6 +66,7 @@ class ModelConfig:
|
||||
self.dynamic_load_weight = dynamic_load_weight
|
||||
self.load_strategy = load_strategy
|
||||
self.quantization = quantization
|
||||
self.lm_head_fp32 = lm_head_fp32
|
||||
|
||||
config_file = os.path.join(model_name_or_path, config_json_file)
|
||||
if os.path.isfile(model_name_or_path):
|
||||
@@ -804,7 +806,7 @@ class Config:
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = 8192
|
||||
@@ -855,7 +857,7 @@ class Config:
|
||||
)
|
||||
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
|
@@ -1099,6 +1099,7 @@ class LLMEngine:
|
||||
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||
"enable_logprob": self.cfg.enable_logprob,
|
||||
"enable_mm": self.cfg.enable_mm,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
}
|
||||
for worker_flag, value in worker_append_flag.items():
|
||||
if value:
|
||||
|
@@ -23,7 +23,7 @@ from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
from .utils import get_tensor, temporary_dtype
|
||||
|
||||
|
||||
class ParallelLMHead(nn.Layer):
|
||||
@@ -38,6 +38,7 @@ class ParallelLMHead(nn.Layer):
|
||||
embedding_dim: int,
|
||||
prefix: str = "",
|
||||
with_bias: bool = False,
|
||||
dtype: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parallelized LMhead.
|
||||
@@ -50,6 +51,7 @@ class ParallelLMHead(nn.Layer):
|
||||
embedding_dim (int): size of hidden state.
|
||||
prefix (str): The name of current layer. Defaults to "".
|
||||
with_bias (bool): whether to have bias. Default: False.
|
||||
dtype (str): The dtype of weight. Defalut: None.
|
||||
"""
|
||||
super(ParallelLMHead, self).__init__()
|
||||
self.weight_key: str = prefix + ".weight"
|
||||
@@ -62,9 +64,9 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
with temporary_dtype(self.dtype):
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
@@ -103,20 +105,20 @@ class ParallelLMHead(nn.Layer):
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype))
|
||||
else:
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
|
||||
)
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype)
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
@@ -131,7 +133,7 @@ class ParallelLMHead(nn.Layer):
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight)
|
||||
else:
|
||||
logits = self.linear(logits)
|
||||
logits = self.linear(logits.astype(self.linear.weight.dtype))
|
||||
return logits
|
||||
|
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -377,3 +378,15 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
|
||||
paddle.Tensor: An empty tensor with the specified shape and data type.
|
||||
"""
|
||||
return paddle.empty(list(shape), dtype=dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_dtype(dtype: str):
|
||||
"""Temporarily set Paddle default dtype"""
|
||||
orig_dtype = paddle.get_default_dtype()
|
||||
try:
|
||||
if dtype is not None and dtype == "float32":
|
||||
paddle.set_default_dtype(dtype)
|
||||
yield
|
||||
finally:
|
||||
paddle.set_default_dtype(orig_dtype)
|
||||
|
@@ -613,7 +613,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
@@ -412,13 +412,13 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
self.ernie.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -363,7 +363,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
||||
compute logits
|
||||
"""
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -570,13 +570,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
self.vision_model.load_state_dict(state_dict)
|
||||
self.resampler_model.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -326,7 +326,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -257,14 +257,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
self.model.load_state_dict(state_dict)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -298,7 +298,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits = logits.astype(paddle.float32)
|
||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
@@ -587,6 +587,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Enable output of token-level log probabilities.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm_head_fp32",
|
||||
action="store_true",
|
||||
help="The data type of lm_head",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
Reference in New Issue
Block a user