mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
import inspect
|
|
import types
|
|
from typing import Callable, Optional, TypeVar, get_type_hints
|
|
|
|
from paddle.jit import sot
|
|
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
|
|
from paddleformers.utils.log import logger
|
|
from typing_extensions import ParamSpec
|
|
|
|
from fastdeploy.config import FDConfig
|
|
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
|
|
CudaGraphPiecewiseBackend,
|
|
)
|
|
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
|
|
resolve_dynamic_dims,
|
|
)
|
|
from fastdeploy.model_executor.graph_optimization.utils import in_profile_run_mode
|
|
from fastdeploy.model_executor.graph_optimization.utils import (
|
|
in_sot_warmup_mode as in_warmup_mode,
|
|
)
|
|
|
|
P = ParamSpec("P")
|
|
T = TypeVar("T")
|
|
|
|
|
|
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
|
|
forward_fn = fn
|
|
forward_sig = inspect.signature(forward_fn)
|
|
forward_type_hints = get_type_hints(forward_fn)
|
|
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
|
|
unsafe_static_forward_fn = None
|
|
need_warmup = True
|
|
|
|
@functools.wraps(forward_fn)
|
|
def warmup_impl(self, *args, **kwargs):
|
|
nonlocal unsafe_static_forward_fn, need_warmup
|
|
bound_args = forward_sig.bind(self, *args, **kwargs)
|
|
bound_args.apply_defaults()
|
|
for name, arg in bound_args.arguments.items():
|
|
if name not in forward_type_hints:
|
|
continue
|
|
annotation = forward_type_hints[name]
|
|
resolve_dynamic_dims(arg, name, annotation)
|
|
|
|
result = static_forward_fn(self, *args, **kwargs)
|
|
original_code = forward_fn.__code__
|
|
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[
|
|
original_code
|
|
]
|
|
# Check has only one graph
|
|
if len(new_guarded_codes) > 1:
|
|
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
|
|
unsafe_static_forward_fn = None
|
|
need_warmup = False
|
|
return result
|
|
# Check generated code has no break graph
|
|
new_code = new_guarded_codes[0][0][0]
|
|
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
|
|
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
|
|
unsafe_static_forward_fn = None
|
|
need_warmup = False
|
|
return result
|
|
unsafe_static_forward_fn = types.FunctionType(
|
|
new_code,
|
|
forward_fn.__globals__,
|
|
forward_fn.__name__,
|
|
forward_fn.__defaults__,
|
|
forward_fn.__closure__,
|
|
)
|
|
return result
|
|
|
|
@functools.wraps(forward_fn)
|
|
def static_forward(self, *args, **kwargs):
|
|
if in_profile_run_mode():
|
|
return forward_fn(self, *args, **kwargs)
|
|
nonlocal need_warmup
|
|
is_warmup = in_warmup_mode() and need_warmup
|
|
if is_warmup:
|
|
return warmup_impl(self, *args, **kwargs)
|
|
nonlocal unsafe_static_forward_fn
|
|
if unsafe_static_forward_fn is None:
|
|
return static_forward_fn(self, *args, **kwargs)
|
|
return unsafe_static_forward_fn(self, *args, **kwargs)
|
|
|
|
return static_forward
|
|
|
|
|
|
class GraphOptBackend:
|
|
"""
|
|
Integrated various graph optimization functions, including dynamic graph to static graph conversion,
|
|
CINN compilation optimization, CudaGraph, and so on.
|
|
"""
|
|
|
|
fd_config: FDConfig
|
|
cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None
|
|
|
|
def __init__(self, runnable: Callable, fd_config: FDConfig):
|
|
self.runnable = runnable
|
|
self.fd_config = fd_config
|
|
|
|
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
|
|
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
|
# 1. Prepare cuda graph input buffers (contain output of subgraphs)
|
|
|
|
# 2. Convert dynamic graph to static graph
|
|
|
|
backend = (
|
|
ToStaticBackend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else ToStaticBackend.PHI
|
|
)
|
|
self.runnable = apply_to_static_optimization(
|
|
self.runnable.__func__,
|
|
backend,
|
|
).__get__(self.runnable.__self__)
|
|
|
|
def __call__(self, **kwargs):
|
|
if not self.fd_config.graph_opt_config.use_cudagraph:
|
|
return self.runnable(**kwargs)
|
|
if self.cudagraph_piecewise_backend is None:
|
|
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
|
|
fd_config=self.fd_config, runnable=self.runnable
|
|
)
|
|
|
|
assert kwargs["forward_meta"].ids_remove_padding is not None
|
|
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
|
|
|
|
if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch):
|
|
return self.runnable(**kwargs)
|
|
else:
|
|
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
|
|
|
def clear_cudagraph_piecewise_backend(self):
|
|
""" """
|
|
self.cudagraph_piecewise_backend.clear_graph()
|