[SOT] Mark dynamic dims by type annotations (#2771)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [SOT] Mark dynamic dims by type annotations

* fix conflict of forward_meta

* mark more attn backend

* fix missing annotated and add env SOT_SPECIALIZED_DIM_NUMBERS

* auto infer implicit 0 dim dynamic dim

* revert manual marked dims

* revert missing update

* auto infer can use unsafe code in warmup stage

* check -> type_match

* fix codestyle

* restore blank line

* empty commit

* add need_warmup nonlocal;

* add doc for resolver

* add missing type hints

* unquote "ForwardMeta"
This commit is contained in:
Nyakku Shigure
2025-07-22 15:23:52 +08:00
committed by GitHub
parent e991777757
commit 48e6a0ca26
13 changed files with 330 additions and 28 deletions

View File

@@ -335,11 +335,11 @@ class GraphOptimizationConfig:
cudagraph_splitting_ops = ["paddle.unified_attention"] cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph, Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator can manually split the model into multiple layers and apply the @support_graph_optimization decorator
only to the layer where CUDA graph functionality is required. only to the layer where CUDA graph functionality is required.
""" """
cudagraph_splitting_ops = Optional[list[str]] cudagraph_splitting_ops: list[str] = field(default_factory=list)
"""" Whether to use a full cuda graph for the entire forward pass rather than """ Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops.""" Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = True full_cuda_graph: bool = True

View File

@@ -937,11 +937,11 @@ class LLMEngine:
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"), "SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"), "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv( "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
"FLAGS_pir_interpreter_record_stream_for_gc_cache", "FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1"
default="1",
), ),
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv( "FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
"FLAGS_parameters_persistent_mode_in_dy2st", default="1" "FLAGS_parameters_persistent_mode_in_dy2st", default="1"

View File

@@ -0,0 +1,191 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import dataclasses
import typing
from abc import abstractmethod
from collections.abc import Callable
from functools import partial
from typing import Annotated, Any, TypeVar, Union, get_origin, get_type_hints
import paddle
from paddle import Tensor
from paddleformers.utils.log import logger
from typing_extensions import TypeAlias
T = TypeVar("T")
U = TypeVar("U")
Accessor: TypeAlias = Callable[[T], U]
class DynamicDims:
def __init__(self, dims: int | tuple[int]):
self.dims = dims if isinstance(dims, tuple) else (dims,)
def __repr__(self):
return f"DynamicDims({self.dims})"
class DynamicDimTypeResolver:
"""
Base class for dynamic dimension type resolvers.
This class provides a mechanism to register and resolve dynamic dimensions
based on type annotations. It uses a registry pattern to allow multiple
resolvers to be registered and used in a flexible manner.
"""
ALL_DYNAMIC_DIM_TYPE_RESOLVERS = []
@classmethod
def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]):
cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls())
return resolver_cls
@abstractmethod
def type_match(self, tp: type[Any]) -> bool:
raise NotImplementedError
@abstractmethod
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
raise NotImplementedError
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
inner_types = self.extract_inner_types(data, data_name, tp)
for accessor, inner_data_name, inner_type in inner_types:
self.generic_resolve(accessor(data), inner_data_name, inner_type)
def generic_resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS:
if resolver.type_match(tp):
return resolver.resolve(data, data_name, tp)
runtime_tp = type(data)
if runtime_tp is not tp and resolver.type_match(runtime_tp):
return resolver.resolve(data, data_name, runtime_tp)
else:
logger.debug(f"No resolver found for type {tp} and data {data_name}")
@DynamicDimTypeResolver.register_resolver
class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return dataclasses.is_dataclass(tp) and isinstance(tp, type)
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
type_hints = get_type_hints(tp, include_extras=True)
return [ # type: ignore
(
# bind name by partial to avoid capture wrong free vars
partial(lambda name, dt: getattr(dt, name), field.name),
f"{data_name}.{field.name}",
type_hints[field.name],
)
for field in dataclasses.fields(tp)
]
@DynamicDimTypeResolver.register_resolver
class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp) -> bool:
return get_origin(tp) is Union and len(tp.__args__) == 2 and tp.__args__[1] is type(None) # noqa: E721
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if data is None:
return []
inner_type = tp.__args__[0]
return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional
@DynamicDimTypeResolver.register_resolver
class ListDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return get_origin(tp) is list
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if not data:
return []
inner_type = typing.get_args(tp)[0] if tp.__args__ else Any
return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore
@DynamicDimTypeResolver.register_resolver
class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver):
INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__"
def type_match(self, tp: type[Any]) -> bool:
return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
def extract_inner_types(
self, data: Any, data_name: str, tp: type[Any]
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
if isinstance(fields, str):
raise TypeError(
f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}"
)
inner_types_dict = typing.get_type_hints(tp)
return [
(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type)
for field_name, inner_type in inner_types_dict.items()
]
@DynamicDimTypeResolver.register_resolver
class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
base_type, *metadata = typing.get_args(tp)
# Filter out DynamicDims instances
dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)]
if not dynamic_dims:
return
if len(dynamic_dims) > 1:
raise ValueError("Multiple DynamicDims annotations found. Only one is allowed.")
dynamic_dims = dynamic_dims[0].dims
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(data, dynamic_dims)
@DynamicDimTypeResolver.register_resolver
class TensorImplicitFirstDimOnlyDynamicDimTypeResolver(DynamicDimTypeResolver):
def type_match(self, tp: type[Any]) -> bool:
return tp is Tensor
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
# Tensor annotation has implicit dynamic_dims=(0, )
dynamic_dims = (0,)
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(data, dynamic_dims)
def resolve_dynamic_dims(arg: Any, arg_name: str, annotation: type[Any]) -> None:
DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation)

View File

@@ -14,14 +14,101 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Callable, Optional import functools
import inspect
import types
from typing import Callable, Optional, TypeVar, get_type_hints
from paddle.jit.dy2static.utils import Backend from paddle.jit import sot
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
from paddleformers.utils.log import logger
from typing_extensions import ParamSpec
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import ( from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
CudaGraphPiecewiseBackend, CudaGraphPiecewiseBackend,
) )
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
resolve_dynamic_dims,
)
P = ParamSpec("P")
T = TypeVar("T")
# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
def create_in_warmup_mode():
cnt = 0
def in_warmup_mode():
nonlocal cnt
cnt += 1
return cnt < 32
return in_warmup_mode
in_warmup_mode = create_in_warmup_mode()
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
forward_fn = fn
forward_sig = inspect.signature(forward_fn)
forward_type_hints = get_type_hints(forward_fn)
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
unsafe_static_forward_fn = None
need_warmup = True
@functools.wraps(forward_fn)
def warmup_impl(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn, need_warmup
bound_args = forward_sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
for name, arg in bound_args.arguments.items():
if name not in forward_type_hints:
continue
annotation = forward_type_hints[name]
resolve_dynamic_dims(arg, name, annotation)
result = static_forward_fn(self, *args, **kwargs)
original_code = forward_fn.__code__
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[
original_code
]
# Check has only one graph
if len(new_guarded_codes) > 1:
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
unsafe_static_forward_fn = None
need_warmup = False
return result
# Check generated code has no break graph
new_code = new_guarded_codes[0][0][0]
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = None
need_warmup = False
return result
unsafe_static_forward_fn = types.FunctionType(
new_code,
forward_fn.__globals__,
forward_fn.__name__,
forward_fn.__defaults__,
forward_fn.__closure__,
)
return result
@functools.wraps(forward_fn)
def static_forward(self, *args, **kwargs):
nonlocal need_warmup
is_warmup = in_warmup_mode() and need_warmup
if is_warmup:
return warmup_impl(self, *args, **kwargs)
nonlocal unsafe_static_forward_fn
if unsafe_static_forward_fn is None:
return static_forward_fn(self, *args, **kwargs)
return unsafe_static_forward_fn(self, *args, **kwargs)
return static_forward
class GraphOptBackend: class GraphOptBackend:
@@ -42,10 +129,14 @@ class GraphOptBackend:
# 1. Prepare cuda grpah input buffers (contain output of subgraphs) # 1. Prepare cuda grpah input buffers (contain output of subgraphs)
# 2. Convert dynamic grpah to static graph # 2. Convert dynamic grpah to static graph
from paddle.jit import sot
backend = Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI backend = (
self.runnable = sot.symbolic_translate(self.runnable, training=False, backend=backend) ToStaticBackend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else ToStaticBackend.PHI
)
self.runnable = apply_to_static_optimization(
self.runnable.__func__,
backend,
).__get__(self.runnable.__self__)
def __call__(self, **kwargs): def __call__(self, **kwargs):
if not self.fd_config.graph_opt_config.use_cudagraph: if not self.fd_config.graph_opt_config.use_cudagraph:

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from .append_attn_backend import AppendAttentionBackend from .append_attn_backend import AppendAttentionBackend
from .attention import Attention
from .attention_selecter import get_attention_backend from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend from .base_attention_backend import AttentionBackend
from .block_multihead_attn_backend import BlockAttentionBackend from .block_multihead_attn_backend import BlockAttentionBackend
@@ -32,4 +33,5 @@ __all__ = [
"FlashAttentionBackend", "FlashAttentionBackend",
"IluvatarAttnBackend", "IluvatarAttnBackend",
"BlockAttentionBackend", "BlockAttentionBackend",
"Attention",
] ]

View File

@@ -66,13 +66,13 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None encoder_block_shape_q: int = -1
decoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class AppendAttentionBackend(AttentionBackend): class AppendAttentionBackend(AttentionBackend):
@@ -80,6 +80,9 @@ class AppendAttentionBackend(AttentionBackend):
AppendAttentionBackend backend implementation. AppendAttentionBackend backend implementation.
""" """
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: AppendAttentionMetadata
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,

View File

@@ -56,13 +56,13 @@ class BlockAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None encoder_block_shape_q: int = -1
decoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class BlockAttentionBackend(AttentionBackend): class BlockAttentionBackend(AttentionBackend):
@@ -70,6 +70,9 @@ class BlockAttentionBackend(AttentionBackend):
BlockAttentionBackend backend implementation. BlockAttentionBackend backend implementation.
""" """
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: BlockAttentionBackend
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,

View File

@@ -66,8 +66,8 @@ class FlashAttentionMetadata(AttentionMetadata):
decoder_tile_ids_per_batch: paddle.Tensor = None decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: Optional[paddle.Tensor] = None encoder_block_shape_q: int = -1
decoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: int = -1
cu_seqlens_q: paddle.Tensor = None cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None
@@ -81,7 +81,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# pd_disaggregation # pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@@ -89,6 +89,9 @@ class FlashAttentionBackend(AttentionBackend):
FlashAttentionBackend backend implementation FlashAttentionBackend backend implementation
""" """
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashAttentionMetadata
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,

View File

@@ -82,13 +82,13 @@ class MLAAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None encoder_block_shape_q: int = -1
decoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class MLAAttentionBackend(AttentionBackend): class MLAAttentionBackend(AttentionBackend):
@@ -96,6 +96,9 @@ class MLAAttentionBackend(AttentionBackend):
MLA Attention Backend implementation. MLA Attention Backend implementation.
""" """
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: MLAAttentionMetadata
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,

View File

@@ -62,13 +62,13 @@ class XPUAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None encoder_block_shape_q: int = -1
decoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class XPUAttentionBackend(AttentionBackend): class XPUAttentionBackend(AttentionBackend):
@@ -76,6 +76,9 @@ class XPUAttentionBackend(AttentionBackend):
XPUAttentionBackend backend implementation. XPUAttentionBackend backend implementation.
""" """
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: XPUAttentionMetadata
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,

View File

@@ -70,6 +70,9 @@ class GCUFlashAttnBackend(AttentionBackend):
GCUFlashAttnBackend backend implementation. GCUFlashAttnBackend backend implementation.
""" """
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: GCUFlashAttnBackend
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,

View File

@@ -406,7 +406,7 @@ class Ernie4_5_VLModel(nn.Layer):
def forward( def forward(
self, self,
ids_remove_padding: paddle.Tensor, ids_remove_padding: paddle.Tensor,
image_features: paddle.Tensor, image_features: Optional[paddle.Tensor],
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
): ):
text_input = None text_input = None
@@ -584,7 +584,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
def forward( def forward(
self, self,
ids_remove_padding: paddle.Tensor, ids_remove_padding: paddle.Tensor,
image_features: paddle.Tensor, image_features: Optional[paddle.Tensor],
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
): ):
hidden_states = self.ernie( hidden_states = self.ernie(

View File

@@ -536,7 +536,7 @@ def parse_args():
"--graph_optimization_config", "--graph_optimization_config",
type=json.loads, type=json.loads,
default=None, default=None,
help=" Configation of Graph optimization backend. ", help="Configation of Graph optimization backend.",
) )
parser.add_argument( parser.add_argument(
"--guided_decoding_backend", "--guided_decoding_backend",