[Precision] Change lm_head layer running in float32 (#3596)

* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
This commit is contained in:
chen
2025-08-26 20:20:06 +08:00
committed by GitHub
parent 2136990144
commit d233e3c97c
14 changed files with 85 additions and 49 deletions

View File

@@ -15,10 +15,10 @@
"""
import json
import os
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
import os
from fastdeploy.engine.config import (
CacheConfig,
@@ -315,6 +315,11 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""
lm_head_fp32: bool = None
"""
Flag to specify the data type of lm_head as FP32.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -466,6 +471,12 @@ class EngineArgs:
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.",
)
model_group.add_argument(
"--lm_head-fp32",
action="store_true",
default=EngineArgs.lm_head_fp32,
help="Specify the dtype of lm_head weight as float32.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
@@ -769,6 +780,7 @@ class EngineArgs:
quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight,
load_strategy=self.load_strategy,
lm_head_fp32=self.lm_head_fp32,
)
def create_cache_config(self, model_cfg) -> CacheConfig:
@@ -855,7 +867,7 @@ class EngineArgs:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = 2048
else:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len
else:
self.max_num_batched_tokens = 8192