[feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* delete use_fast_ffn
This commit is contained in:
Yuanle Liu
2025-07-03 22:33:27 +08:00
committed by GitHub
parent 00863c43fd
commit 240bdac2a4
26 changed files with 455 additions and 139 deletions

View File

@@ -5,12 +5,6 @@ default_stages:
- pre-commit # Run locally
# - manual # Run in CI
repos:
# 格式化
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
# 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
@@ -29,15 +23,6 @@ repos:
rev: 6.0.1
hooks:
- id: isort
# # 格式化
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v20.1.3
# hooks:
# - id: clang-format
# # exclude: '.*'
# types_or: [c++, cuda]
# args: [--style=file, --verbose]
# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29

View File

@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle
from . import jit
from .jit_kernels import (
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
index fcb377e..db9d6f3 100644
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
import subprocess
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
def run_cuobjdump(file_path):
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
index 66c370a..4761426 100644
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
-import torch
+import paddle
from typing import Optional
from .template import map_ctype
@@ -35,7 +35,7 @@ class Runtime:
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
-import torch
+import paddle
from typing import Any, Dict, Iterable, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
}
@@ -27,25 +27,25 @@ genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
}
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
- if value.dtype == torch.int:
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
+import paddle
from functools import lru_cache
from typing import Tuple
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor) -> None:
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow paddle operations.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape
- assert n % 64 == 0 and k % 128 == 0
+ # assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
- assert m == m_ and n == n_ and k == k_
- assert n > 0 and k > 0
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
-import torch
+import paddle
from typing import Tuple
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
"""
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow Pypaddle operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
Values of `m_indices` in every-m-alignment-block must also be the same.
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
m__ = m_indices.numel()
# Type and shape checks
- assert m == m_ == m__ and k == k_ and n == n_
- assert lhs_scales.shape == (m, (k + 127) // 128)
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
+ # assert m_indices.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and m_indices.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
)
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
runtime(*args)
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow paddle operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups___ = masked_m.numel()
# Type and shape checks
- assert num_groups == num_groups_ == num_groups__ == num_groups___
- assert m == m_ and n == n_ and k == k_
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
+ # assert masked_m.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
- torch.cuda.current_stream(), num_sms, smem_config[0])
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
-import torch
+import paddle
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
@@ -51,10 +51,10 @@ class JITTuner:
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle
_num_sms = None
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
_num_sms = num_sms
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
"""
global _num_sms
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
return _num_sms
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
return ceil_div(x, alignment) * alignment
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
"""
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
+ aligned_x = paddle.transpose(
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
-import torch.distributed as dist
+import paddle
+import paddle.distributed as dist
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
+ paddle.device.cuda.synchronize()
+ paddle.device.synchronize()
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
cache.zero_()
# Warmup
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
x @ y
# Testing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
end_event.record()
- torch.cuda.synchronize()
+ paddle.device.synchronize()
return start_event.elapsed_time(end_event) / num_tests
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
fn()
if not using_nsys:
--
2.43.0
if not using_nsys:
--
2.43.0

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Literal
from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -69,7 +69,6 @@ class ModelConfig(PretrainedConfig):
max_seq_len: int = 512,
initializer_range: float = 0.02,
use_rope=True,
use_fast_ffn: bool = False,
rope_theta: int = 10000,
rope_3d: bool = False,
ori_vocab_size: int | None = None,
@@ -104,7 +103,6 @@ class ModelConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_rope = use_rope
self.use_fast_ffn = use_fast_ffn
self.rope_theta = rope_theta
self.ori_vocab_size = ori_vocab_size or vocab_size
self.max_seq_len = max_seq_len
@@ -199,7 +197,7 @@ class ParallelConfig:
eos_tokens_lens: int = 2
# Enable chunked prefill
enable_chunked_prefill: str = "store_true"
#
max_num_batched_tokens: int = 2048
# enable prefix cache
enable_prefix_caching = None
@@ -349,7 +347,7 @@ class GraphOptimizationConfig:
class LoadConfig:
"""
Configuration for dynamic weight loading strategies
Attributes:
dynamic_load_weight: Whether to enable dynamic weight loading
load_strategy: Specifies the weight loading method when enabled:
@@ -366,7 +364,7 @@ class LoadConfig:
def __post_init__(self):
if self.load_strategy is not None and not self.dynamic_load_weight:
raise ValueError("Load strategy requires dynamic_load_weight=True")
if self.dynamic_load_weight and self.load_strategy is None:
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")

View File

@@ -728,7 +728,7 @@ class Config:
), "XPU currently do not support guided_decoding"
try:
import xgrammar
import xgrammar # noqa
except Exception as e:
raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"

View File

@@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .attention import Attention
from .append_attn_backend import AppendAttentionBackend
from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend
__all__ = [
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
"AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend"
"MLAAttentionBackend", "FlashAttentionBackend"
]

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta

View File

@@ -0,0 +1,243 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import List, Optional
import paddle
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
from fastdeploy.worker.forward_meta import ForwardMeta
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
pre_cache_num_blocks_cpu = None
kv_token_num_cpu = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.max_seq_len = fd_config.parallel_config.max_model_len
self.causal = getattr(fd_config.model_config, "causal", True)
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim
self.hidden_size = fd_config.model_config.hidden_size
self.block_size = fd_config.parallel_config.block_size
self.num_layers: int = fd_config.model_config.num_layers
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank
if self.device_id is None:
self.device_id = device_id
else:
self.device_id = self.device_id.split(",")[device_id]
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.kv_num_heads, self.block_size,
self.head_dim)
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.set_max_lengths[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(
metadata.decoder_tile_ids_per_batch, False)
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0],
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
)
res = flash_attention_v3_varlen(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=metadata.set_max_lengths[0],
max_seqlen_k=metadata.set_max_lengths[3],
causal=self.causal,
)[0].reshape([-1, self.hidden_size])
return res

View File

@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta

View File

@@ -17,10 +17,16 @@
from .append_attention import append_attention
from .get_block_shape_and_split_kv_block import \
get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_signal_layerwise import init_signal_layerwise
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
from .pre_cache_len_concat import pre_cache_len_concat
__all__ = [
"get_block_shape_and_split_kv_block", "append_attention",
"open_shm_and_get_meta_signal", "init_signal_layerwise"
"get_block_shape_and_split_kv_block",
"append_attention",
"open_shm_and_get_meta_signal",
"init_signal_layerwise",
"gqa_rope_write_cache",
"pre_cache_len_concat",
]

View File

@@ -0,0 +1,66 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.platforms import current_platform
def gqa_rope_write_cache(
qkv: paddle.Tensor,
key_cache: paddle.Tensor,
value_cache: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
rotary_embs: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
padding_offsets: paddle.Tensor,
cum_offsets: paddle.Tensor,
block_tables: paddle.Tensor,
kv_batch_ids: paddle.Tensor,
kv_tile_ids_per_batch: paddle.Tensor,
kv_num_blocks: paddle.Tensor,
cache_batch_ids: paddle.Tensor,
cache_tile_ids_per_batch: paddle.Tensor,
cache_num_blocks: paddle.Tensor,
cache_k_quant_scales: Optional[paddle.Tensor] = None,
cache_v_quant_scales: Optional[paddle.Tensor] = None,
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
cache_k_zp: Optional[paddle.Tensor] = None,
cache_v_zp: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None,
kv_token_num: int = 1,
max_seq_len: int = 0,
cache_quant_type: str = "none"):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
q, k, v, qkv_ = gqa_rope_write_cache(
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
rotary_embs, seq_lens_this_time, seq_lens_encoder,
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
kv_token_num, max_seq_len, cache_quant_type)
return q, k, v, qkv_
else:
raise NotImplementedError()

View File

@@ -0,0 +1,36 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License,
Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.platforms import current_platform
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
max_dec_len: int = 0,
block_size: int = 64):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
max_dec_len, block_size)
return out
else:
raise NotImplementedError()

View File

@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta

View File

@@ -329,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
with_bias: bool = False,
add_bias: bool = False,
activation: str = "gelu",
use_fast_ffn: bool = False,
skip_quant: bool = False,
):
"""
@@ -344,11 +343,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
activation (str): Activation function to use. Defaults to "gelu".
use_fast_ffn (bool): Whether to use a faster FFN implementation.
Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.use_fast_ffn = use_fast_ffn
self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_degree
@@ -385,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"gate_proj")
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
paddle.get_default_dtype())
converted_bias_tensor = paddle.zeros(shape=list(
bias_tensor.shape),
dtype=bias_tensor.dtype)
if not self.use_fast_ffn:
converted_bias_tensor = paddle.concat(
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
else:
converted_bias_tensor = bias_tensor
state_dict[self.bias_key] = converted_bias_tensor
if not self.use_fast_ffn:
converted_weight_tensor = paddle.concat(
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
else:
converted_weight_tensor = weight_tensor
state_dict[self.bias_key] = bias_tensor
state_dict[self.weight_key] = converted_weight_tensor
state_dict[self.weight_key] = weight_tensor
super().load_state_dict(state_dict)

View File

@@ -15,8 +15,9 @@
"""
from typing import Optional
from ..attention import Attention
from ..moe import FusedMoE
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from . import get_quantization_config
from .quant_base import QuantConfigBase, QuantMethodBase

