[Precision] Change lm_head layer running in float32 (#3596)

* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
This commit is contained in:
chen
2025-08-26 20:20:06 +08:00
committed by GitHub
parent 2136990144
commit d233e3c97c
14 changed files with 85 additions and 49 deletions

View File

@@ -15,6 +15,7 @@
"""
import functools
from contextlib import contextmanager
from typing import Tuple, Union
import numpy as np
@@ -377,3 +378,15 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
paddle.Tensor: An empty tensor with the specified shape and data type.
"""
return paddle.empty(list(shape), dtype=dtype)
@contextmanager
def temporary_dtype(dtype: str):
"""Temporarily set Paddle default dtype"""
orig_dtype = paddle.get_default_dtype()
try:
if dtype is not None and dtype == "float32":
paddle.set_default_dtype(dtype)
yield
finally:
paddle.set_default_dtype(orig_dtype)