[SOT] Change warnings to errors and remove fallback operations (#4378)

* Change warnings to errors and remove fallback operations

* fix unitest

* fix codestyle
This commit is contained in:
Ryan
2025-10-17 11:27:04 +08:00
committed by GitHub
parent 0413c32b8f
commit 6160145f82
3 changed files with 27 additions and 28 deletions

View File

@@ -21,7 +21,6 @@ from typing import Callable, Optional, TypeVar, get_type_hints
from paddle.jit import sot
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
from paddleformers.utils.log import logger
from typing_extensions import ParamSpec
from fastdeploy.config import FDConfig
@@ -46,11 +45,10 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
forward_type_hints = get_type_hints(forward_fn)
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
unsafe_static_forward_fn = None
need_warmup = True
@functools.wraps(forward_fn)
def warmup_impl(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn, need_warmup
nonlocal unsafe_static_forward_fn
bound_args = forward_sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
for name, arg in bound_args.arguments.items():
@@ -66,17 +64,11 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
]
# Check has only one graph
if len(new_guarded_codes) > 1:
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
unsafe_static_forward_fn = None
need_warmup = False
return result
raise RuntimeError("Model has multiple generated code, please check all dynamic dim has marked.")
# Check generated code has no break graph
new_code = new_guarded_codes[0][0][0]
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = None
need_warmup = False
return result
raise RuntimeError("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = types.FunctionType(
new_code,
forward_fn.__globals__,
@@ -88,15 +80,12 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
@functools.wraps(forward_fn)
def static_forward(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn
if in_profile_run_mode():
return forward_fn(self, *args, **kwargs)
nonlocal need_warmup
is_warmup = in_warmup_mode() and need_warmup
if is_warmup:
if in_warmup_mode():
return warmup_impl(self, *args, **kwargs)
nonlocal unsafe_static_forward_fn
if unsafe_static_forward_fn is None:
return static_forward_fn(self, *args, **kwargs)
assert unsafe_static_forward_fn is not None
return unsafe_static_forward_fn(self, *args, **kwargs)
return static_forward