mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 21:02:24 +08:00
[SOT] Mark dynamic dims by type annotations (#2771)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [SOT] Mark dynamic dims by type annotations * fix conflict of forward_meta * mark more attn backend * fix missing annotated and add env SOT_SPECIALIZED_DIM_NUMBERS * auto infer implicit 0 dim dynamic dim * revert manual marked dims * revert missing update * auto infer can use unsafe code in warmup stage * check -> type_match * fix codestyle * restore blank line * empty commit * add need_warmup nonlocal; * add doc for resolver * add missing type hints * unquote "ForwardMeta"
This commit is contained in:
@@ -335,11 +335,11 @@ class GraphOptimizationConfig:
|
||||
cudagraph_splitting_ops = ["paddle.unified_attention"]
|
||||
|
||||
Note: If want to use subgraph capture functionality in a dynamic graph,
|
||||
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
|
||||
can manually split the model into multiple layers and apply the @support_graph_optimization decorator
|
||||
only to the layer where CUDA graph functionality is required.
|
||||
"""
|
||||
cudagraph_splitting_ops = Optional[list[str]]
|
||||
"""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||
cudagraph_splitting_ops: list[str] = field(default_factory=list)
|
||||
""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
full_cuda_graph: bool = True
|
||||
|
@@ -937,11 +937,11 @@ class LLMEngine:
|
||||
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
|
||||
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
|
||||
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
|
||||
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
|
||||
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
|
||||
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
|
||||
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
|
||||
"FLAGS_pir_interpreter_record_stream_for_gc_cache",
|
||||
default="1",
|
||||
"FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1"
|
||||
),
|
||||
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
|
||||
"FLAGS_parameters_persistent_mode_in_dy2st", default="1"
|
||||
|
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, TypeVar, Union, get_origin, get_type_hints
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
from paddleformers.utils.log import logger
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
Accessor: TypeAlias = Callable[[T], U]
|
||||
|
||||
|
||||
class DynamicDims:
|
||||
def __init__(self, dims: int | tuple[int]):
|
||||
self.dims = dims if isinstance(dims, tuple) else (dims,)
|
||||
|
||||
def __repr__(self):
|
||||
return f"DynamicDims({self.dims})"
|
||||
|
||||
|
||||
class DynamicDimTypeResolver:
|
||||
"""
|
||||
Base class for dynamic dimension type resolvers.
|
||||
This class provides a mechanism to register and resolve dynamic dimensions
|
||||
based on type annotations. It uses a registry pattern to allow multiple
|
||||
resolvers to be registered and used in a flexible manner.
|
||||
"""
|
||||
|
||||
ALL_DYNAMIC_DIM_TYPE_RESOLVERS = []
|
||||
|
||||
@classmethod
|
||||
def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]):
|
||||
cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls())
|
||||
return resolver_cls
|
||||
|
||||
@abstractmethod
|
||||
def type_match(self, tp: type[Any]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract_inner_types(
|
||||
self, data: Any, data_name: str, tp: type[Any]
|
||||
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
|
||||
inner_types = self.extract_inner_types(data, data_name, tp)
|
||||
for accessor, inner_data_name, inner_type in inner_types:
|
||||
self.generic_resolve(accessor(data), inner_data_name, inner_type)
|
||||
|
||||
def generic_resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
|
||||
for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS:
|
||||
if resolver.type_match(tp):
|
||||
return resolver.resolve(data, data_name, tp)
|
||||
runtime_tp = type(data)
|
||||
if runtime_tp is not tp and resolver.type_match(runtime_tp):
|
||||
return resolver.resolve(data, data_name, runtime_tp)
|
||||
else:
|
||||
logger.debug(f"No resolver found for type {tp} and data {data_name}")
|
||||
|
||||
|
||||
@DynamicDimTypeResolver.register_resolver
|
||||
class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver):
|
||||
def type_match(self, tp: type[Any]) -> bool:
|
||||
return dataclasses.is_dataclass(tp) and isinstance(tp, type)
|
||||
|
||||
def extract_inner_types(
|
||||
self, data: Any, data_name: str, tp: type[Any]
|
||||
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
|
||||
type_hints = get_type_hints(tp, include_extras=True)
|
||||
return [ # type: ignore
|
||||
(
|
||||
# bind name by partial to avoid capture wrong free vars
|
||||
partial(lambda name, dt: getattr(dt, name), field.name),
|
||||
f"{data_name}.{field.name}",
|
||||
type_hints[field.name],
|
||||
)
|
||||
for field in dataclasses.fields(tp)
|
||||
]
|
||||
|
||||
|
||||
@DynamicDimTypeResolver.register_resolver
|
||||
class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver):
|
||||
def type_match(self, tp) -> bool:
|
||||
return get_origin(tp) is Union and len(tp.__args__) == 2 and tp.__args__[1] is type(None) # noqa: E721
|
||||
|
||||
def extract_inner_types(
|
||||
self, data: Any, data_name: str, tp: type[Any]
|
||||
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
|
||||
if data is None:
|
||||
return []
|
||||
inner_type = tp.__args__[0]
|
||||
return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional
|
||||
|
||||
|
||||
@DynamicDimTypeResolver.register_resolver
|
||||
class ListDynamicDimTypeResolver(DynamicDimTypeResolver):
|
||||
def type_match(self, tp: type[Any]) -> bool:
|
||||
return get_origin(tp) is list
|
||||
|
||||
def extract_inner_types(
|
||||
self, data: Any, data_name: str, tp: type[Any]
|
||||
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
|
||||
if not data:
|
||||
return []
|
||||
inner_type = typing.get_args(tp)[0] if tp.__args__ else Any
|
||||
return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore
|
||||
|
||||
|
||||
@DynamicDimTypeResolver.register_resolver
|
||||
class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver):
|
||||
INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__"
|
||||
|
||||
def type_match(self, tp: type[Any]) -> bool:
|
||||
return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
|
||||
|
||||
def extract_inner_types(
|
||||
self, data: Any, data_name: str, tp: type[Any]
|
||||
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
|
||||
fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
|
||||
if isinstance(fields, str):
|
||||
raise TypeError(
|
||||
f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}"
|
||||
)
|
||||
inner_types_dict = typing.get_type_hints(tp)
|
||||
return [
|
||||
(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type)
|
||||
for field_name, inner_type in inner_types_dict.items()
|
||||
]
|
||||
|
||||
|
||||
@DynamicDimTypeResolver.register_resolver
|
||||
class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver):
|
||||
def type_match(self, tp: type[Any]) -> bool:
|
||||
return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor
|
||||
|
||||
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
|
||||
base_type, *metadata = typing.get_args(tp)
|
||||
# Filter out DynamicDims instances
|
||||
dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)]
|
||||
if not dynamic_dims:
|
||||
return
|
||||
if len(dynamic_dims) > 1:
|
||||
raise ValueError("Multiple DynamicDims annotations found. Only one is allowed.")
|
||||
dynamic_dims = dynamic_dims[0].dims
|
||||
if not isinstance(data, Tensor):
|
||||
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
|
||||
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
|
||||
paddle.jit.marker.dynamic_dims(data, dynamic_dims)
|
||||
|
||||
|
||||
@DynamicDimTypeResolver.register_resolver
|
||||
class TensorImplicitFirstDimOnlyDynamicDimTypeResolver(DynamicDimTypeResolver):
|
||||
def type_match(self, tp: type[Any]) -> bool:
|
||||
return tp is Tensor
|
||||
|
||||
def resolve(self, data: Any, data_name: str, tp: type[Any]) -> None:
|
||||
# Tensor annotation has implicit dynamic_dims=(0, )
|
||||
dynamic_dims = (0,)
|
||||
if not isinstance(data, Tensor):
|
||||
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
|
||||
logger.debug(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
|
||||
paddle.jit.marker.dynamic_dims(data, dynamic_dims)
|
||||
|
||||
|
||||
def resolve_dynamic_dims(arg: Any, arg_name: str, annotation: type[Any]) -> None:
|
||||
DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation)
|
@@ -14,14 +14,101 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
from typing import Callable, Optional, TypeVar, get_type_hints
|
||||
|
||||
from paddle.jit.dy2static.utils import Backend
|
||||
from paddle.jit import sot
|
||||
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
|
||||
from paddleformers.utils.log import logger
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
|
||||
CudaGraphPiecewiseBackend,
|
||||
)
|
||||
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
|
||||
resolve_dynamic_dims,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
|
||||
def create_in_warmup_mode():
|
||||
cnt = 0
|
||||
|
||||
def in_warmup_mode():
|
||||
nonlocal cnt
|
||||
cnt += 1
|
||||
return cnt < 32
|
||||
|
||||
return in_warmup_mode
|
||||
|
||||
|
||||
in_warmup_mode = create_in_warmup_mode()
|
||||
|
||||
|
||||
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
|
||||
forward_fn = fn
|
||||
forward_sig = inspect.signature(forward_fn)
|
||||
forward_type_hints = get_type_hints(forward_fn)
|
||||
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = True
|
||||
|
||||
@functools.wraps(forward_fn)
|
||||
def warmup_impl(self, *args, **kwargs):
|
||||
nonlocal unsafe_static_forward_fn, need_warmup
|
||||
bound_args = forward_sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for name, arg in bound_args.arguments.items():
|
||||
if name not in forward_type_hints:
|
||||
continue
|
||||
annotation = forward_type_hints[name]
|
||||
resolve_dynamic_dims(arg, name, annotation)
|
||||
|
||||
result = static_forward_fn(self, *args, **kwargs)
|
||||
original_code = forward_fn.__code__
|
||||
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[
|
||||
original_code
|
||||
]
|
||||
# Check has only one graph
|
||||
if len(new_guarded_codes) > 1:
|
||||
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = False
|
||||
return result
|
||||
# Check generated code has no break graph
|
||||
new_code = new_guarded_codes[0][0][0]
|
||||
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
|
||||
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
|
||||
unsafe_static_forward_fn = None
|
||||
need_warmup = False
|
||||
return result
|
||||
unsafe_static_forward_fn = types.FunctionType(
|
||||
new_code,
|
||||
forward_fn.__globals__,
|
||||
forward_fn.__name__,
|
||||
forward_fn.__defaults__,
|
||||
forward_fn.__closure__,
|
||||
)
|
||||
return result
|
||||
|
||||
@functools.wraps(forward_fn)
|
||||
def static_forward(self, *args, **kwargs):
|
||||
nonlocal need_warmup
|
||||
is_warmup = in_warmup_mode() and need_warmup
|
||||
if is_warmup:
|
||||
return warmup_impl(self, *args, **kwargs)
|
||||
nonlocal unsafe_static_forward_fn
|
||||
if unsafe_static_forward_fn is None:
|
||||
return static_forward_fn(self, *args, **kwargs)
|
||||
return unsafe_static_forward_fn(self, *args, **kwargs)
|
||||
|
||||
return static_forward
|
||||
|
||||
|
||||
class GraphOptBackend:
|
||||
@@ -42,10 +129,14 @@ class GraphOptBackend:
|
||||
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
||||
|
||||
# 2. Convert dynamic grpah to static graph
|
||||
from paddle.jit import sot
|
||||
|
||||
backend = Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI
|
||||
self.runnable = sot.symbolic_translate(self.runnable, training=False, backend=backend)
|
||||
backend = (
|
||||
ToStaticBackend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else ToStaticBackend.PHI
|
||||
)
|
||||
self.runnable = apply_to_static_optimization(
|
||||
self.runnable.__func__,
|
||||
backend,
|
||||
).__get__(self.runnable.__self__)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if not self.fd_config.graph_opt_config.use_cudagraph:
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .attention import Attention
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .block_multihead_attn_backend import BlockAttentionBackend
|
||||
@@ -32,4 +33,5 @@ __all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"IluvatarAttnBackend",
|
||||
"BlockAttentionBackend",
|
||||
"Attention",
|
||||
]
|
||||
|
@@ -66,13 +66,13 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class AppendAttentionBackend(AttentionBackend):
|
||||
@@ -80,6 +80,9 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
AppendAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: AppendAttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
|
@@ -56,13 +56,13 @@ class BlockAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class BlockAttentionBackend(AttentionBackend):
|
||||
@@ -70,6 +70,9 @@ class BlockAttentionBackend(AttentionBackend):
|
||||
BlockAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: BlockAttentionBackend
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
|
@@ -66,8 +66,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
@@ -81,7 +81,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@@ -89,6 +89,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
FlashAttentionBackend backend implementation
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: FlashAttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
|
@@ -82,13 +82,13 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class MLAAttentionBackend(AttentionBackend):
|
||||
@@ -96,6 +96,9 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
MLA Attention Backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: MLAAttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
|
@@ -62,13 +62,13 @@ class XPUAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class XPUAttentionBackend(AttentionBackend):
|
||||
@@ -76,6 +76,9 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
XPUAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: XPUAttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
|
@@ -70,6 +70,9 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
GCUFlashAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: GCUFlashAttnBackend
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
|
@@ -406,7 +406,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
text_input = None
|
||||
@@ -584,7 +584,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
hidden_states = self.ernie(
|
||||
|
@@ -536,7 +536,7 @@ def parse_args():
|
||||
"--graph_optimization_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help=" Configation of Graph optimization backend. ",
|
||||
help="Configation of Graph optimization backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guided_decoding_backend",
|
||||
|
Reference in New Issue
Block a user