View File

@@ -214,7 +214,7 @@ def load_tp_checkpoint_v1(
need_tp = True if tensor_parallel_filtered_map else False
state_dict = {}
for key, weight in weights_iterator:
paddle.device.cuda.synchronize()
paddle.device.synchronize()
if need_tp and key in tensor_parallel_filtered_map:
action = tensor_parallel_filtered_map.pop(key)
tensor = action(weight).clone()

View File

@@ -86,7 +86,7 @@ class DefaultModelLoader(BaseModelLoader):
if isinstance(v, paddle.Tensor):
v.value().get_tensor()._clear()
paddle.device.cuda.empty_cache()
paddle.device.cuda.synchronize()
paddle.device.synchronize()
def load_model(self, fd_config: FDConfig) -> nn.Layer:
context = paddle.LazyGuard()

View File

@@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
@@ -68,7 +68,6 @@ class DeepSeekV3MLP(nn.Layer):
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(

View File

@@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -62,7 +62,6 @@ class Ernie4_5_MLP(nn.Layer):
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -55,7 +55,6 @@ class Qwen2MLP(nn.Layer):
output_size=fd_config.model_config.ffn_hidden_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(

View File

@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -57,7 +57,6 @@ class Qwen3MLP(nn.Layer):
output_size=fd_config.model_config.ffn_hidden_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(

View File

@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
NATIVE_ATTN = enum.auto()
APPEND_ATTN = enum.auto()
MLA_ATTN = enum.auto()
FLASH_ATTN = enum.auto()
class Platform:

View File

@@ -13,9 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
cuda platform file
"""
import paddle
@@ -65,6 +62,11 @@ class CUDAPlatform(Platform):
return (
"fastdeploy.model_executor.layers.attention.MLAAttentionBackend"
)
elif selected_backend == _Backend.FLASH_ATTN:
logger.info("Using FLASH ATTN backend.")
return (
"fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
)
else:
raise ValueError(
"Invalid attention backend you specified.\n"

View File

@@ -90,7 +90,8 @@ class MTPProposer(Proposer):
self.model = get_model_from_loader(self.cfg)
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int):
"""Set dummy prefill inputs to model_inputs"""
max_dec_len = expected_decode_len + 1
self.num_gpu_blocks = self.parallel_config.max_block_num
@@ -130,10 +131,10 @@ class MTPProposer(Proposer):
self.cache_kvs = {}
cache_type = self.parallel_config.dtype
if (self.quant_config and
hasattr(self.quant_config, "kv_cache_quant_type") and
self.quant_config.kv_cache_quant_type is not None):
if (self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None):
cache_type = 'uint8'
# Get kv cache shape
@@ -190,8 +191,7 @@ class MTPProposer(Proposer):
head_dim = self.model_config.head_dim
# Get the attention backend
attn_cls = get_attention_backend(
self.parallel_config.attention_backend)
attn_cls = get_attention_backend()
attn_backend = attn_cls(
self.cfg,
kv_num_heads=self.model_config.kv_num_heads,
@@ -200,8 +200,8 @@ class MTPProposer(Proposer):
)
if attn_backend is None:
raise NotImplementedError(
f"{ self.parallel_config.attention_backend} attention backend"
" is not support by GPUModelRunner")
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
)
self.attn_backends.append(attn_backend)
def clear_dummy_input(self):

View File

@@ -593,7 +593,8 @@ class GPUModelRunner(ModelRunnerBase):
self.model = get_model_from_loader(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
from fastdeploy.rl.dynamic_weight_manager import \
DynamicWeightManager
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
# 2. Load lora model
@@ -622,7 +623,7 @@ class GPUModelRunner(ModelRunnerBase):
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
def clear_cache(self):
"""Clear cached data from shared inputs and forward metadata."""
self.share_inputs.pop("caches", None)
@@ -719,7 +720,7 @@ class GPUModelRunner(ModelRunnerBase):
head_dim=head_dim)
if attn_backend is None:
raise NotImplementedError(
"Attention backend which you chose is not support by GPUModelRunner"
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
)
self.attn_backends.append(attn_backend)
@@ -1150,7 +1151,6 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]:
self.proposer.clear_dummy_input()
# paddle.device.cuda.synchronize()
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
"""

View File

@@ -590,7 +590,7 @@ class XPUModelRunner(ModelRunnerBase):
head_dim=head_dim)
if attn_backend is None:
raise NotImplementedError(
"Attention backend which you chose is not support by GPUModelRunner"
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
)
self.attn_backends.append(attn_backend)