mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[SOT] Change warnings to errors and remove fallback operations (#4378)
* Change warnings to errors and remove fallback operations * fix unitest * fix codestyle
This commit is contained in:
@@ -21,7 +21,6 @@ from typing import Callable, Optional, TypeVar, get_type_hints
|
||||
|
||||
from paddle.jit import sot
|
||||
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
|
||||
from paddleformers.utils.log import logger
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
@@ -46,11 +45,10 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
|
||||
forward_type_hints = get_type_hints(forward_fn)
|
||||
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = True
|
||||
|
||||
@functools.wraps(forward_fn)
|
||||
def warmup_impl(self, *args, **kwargs):
|
||||
nonlocal unsafe_static_forward_fn, need_warmup
|
||||
nonlocal unsafe_static_forward_fn
|
||||
bound_args = forward_sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for name, arg in bound_args.arguments.items():
|
||||
@@ -66,17 +64,11 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
|
||||
]
|
||||
# Check has only one graph
|
||||
if len(new_guarded_codes) > 1:
|
||||
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = False
|
||||
return result
|
||||
raise RuntimeError("Model has multiple generated code, please check all dynamic dim has marked.")
|
||||
# Check generated code has no break graph
|
||||
new_code = new_guarded_codes[0][0][0]
|
||||
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
|
||||
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = False
|
||||
return result
|
||||
raise RuntimeError("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
|
||||
unsafe_static_forward_fn = types.FunctionType(
|
||||
new_code,
|
||||
forward_fn.__globals__,
|
||||
@@ -88,15 +80,12 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
|
||||
|
||||
@functools.wraps(forward_fn)
|
||||
def static_forward(self, *args, **kwargs):
|
||||
nonlocal unsafe_static_forward_fn
|
||||
if in_profile_run_mode():
|
||||
return forward_fn(self, *args, **kwargs)
|
||||
nonlocal need_warmup
|
||||
is_warmup = in_warmup_mode() and need_warmup
|
||||
if is_warmup:
|
||||
if in_warmup_mode():
|
||||
return warmup_impl(self, *args, **kwargs)
|
||||
nonlocal unsafe_static_forward_fn
|
||||
if unsafe_static_forward_fn is None:
|
||||
return static_forward_fn(self, *args, **kwargs)
|
||||
assert unsafe_static_forward_fn is not None
|
||||
return unsafe_static_forward_fn(self, *args, **kwargs)
|
||||
|
||||
return static_forward
|
||||
|
||||
Reference in New Issue
Block a user