mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
[Precision] Change lm_head layer running in float32 (#3596)
* support lm_head fp32 bf16 fp16 * delete print * code check * check * check * code check * check * check
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -377,3 +378,15 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
|
||||
paddle.Tensor: An empty tensor with the specified shape and data type.
|
||||
"""
|
||||
return paddle.empty(list(shape), dtype=dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_dtype(dtype: str):
|
||||
"""Temporarily set Paddle default dtype"""
|
||||
orig_dtype = paddle.get_default_dtype()
|
||||
try:
|
||||
if dtype is not None and dtype == "float32":
|
||||
paddle.set_default_dtype(dtype)
|
||||
yield
|
||||
finally:
|
||||
paddle.set_default_dtype(orig_dtype)
|
||||
|
Reference in New Issue
Block a user