[SOT] Mark dynamic dims by type annotations (#2771)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [SOT] Mark dynamic dims by type annotations

* fix conflict of forward_meta

* mark more attn backend

* fix missing annotated and add env SOT_SPECIALIZED_DIM_NUMBERS

* auto infer implicit 0 dim dynamic dim

* revert manual marked dims

* revert missing update

* auto infer can use unsafe code in warmup stage

* check -> type_match

* fix codestyle

* restore blank line

* empty commit

* add need_warmup nonlocal;

* add doc for resolver

* add missing type hints

* unquote "ForwardMeta"
This commit is contained in:
Nyakku Shigure
2025-07-22 15:23:52 +08:00
committed by GitHub
parent e991777757
commit 48e6a0ca26
13 changed files with 330 additions and 28 deletions

View File

@@ -0,0 +1,191 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import dataclasses
import typing
from abc import abstractmethod
from collections.abc import Callable
from functools import partial
from typing import Annotated, Any, TypeVar, Union, get_origin, get_type_hints
import paddle
from paddle import Tensor
from paddleformers.utils.log import logger
from typing_extensions import TypeAlias
T = TypeVar("T")
U = TypeVar("U")
Accessor: TypeAlias = Callable[[T], U]
class DynamicDims:
def __init__(self, dims: int | tuple[int]):
self.dims = dims if isinstance(dims, tuple) else (dims,)
def __repr__(self):
return f"DynamicDims({self.dims})"
class DynamicDimTypeResolver:
"""
Base class for dynamic dimension type resolvers.
This class provides a mechanism to register and resolve dynamic dimensions
based on type annotations. It uses a registry pattern to allow multiple
resolvers to be registered and used in a flexible manner.
"""
ALL_DYNAMIC_DIM_TYPE_RESOLVERS = []
@classmethod
def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]):
cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls())
return resolver_cls
@abstractmethod
def type_match(self, tp: type[Any]) -> bool:
raise NotImplementedError
@abstractmethod
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
raise NotImplementedError
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
inner_types = self.extract_inner_types(data, data_name, tp)
for accessor, inner_data_name, inner_type in inner_types:
self.generic_resolve(accessor(data), inner_data_name, inner_type)
def generic_resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS:
if resolver.type_match(tp):
return resolver.resolve(data, data_name, tp)
runtime_tp = type(data)
if runtime_tp is not tp and resolver.type_match(runtime_tp):
return resolver.resolve(data, data_name, runtime_tp)
else:
logger.debug(f"No resolver found for type {tp} and data {data_name}")
@DynamicDimTypeResolver.register_resolver
class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return dataclasses.is_dataclass(tp) and isinstance(tp, type)
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
type_hints = get_type_hints(tp, include_extras=True)
return [ # type: ignore
(
# bind name by partial to avoid capture wrong free vars
partial(lambda name, dt: getattr(dt, name), field.name),
f"{data_name}.{field.name}",
type_hints[field.name],
)
for field in dataclasses.fields(tp)
]
@DynamicDimTypeResolver.register_resolver
class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp) -> bool:
return get_origin(tp) is Union and len(tp.__args__) == 2 and tp.__args__[1] is type(None) # noqa: E721
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if data is None:
return []
inner_type = tp.__args__[0]
return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional
@DynamicDimTypeResolver.register_resolver
class ListDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return get_origin(tp) is list
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if not data:
return []
inner_type = typing.get_args(tp)[0] if tp.__args__ else Any
return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore
@DynamicDimTypeResolver.register_resolver
class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver):
INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__"
def type_match(self, tp: type[Any]) -> bool:
return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
if isinstance(fields, str):
raise TypeError(
f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}"
)
inner_types_dict = typing.get_type_hints(tp)
return [
(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type)
for field_name, inner_type in inner_types_dict.items()
]
@DynamicDimTypeResolver.register_resolver
class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
base_type, *metadata = typing.get_args(tp)
# Filter out DynamicDims instances
dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)]
if not dynamic_dims:
return
if len(dynamic_dims) > 1:
raise ValueError("Multiple DynamicDims annotations found. Only one is allowed.")
dynamic_dims = dynamic_dims[0].dims
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(data, dynamic_dims)
@DynamicDimTypeResolver.register_resolver
class TensorImplicitFirstDimOnlyDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return tp is Tensor
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
# Tensor annotation has implicit dynamic_dims=(0, )
dynamic_dims = (0,)
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(data, dynamic_dims)
def resolve_dynamic_dims(arg: Any, arg_name: str, annotation: type[Any]) -> None:
DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation)

View File

@@ -14,14 +14,101 @@
# limitations under the License.
"""
from typing import Callable, Optional
import functools
import inspect
import types
from typing import Callable, Optional, TypeVar, get_type_hints
from paddle.jit.dy2static.utils import Backend
from paddle.jit import sot
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
from paddleformers.utils.log import logger
from typing_extensions import ParamSpec
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
CudaGraphPiecewiseBackend,
)
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
resolve_dynamic_dims,
)
P = ParamSpec("P")
T = TypeVar("T")
# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
def create_in_warmup_mode():
cnt = 0
def in_warmup_mode():
nonlocal cnt
cnt += 1
return cnt < 32
return in_warmup_mode
in_warmup_mode = create_in_warmup_mode()
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
forward_fn = fn
forward_sig = inspect.signature(forward_fn)
forward_type_hints = get_type_hints(forward_fn)
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
unsafe_static_forward_fn = None
need_warmup = True
@functools.wraps(forward_fn)
def warmup_impl(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn, need_warmup
bound_args = forward_sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
for name, arg in bound_args.arguments.items():
if name not in forward_type_hints:
continue
annotation = forward_type_hints[name]
resolve_dynamic_dims(arg, name, annotation)
result = static_forward_fn(self, *args, **kwargs)
original_code = forward_fn.__code__
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[
original_code
]
# Check has only one graph
if len(new_guarded_codes) > 1:
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
unsafe_static_forward_fn = None
need_warmup = False
return result
# Check generated code has no break graph
new_code = new_guarded_codes[0][0][0]
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = None
need_warmup = False
return result
unsafe_static_forward_fn = types.FunctionType(
new_code,
forward_fn.__globals__,
forward_fn.__name__,
forward_fn.__defaults__,
forward_fn.__closure__,
)
return result
@functools.wraps(forward_fn)
def static_forward(self, *args, **kwargs):
nonlocal need_warmup
is_warmup = in_warmup_mode() and need_warmup
if is_warmup:
return warmup_impl(self, *args, **kwargs)
nonlocal unsafe_static_forward_fn
if unsafe_static_forward_fn is None:
return static_forward_fn(self, *args, **kwargs)
return unsafe_static_forward_fn(self, *args, **kwargs)
return static_forward
class GraphOptBackend:
@@ -42,10 +129,14 @@ class GraphOptBackend:
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
# 2. Convert dynamic grpah to static graph
from paddle.jit import sot
backend = Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI
self.runnable = sot.symbolic_translate(self.runnable, training=False, backend=backend)
backend = (
ToStaticBackend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else ToStaticBackend.PHI
)
self.runnable = apply_to_static_optimization(
self.runnable.__func__,
backend,
).__get__(self.runnable.__self__)
def __call__(self, **kwargs):
if not self.fd_config.graph_opt_config.use_cudagraph: