[SOT] Mark dynamic dims by type annotations (#2771)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [SOT] Mark dynamic dims by type annotations

* fix conflict of forward_meta

* mark more attn backend

* fix missing annotated and add env SOT_SPECIALIZED_DIM_NUMBERS

* auto infer implicit 0 dim dynamic dim

* revert manual marked dims

* revert missing update

* auto infer can use unsafe code in warmup stage

* check -> type_match

* fix codestyle

* restore blank line

* empty commit

* add need_warmup nonlocal;

* add doc for resolver

* add missing type hints

* unquote "ForwardMeta"
This commit is contained in:
Nyakku Shigure
2025-07-22 15:23:52 +08:00
committed by GitHub
parent e991777757
commit 48e6a0ca26
13 changed files with 330 additions and 28 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from .append_attn_backend import AppendAttentionBackend
from .attention import Attention
from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend
from .block_multihead_attn_backend import BlockAttentionBackend
@@ -32,4 +33,5 @@ __all__ = [
"FlashAttentionBackend",
"IluvatarAttnBackend",
"BlockAttentionBackend",
"Attention",
]

View File

@@ -66,13 +66,13 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class AppendAttentionBackend(AttentionBackend):
@@ -80,6 +80,9 @@ class AppendAttentionBackend(AttentionBackend):
AppendAttentionBackend backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: AppendAttentionMetadata
def __init__(
self,
fd_config: FDConfig,

View File

@@ -56,13 +56,13 @@ class BlockAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class BlockAttentionBackend(AttentionBackend):
@@ -70,6 +70,9 @@ class BlockAttentionBackend(AttentionBackend):
BlockAttentionBackend backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: BlockAttentionBackend
def __init__(
self,
fd_config: FDConfig,

View File

@@ -66,8 +66,8 @@ class FlashAttentionMetadata(AttentionMetadata):
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -81,7 +81,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend):
@@ -89,6 +89,9 @@ class FlashAttentionBackend(AttentionBackend):
FlashAttentionBackend backend implementation
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashAttentionMetadata
def __init__(
self,
fd_config: FDConfig,

View File

@@ -82,13 +82,13 @@ class MLAAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class MLAAttentionBackend(AttentionBackend):
@@ -96,6 +96,9 @@ class MLAAttentionBackend(AttentionBackend):
MLA Attention Backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: MLAAttentionMetadata
def __init__(
self,
fd_config: FDConfig,

View File

@@ -62,13 +62,13 @@ class XPUAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class XPUAttentionBackend(AttentionBackend):
@@ -76,6 +76,9 @@ class XPUAttentionBackend(AttentionBackend):
XPUAttentionBackend backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: XPUAttentionMetadata
def __init__(
self,
fd_config: FDConfig,