diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 89efeee6f..faf093b33 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -108,6 +108,7 @@ class ModelConfig: self.enable_mm = False self.enable_redundant_experts = False self.redundant_experts_num = 0 + self.lm_head_fp32: bool = False for key, value in args.items(): if hasattr(self, key): diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index c31543ee1..8a47d815d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -15,10 +15,10 @@ """ import json +import os from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional -import os from fastdeploy.engine.config import ( CacheConfig, @@ -315,6 +315,11 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + lm_head_fp32: bool = None + """ + Flag to specify the data type of lm_head as FP32. + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -466,6 +471,12 @@ class EngineArgs: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--lm_head-fp32", + action="store_true", + default=EngineArgs.lm_head_fp32, + help="Specify the dtype of lm_head weight as float32.", + ) # Parallel processing parameters group parallel_group = parser.add_argument_group("Parallel Configuration") @@ -769,6 +780,7 @@ class EngineArgs: quantization=self.quantization, dynamic_load_weight=self.dynamic_load_weight, load_strategy=self.load_strategy, + lm_head_fp32=self.lm_head_fp32, ) def create_cache_config(self, model_cfg) -> CacheConfig: @@ -855,7 +867,7 @@ class EngineArgs: if self.enable_chunked_prefill: self.max_num_batched_tokens = 2048 else: - if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): self.max_num_batched_tokens = self.max_model_len else: self.max_num_batched_tokens = 8192 diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 25dcc19c3..343dfa0eb 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -51,6 +51,7 @@ class ModelConfig: load_strategy: str = "ipc_snapshot", quantization: str = None, download_dir: Optional[str] = None, + lm_head_fp32: bool = False, ): """ Initialize the ModelConfig class. @@ -65,6 +66,7 @@ class ModelConfig: self.dynamic_load_weight = dynamic_load_weight self.load_strategy = load_strategy self.quantization = quantization + self.lm_head_fp32 = lm_head_fp32 config_file = os.path.join(model_name_or_path, config_json_file) if os.path.isfile(model_name_or_path): @@ -804,7 +806,7 @@ class Config: if self.cache_config.enable_chunked_prefill: self.max_num_batched_tokens = 2048 else: - if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): self.max_num_batched_tokens = self.max_model_len else: self.max_num_batched_tokens = 8192 @@ -855,7 +857,7 @@ class Config: ) if not self.cache_config.enable_chunked_prefill: - if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): assert self.max_num_batched_tokens >= self.max_model_len, ( f"max_num_batched_tokens: {self.max_num_batched_tokens} " f"should be larger than or equal to max_model_len: {self.max_model_len}" diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 318869c8d..a55850ec0 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1099,6 +1099,7 @@ class LLMEngine: "enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_logprob": self.cfg.enable_logprob, "enable_mm": self.cfg.enable_mm, + "lm_head_fp32": self.cfg.model_config.lm_head_fp32, } for worker_flag, value in worker_append_flag.items(): if value: diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 4b8b96839..62c81efc7 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -23,7 +23,7 @@ from paddle.distributed import fleet from fastdeploy.config import FDConfig -from .utils import get_tensor +from .utils import get_tensor, temporary_dtype class ParallelLMHead(nn.Layer): @@ -38,6 +38,7 @@ class ParallelLMHead(nn.Layer): embedding_dim: int, prefix: str = "", with_bias: bool = False, + dtype: str = None, ) -> None: """ Parallelized LMhead. @@ -50,6 +51,7 @@ class ParallelLMHead(nn.Layer): embedding_dim (int): size of hidden state. prefix (str): The name of current layer. Defaults to "". with_bias (bool): whether to have bias. Default: False. + dtype (str): The dtype of weight. Defalut: None. """ super(ParallelLMHead, self).__init__() self.weight_key: str = prefix + ".weight" @@ -62,37 +64,37 @@ class ParallelLMHead(nn.Layer): ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear - + self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings - - if self.use_ep: - self.weight = self.create_parameter( - shape=[embedding_dim, num_embeddings], - dtype=paddle.get_default_dtype(), - is_bias=False, - ) - else: - if self.column_cut: - need_gather = True - self.linear = ColumnParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - gather_output=need_gather, - fuse_matmul_bias=False, # False diff更小 + with temporary_dtype(self.dtype): + if self.use_ep: + self.weight = self.create_parameter( + shape=[embedding_dim, num_embeddings], + dtype=paddle.get_default_dtype(), + is_bias=False, ) else: - self.linear = RowParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - input_is_parallel=False, - fuse_matmul_bias=False, # False diff更小 - ) + if self.column_cut: + need_gather = True + self.linear = ColumnParallelLinear( + embedding_dim, + num_embeddings, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + gather_output=need_gather, + fuse_matmul_bias=False, # False diff更小 + ) + else: + self.linear = RowParallelLinear( + embedding_dim, + num_embeddings, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + input_is_parallel=False, + fuse_matmul_bias=False, # False diff更小 + ) def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ @@ -103,20 +105,20 @@ class ParallelLMHead(nn.Layer): """ if self.use_ep: - self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) + self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype)) else: if self.tie_word_embeddings: self.linear.weight.set_value( - get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0]) + get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0]) ) else: - weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) + weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype) if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) self.linear.weight.set_value(weight_tensor) if self.bias_key is not None: - bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) + bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype) self.linear.bias.set_value(bias) def forward(self, input: paddle.Tensor) -> paddle.Tensor: @@ -131,7 +133,7 @@ class ParallelLMHead(nn.Layer): """ logits = input if self.use_ep: - logits = paddle.matmul(logits, self.weight) + logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight) else: - logits = self.linear(logits) + logits = self.linear(logits.astype(self.linear.weight.dtype)) return logits diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index ed7b4369b..2e565efb5 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -15,6 +15,7 @@ """ import functools +from contextlib import contextmanager from typing import Tuple, Union import numpy as np @@ -377,3 +378,15 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) paddle.Tensor: An empty tensor with the specified shape and data type. """ return paddle.empty(list(shape), dtype=dtype) + + +@contextmanager +def temporary_dtype(dtype: str): + """Temporarily set Paddle default dtype""" + orig_dtype = paddle.get_default_dtype() + try: + if dtype is not None and dtype == "float32": + paddle.set_default_dtype(dtype) + yield + finally: + paddle.set_default_dtype(orig_dtype) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 4d75b03b9..aa887735b 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -613,7 +613,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 460170b7d..3d9695a18 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -412,13 +412,13 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): """ self.ernie.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight}) else: self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index b52d8ed71..6e8e83603 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -363,7 +363,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): compute logits """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 2dd562135..fe303ee2d 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -570,13 +570,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): self.vision_model.load_state_dict(state_dict) self.resampler_model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight}) else: self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index af2af00b1..acf74e8e1 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -326,7 +326,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 4b106aea2..9ce571d24 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -257,14 +257,14 @@ class Qwen3ForCausalLM(ModelForCasualLM): """ self.model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight}) else: self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index bcf9dbe6a..5f6fac9ee 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -298,7 +298,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 32373b308..2b85c94af 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -587,6 +587,11 @@ def parse_args(): action="store_true", help="Enable output of token-level log probabilities.", ) + parser.add_argument( + "--lm_head_fp32", + action="store_true", + help="The data type of lm_head", + ) args = parser.parse_args() return args