mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 21:51:31 +08:00
[Precision] Change lm_head layer running in float32 (#3596)
* support lm_head fp32 bf16 fp16 * delete print * code check * check * check * code check * check * check
This commit is contained in:
@@ -108,6 +108,7 @@ class ModelConfig:
|
|||||||
self.enable_mm = False
|
self.enable_mm = False
|
||||||
self.enable_redundant_experts = False
|
self.enable_redundant_experts = False
|
||||||
self.redundant_experts_num = 0
|
self.redundant_experts_num = 0
|
||||||
|
self.lm_head_fp32: bool = False
|
||||||
|
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
|
@@ -15,10 +15,10 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from dataclasses import fields as dataclass_fields
|
from dataclasses import fields as dataclass_fields
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import os
|
|
||||||
|
|
||||||
from fastdeploy.engine.config import (
|
from fastdeploy.engine.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
@@ -315,6 +315,11 @@ class EngineArgs:
|
|||||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
lm_head_fp32: bool = None
|
||||||
|
"""
|
||||||
|
Flag to specify the data type of lm_head as FP32.
|
||||||
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""
|
"""
|
||||||
Post-initialization processing to set default tokenizer if not provided.
|
Post-initialization processing to set default tokenizer if not provided.
|
||||||
@@ -466,6 +471,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.enable_logprob,
|
default=EngineArgs.enable_logprob,
|
||||||
help="Enable output of token-level log probabilities.",
|
help="Enable output of token-level log probabilities.",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--lm_head-fp32",
|
||||||
|
action="store_true",
|
||||||
|
default=EngineArgs.lm_head_fp32,
|
||||||
|
help="Specify the dtype of lm_head weight as float32.",
|
||||||
|
)
|
||||||
|
|
||||||
# Parallel processing parameters group
|
# Parallel processing parameters group
|
||||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||||
@@ -769,6 +780,7 @@ class EngineArgs:
|
|||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
dynamic_load_weight=self.dynamic_load_weight,
|
dynamic_load_weight=self.dynamic_load_weight,
|
||||||
load_strategy=self.load_strategy,
|
load_strategy=self.load_strategy,
|
||||||
|
lm_head_fp32=self.lm_head_fp32,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_cache_config(self, model_cfg) -> CacheConfig:
|
def create_cache_config(self, model_cfg) -> CacheConfig:
|
||||||
@@ -855,7 +867,7 @@ class EngineArgs:
|
|||||||
if self.enable_chunked_prefill:
|
if self.enable_chunked_prefill:
|
||||||
self.max_num_batched_tokens = 2048
|
self.max_num_batched_tokens = 2048
|
||||||
else:
|
else:
|
||||||
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
|
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||||
self.max_num_batched_tokens = self.max_model_len
|
self.max_num_batched_tokens = self.max_model_len
|
||||||
else:
|
else:
|
||||||
self.max_num_batched_tokens = 8192
|
self.max_num_batched_tokens = 8192
|
||||||
|
@@ -51,6 +51,7 @@ class ModelConfig:
|
|||||||
load_strategy: str = "ipc_snapshot",
|
load_strategy: str = "ipc_snapshot",
|
||||||
quantization: str = None,
|
quantization: str = None,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
|
lm_head_fp32: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the ModelConfig class.
|
Initialize the ModelConfig class.
|
||||||
@@ -65,6 +66,7 @@ class ModelConfig:
|
|||||||
self.dynamic_load_weight = dynamic_load_weight
|
self.dynamic_load_weight = dynamic_load_weight
|
||||||
self.load_strategy = load_strategy
|
self.load_strategy = load_strategy
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
|
self.lm_head_fp32 = lm_head_fp32
|
||||||
|
|
||||||
config_file = os.path.join(model_name_or_path, config_json_file)
|
config_file = os.path.join(model_name_or_path, config_json_file)
|
||||||
if os.path.isfile(model_name_or_path):
|
if os.path.isfile(model_name_or_path):
|
||||||
@@ -804,7 +806,7 @@ class Config:
|
|||||||
if self.cache_config.enable_chunked_prefill:
|
if self.cache_config.enable_chunked_prefill:
|
||||||
self.max_num_batched_tokens = 2048
|
self.max_num_batched_tokens = 2048
|
||||||
else:
|
else:
|
||||||
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
|
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||||
self.max_num_batched_tokens = self.max_model_len
|
self.max_num_batched_tokens = self.max_model_len
|
||||||
else:
|
else:
|
||||||
self.max_num_batched_tokens = 8192
|
self.max_num_batched_tokens = 8192
|
||||||
@@ -855,7 +857,7 @@ class Config:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.cache_config.enable_chunked_prefill:
|
if not self.cache_config.enable_chunked_prefill:
|
||||||
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
|
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||||
|
@@ -1099,6 +1099,7 @@ class LLMEngine:
|
|||||||
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||||
"enable_logprob": self.cfg.enable_logprob,
|
"enable_logprob": self.cfg.enable_logprob,
|
||||||
"enable_mm": self.cfg.enable_mm,
|
"enable_mm": self.cfg.enable_mm,
|
||||||
|
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||||
}
|
}
|
||||||
for worker_flag, value in worker_append_flag.items():
|
for worker_flag, value in worker_append_flag.items():
|
||||||
if value:
|
if value:
|
||||||
|
@@ -23,7 +23,7 @@ from paddle.distributed import fleet
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
|
|
||||||
from .utils import get_tensor
|
from .utils import get_tensor, temporary_dtype
|
||||||
|
|
||||||
|
|
||||||
class ParallelLMHead(nn.Layer):
|
class ParallelLMHead(nn.Layer):
|
||||||
@@ -38,6 +38,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
|
dtype: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parallelized LMhead.
|
Parallelized LMhead.
|
||||||
@@ -50,6 +51,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
embedding_dim (int): size of hidden state.
|
embedding_dim (int): size of hidden state.
|
||||||
prefix (str): The name of current layer. Defaults to "".
|
prefix (str): The name of current layer. Defaults to "".
|
||||||
with_bias (bool): whether to have bias. Default: False.
|
with_bias (bool): whether to have bias. Default: False.
|
||||||
|
dtype (str): The dtype of weight. Defalut: None.
|
||||||
"""
|
"""
|
||||||
super(ParallelLMHead, self).__init__()
|
super(ParallelLMHead, self).__init__()
|
||||||
self.weight_key: str = prefix + ".weight"
|
self.weight_key: str = prefix + ".weight"
|
||||||
@@ -62,9 +64,9 @@ class ParallelLMHead(nn.Layer):
|
|||||||
|
|
||||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||||
|
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
|
||||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||||
|
with temporary_dtype(self.dtype):
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
self.weight = self.create_parameter(
|
self.weight = self.create_parameter(
|
||||||
shape=[embedding_dim, num_embeddings],
|
shape=[embedding_dim, num_embeddings],
|
||||||
@@ -103,20 +105,20 @@ class ParallelLMHead(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype))
|
||||||
else:
|
else:
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
self.linear.weight.set_value(
|
self.linear.weight.set_value(
|
||||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
|
||||||
if self.linear.weight.shape != weight_tensor.shape:
|
if self.linear.weight.shape != weight_tensor.shape:
|
||||||
weight_tensor = weight_tensor.transpose([1, 0])
|
weight_tensor = weight_tensor.transpose([1, 0])
|
||||||
self.linear.weight.set_value(weight_tensor)
|
self.linear.weight.set_value(weight_tensor)
|
||||||
|
|
||||||
if self.bias_key is not None:
|
if self.bias_key is not None:
|
||||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype)
|
||||||
self.linear.bias.set_value(bias)
|
self.linear.bias.set_value(bias)
|
||||||
|
|
||||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||||
@@ -131,7 +133,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
logits = input
|
logits = input
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
logits = paddle.matmul(logits, self.weight)
|
logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight)
|
||||||
else:
|
else:
|
||||||
logits = self.linear(logits)
|
logits = self.linear(logits.astype(self.linear.weight.dtype))
|
||||||
return logits
|
return logits
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -377,3 +378,15 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
|
|||||||
paddle.Tensor: An empty tensor with the specified shape and data type.
|
paddle.Tensor: An empty tensor with the specified shape and data type.
|
||||||
"""
|
"""
|
||||||
return paddle.empty(list(shape), dtype=dtype)
|
return paddle.empty(list(shape), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temporary_dtype(dtype: str):
|
||||||
|
"""Temporarily set Paddle default dtype"""
|
||||||
|
orig_dtype = paddle.get_default_dtype()
|
||||||
|
try:
|
||||||
|
if dtype is not None and dtype == "float32":
|
||||||
|
paddle.set_default_dtype(dtype)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
paddle.set_default_dtype(orig_dtype)
|
||||||
|
@@ -613,7 +613,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
""" """
|
""" """
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@@ -412,13 +412,13 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
"""
|
"""
|
||||||
self.ernie.load_state_dict(state_dict)
|
self.ernie.load_state_dict(state_dict)
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||||
else:
|
else:
|
||||||
self.lm_head.load_state_dict(state_dict)
|
self.lm_head.load_state_dict(state_dict)
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -363,7 +363,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
compute logits
|
compute logits
|
||||||
"""
|
"""
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -570,13 +570,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
self.vision_model.load_state_dict(state_dict)
|
self.vision_model.load_state_dict(state_dict)
|
||||||
self.resampler_model.load_state_dict(state_dict)
|
self.resampler_model.load_state_dict(state_dict)
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||||
else:
|
else:
|
||||||
self.lm_head.load_state_dict(state_dict)
|
self.lm_head.load_state_dict(state_dict)
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -326,7 +326,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
|||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
""" """
|
""" """
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -257,14 +257,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
|||||||
"""
|
"""
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
|
||||||
else:
|
else:
|
||||||
self.lm_head.load_state_dict(state_dict)
|
self.lm_head.load_state_dict(state_dict)
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
""" """
|
""" """
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -298,7 +298,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
|||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
""" """
|
""" """
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = paddle.cast(logits, paddle.float32)
|
logits = logits.astype(paddle.float32)
|
||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -587,6 +587,11 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable output of token-level log probabilities.",
|
help="Enable output of token-level log probabilities.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm_head_fp32",
|
||||||
|
action="store_true",
|
||||||
|
help="The data type of lm_head",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
Reference in New Issue
Block a user