mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[SOT] Mark dynamic dims by type annotations (#2771)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [SOT] Mark dynamic dims by type annotations * fix conflict of forward_meta * mark more attn backend * fix missing annotated and add env SOT_SPECIALIZED_DIM_NUMBERS * auto infer implicit 0 dim dynamic dim * revert manual marked dims * revert missing update * auto infer can use unsafe code in warmup stage * check -> type_match * fix codestyle * restore blank line * empty commit * add need_warmup nonlocal; * add doc for resolver * add missing type hints * unquote "ForwardMeta"
This commit is contained in:
@@ -14,14 +14,101 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
from typing import Callable, Optional, TypeVar, get_type_hints
|
||||
|
||||
from paddle.jit.dy2static.utils import Backend
|
||||
from paddle.jit import sot
|
||||
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
|
||||
from paddleformers.utils.log import logger
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
|
||||
CudaGraphPiecewiseBackend,
|
||||
)
|
||||
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
|
||||
resolve_dynamic_dims,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
|
||||
def create_in_warmup_mode():
|
||||
cnt = 0
|
||||
|
||||
def in_warmup_mode():
|
||||
nonlocal cnt
|
||||
cnt += 1
|
||||
return cnt < 32
|
||||
|
||||
return in_warmup_mode
|
||||
|
||||
|
||||
in_warmup_mode = create_in_warmup_mode()
|
||||
|
||||
|
||||
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
|
||||
forward_fn = fn
|
||||
forward_sig = inspect.signature(forward_fn)
|
||||
forward_type_hints = get_type_hints(forward_fn)
|
||||
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = True
|
||||
|
||||
@functools.wraps(forward_fn)
|
||||
def warmup_impl(self, *args, **kwargs):
|
||||
nonlocal unsafe_static_forward_fn, need_warmup
|
||||
bound_args = forward_sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for name, arg in bound_args.arguments.items():
|
||||
if name not in forward_type_hints:
|
||||
continue
|
||||
annotation = forward_type_hints[name]
|
||||
resolve_dynamic_dims(arg, name, annotation)
|
||||
|
||||
result = static_forward_fn(self, *args, **kwargs)
|
||||
original_code = forward_fn.__code__
|
||||
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[
|
||||
original_code
|
||||
]
|
||||
# Check has only one graph
|
||||
if len(new_guarded_codes) > 1:
|
||||
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = False
|
||||
return result
|
||||
# Check generated code has no break graph
|
||||
new_code = new_guarded_codes[0][0][0]
|
||||
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
|
||||
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = False
|
||||
return result
|
||||
unsafe_static_forward_fn = types.FunctionType(
|
||||
new_code,
|
||||
forward_fn.__globals__,
|
||||
forward_fn.__name__,
|
||||
forward_fn.__defaults__,
|
||||
forward_fn.__closure__,
|
||||
)
|
||||
return result
|
||||
|
||||
@functools.wraps(forward_fn)
|
||||
def static_forward(self, *args, **kwargs):
|
||||
nonlocal need_warmup
|
||||
is_warmup = in_warmup_mode() and need_warmup
|
||||
if is_warmup:
|
||||
return warmup_impl(self, *args, **kwargs)
|
||||
nonlocal unsafe_static_forward_fn
|
||||
if unsafe_static_forward_fn is None:
|
||||
return static_forward_fn(self, *args, **kwargs)
|
||||
return unsafe_static_forward_fn(self, *args, **kwargs)
|
||||
|
||||
return static_forward
|
||||
|
||||
|
||||
class GraphOptBackend:
|
||||
@@ -42,10 +129,14 @@ class GraphOptBackend:
|
||||
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
||||
|
||||
# 2. Convert dynamic grpah to static graph
|
||||
from paddle.jit import sot
|
||||
|
||||
backend = Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI
|
||||
self.runnable = sot.symbolic_translate(self.runnable, training=False, backend=backend)
|
||||
backend = (
|
||||
ToStaticBackend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else ToStaticBackend.PHI
|
||||
)
|
||||
self.runnable = apply_to_static_optimization(
|
||||
self.runnable.__func__,
|
||||
backend,
|
||||
).__get__(self.runnable.__self__)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if not self.fd_config.graph_opt_config.use_cudagraph:
|
||||
|
Reference in New Issue
Block a user