[feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* delete use_fast_ffn
This commit is contained in:
Yuanle Liu
2025-07-03 22:33:27 +08:00
committed by GitHub
parent 00863c43fd
commit 240bdac2a4
26 changed files with 455 additions and 139 deletions

View File

@@ -5,12 +5,6 @@ default_stages:
- pre-commit # Run locally - pre-commit # Run locally
# - manual # Run in CI # - manual # Run in CI
repos: repos:
# 格式化
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
# 代码检查 # 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7 rev: v0.11.7
@@ -29,15 +23,6 @@ repos:
rev: 6.0.1 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
# # 格式化
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v20.1.3
# hooks:
# - id: clang-format
# # exclude: '.*'
# types_or: [c++, cuda]
# args: [--style=file, --verbose]
# markdown # markdown
- repo: https://github.com/jackdewinter/pymarkdown - repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29 rev: v0.9.29

View File

@@ -581,7 +581,7 @@ index d5cdd01..5237f09 100644
# Flush L2 cache with 256 MB data # Flush L2 cache with 256 MB data
- torch.cuda.synchronize() - torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
+ paddle.device.cuda.synchronize() + paddle.device.synchronize()
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32) + cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
cache.zero_() cache.zero_()
@@ -640,4 +640,3 @@ index d5cdd01..5237f09 100644
if not using_nsys: if not using_nsys:
-- --
2.43.0 2.43.0

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Optional, Literal from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -69,7 +69,6 @@ class ModelConfig(PretrainedConfig):
max_seq_len: int = 512, max_seq_len: int = 512,
initializer_range: float = 0.02, initializer_range: float = 0.02,
use_rope=True, use_rope=True,
use_fast_ffn: bool = False,
rope_theta: int = 10000, rope_theta: int = 10000,
rope_3d: bool = False, rope_3d: bool = False,
ori_vocab_size: int | None = None, ori_vocab_size: int | None = None,
@@ -104,7 +103,6 @@ class ModelConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.use_rope = use_rope self.use_rope = use_rope
self.use_fast_ffn = use_fast_ffn
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.ori_vocab_size = ori_vocab_size or vocab_size self.ori_vocab_size = ori_vocab_size or vocab_size
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
@@ -199,7 +197,7 @@ class ParallelConfig:
eos_tokens_lens: int = 2 eos_tokens_lens: int = 2
# Enable chunked prefill # Enable chunked prefill
enable_chunked_prefill: str = "store_true" enable_chunked_prefill: str = "store_true"
#
max_num_batched_tokens: int = 2048 max_num_batched_tokens: int = 2048
# enable prefix cache # enable prefix cache
enable_prefix_caching = None enable_prefix_caching = None

View File

@@ -728,7 +728,7 @@ class Config:
), "XPU currently do not support guided_decoding" ), "XPU currently do not support guided_decoding"
try: try:
import xgrammar import xgrammar # noqa
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"

View File

@@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .attention import Attention
from .append_attn_backend import AppendAttentionBackend from .append_attn_backend import AppendAttentionBackend
from .attention_selecter import get_attention_backend from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend from .base_attention_backend import AttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .mla_attention_backend import MLAAttentionBackend from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend from .xpu_attn_backend import XPUAttentionBackend
__all__ = [ __all__ = [
"Attention", "AttentionBackend", "PaddleNativeAttnBackend", "AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend" "MLAAttentionBackend", "FlashAttentionBackend"
] ]

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata) AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta

View File

@@ -0,0 +1,243 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import List, Optional
import paddle
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
from fastdeploy.worker.forward_meta import ForwardMeta
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
pre_cache_num_blocks_cpu = None
kv_token_num_cpu = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.max_seq_len = fd_config.parallel_config.max_model_len
self.causal = getattr(fd_config.model_config, "causal", True)
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim
self.hidden_size = fd_config.model_config.hidden_size
self.block_size = fd_config.parallel_config.block_size
self.num_layers: int = fd_config.model_config.num_layers
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank
if self.device_id is None:
self.device_id = device_id
else:
self.device_id = self.device_id.split(",")[device_id]
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.kv_num_heads, self.block_size,
self.head_dim)
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.set_max_lengths[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(
metadata.decoder_tile_ids_per_batch, False)
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0],
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
)
res = flash_attention_v3_varlen(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=metadata.set_max_lengths[0],
max_seqlen_k=metadata.set_max_lengths[3],
causal=self.causal,
)[0].reshape([-1, self.hidden_size])
return res

View File

@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata) AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta

View File

@@ -17,10 +17,16 @@
from .append_attention import append_attention from .append_attention import append_attention
from .get_block_shape_and_split_kv_block import \ from .get_block_shape_and_split_kv_block import \
get_block_shape_and_split_kv_block get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_signal_layerwise import init_signal_layerwise from .init_signal_layerwise import init_signal_layerwise
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
from .pre_cache_len_concat import pre_cache_len_concat
__all__ = [ __all__ = [
"get_block_shape_and_split_kv_block", "append_attention", "get_block_shape_and_split_kv_block",
"open_shm_and_get_meta_signal", "init_signal_layerwise" "append_attention",
"open_shm_and_get_meta_signal",
"init_signal_layerwise",
"gqa_rope_write_cache",
"pre_cache_len_concat",
] ]

View File

@@ -0,0 +1,66 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.platforms import current_platform
def gqa_rope_write_cache(
qkv: paddle.Tensor,
key_cache: paddle.Tensor,
value_cache: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
rotary_embs: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
padding_offsets: paddle.Tensor,
cum_offsets: paddle.Tensor,
block_tables: paddle.Tensor,
kv_batch_ids: paddle.Tensor,
kv_tile_ids_per_batch: paddle.Tensor,
kv_num_blocks: paddle.Tensor,
cache_batch_ids: paddle.Tensor,
cache_tile_ids_per_batch: paddle.Tensor,
cache_num_blocks: paddle.Tensor,
cache_k_quant_scales: Optional[paddle.Tensor] = None,
cache_v_quant_scales: Optional[paddle.Tensor] = None,
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
cache_k_zp: Optional[paddle.Tensor] = None,
cache_v_zp: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None,
kv_token_num: int = 1,
max_seq_len: int = 0,
cache_quant_type: str = "none"):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
q, k, v, qkv_ = gqa_rope_write_cache(
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
rotary_embs, seq_lens_this_time, seq_lens_encoder,
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
kv_token_num, max_seq_len, cache_quant_type)
return q, k, v, qkv_
else:
raise NotImplementedError()

View File

@@ -0,0 +1,36 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License,
Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.platforms import current_platform
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
max_dec_len: int = 0,
block_size: int = 64):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
max_dec_len, block_size)
return out
else:
raise NotImplementedError()

View File

@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata) AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta

View File

@@ -329,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
with_bias: bool = False, with_bias: bool = False,
add_bias: bool = False, add_bias: bool = False,
activation: str = "gelu", activation: str = "gelu",
use_fast_ffn: bool = False,
skip_quant: bool = False, skip_quant: bool = False,
): ):
""" """
@@ -344,11 +343,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
with_bias (bool): Whether to include bias or not. Defaults to False. with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False. add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
activation (str): Activation function to use. Defaults to "gelu". activation (str): Activation function to use. Defaults to "gelu".
use_fast_ffn (bool): Whether to use a faster FFN implementation.
Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False. skip_quant (bool): Whether to skip quantization. Defaults to False.
""" """
self.use_fast_ffn = use_fast_ffn
self.activation = activation self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_degree self.nranks = fd_config.parallel_config.tensor_parallel_degree
@@ -385,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"gate_proj") "gate_proj")
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype( bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
paddle.get_default_dtype()) paddle.get_default_dtype())
converted_bias_tensor = paddle.zeros(shape=list(
bias_tensor.shape),
dtype=bias_tensor.dtype)
if not self.use_fast_ffn:
converted_bias_tensor = paddle.concat(
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
else:
converted_bias_tensor = bias_tensor
state_dict[self.bias_key] = converted_bias_tensor
if not self.use_fast_ffn: state_dict[self.bias_key] = bias_tensor
converted_weight_tensor = paddle.concat(
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
else:
converted_weight_tensor = weight_tensor
state_dict[self.weight_key] = converted_weight_tensor state_dict[self.weight_key] = weight_tensor
super().load_state_dict(state_dict) super().load_state_dict(state_dict)

View File

@@ -15,8 +15,9 @@
""" """
from typing import Optional from typing import Optional
from ..attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from ..moe import FusedMoE from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from . import get_quantization_config from . import get_quantization_config
from .quant_base import QuantConfigBase, QuantMethodBase from .quant_base import QuantConfigBase, QuantMethodBase

View File

@@ -214,7 +214,7 @@ def load_tp_checkpoint_v1(
need_tp = True if tensor_parallel_filtered_map else False need_tp = True if tensor_parallel_filtered_map else False
state_dict = {} state_dict = {}
for key, weight in weights_iterator: for key, weight in weights_iterator:
paddle.device.cuda.synchronize() paddle.device.synchronize()
if need_tp and key in tensor_parallel_filtered_map: if need_tp and key in tensor_parallel_filtered_map:
action = tensor_parallel_filtered_map.pop(key) action = tensor_parallel_filtered_map.pop(key)
tensor = action(weight).clone() tensor = action(weight).clone()

View File

@@ -86,7 +86,7 @@ class DefaultModelLoader(BaseModelLoader):
if isinstance(v, paddle.Tensor): if isinstance(v, paddle.Tensor):
v.value().get_tensor()._clear() v.value().get_tensor()._clear()
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
paddle.device.cuda.synchronize() paddle.device.synchronize()
def load_model(self, fd_config: FDConfig) -> nn.Layer: def load_model(self, fd_config: FDConfig) -> nn.Layer:
context = paddle.LazyGuard() context = paddle.LazyGuard()

View File

@@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \ from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ( from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear, ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
@@ -68,7 +68,6 @@ class DeepSeekV3MLP(nn.Layer):
output_size=intermediate_size * 2, output_size=intermediate_size * 2,
with_bias=False, with_bias=False,
activation=fd_config.model_config.hidden_act, activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(

View File

@@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \ from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ( from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -62,7 +62,6 @@ class Ernie4_5_MLP(nn.Layer):
output_size=intermediate_size * 2, output_size=intermediate_size * 2,
with_bias=False, with_bias=False,
activation=fd_config.model_config.hidden_act, activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \ from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ( from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -55,7 +55,6 @@ class Qwen2MLP(nn.Layer):
output_size=fd_config.model_config.ffn_hidden_size * 2, output_size=fd_config.model_config.ffn_hidden_size * 2,
with_bias=False, with_bias=False,
activation=fd_config.model_config.hidden_act, activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(

View File

@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, ModelConfig from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \ from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization support_graph_optimization
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (QKVParallelLinear, from fastdeploy.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \ from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ( from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -57,7 +57,6 @@ class Qwen3MLP(nn.Layer):
output_size=fd_config.model_config.ffn_hidden_size * 2, output_size=fd_config.model_config.ffn_hidden_size * 2,
with_bias=False, with_bias=False,
activation=fd_config.model_config.hidden_act, activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(

View File

@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
NATIVE_ATTN = enum.auto() NATIVE_ATTN = enum.auto()
APPEND_ATTN = enum.auto() APPEND_ATTN = enum.auto()
MLA_ATTN = enum.auto() MLA_ATTN = enum.auto()
FLASH_ATTN = enum.auto()
class Platform: class Platform:

View File

@@ -13,9 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
"""
cuda platform file
"""
import paddle import paddle
@@ -65,6 +62,11 @@ class CUDAPlatform(Platform):
return ( return (
"fastdeploy.model_executor.layers.attention.MLAAttentionBackend" "fastdeploy.model_executor.layers.attention.MLAAttentionBackend"
) )
elif selected_backend == _Backend.FLASH_ATTN:
logger.info("Using FLASH ATTN backend.")
return (
"fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
)
else: else:
raise ValueError( raise ValueError(
"Invalid attention backend you specified.\n" "Invalid attention backend you specified.\n"

View File

@@ -90,7 +90,8 @@ class MTPProposer(Proposer):
self.model = get_model_from_loader(self.cfg) self.model = get_model_from_loader(self.cfg)
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): def dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int):
"""Set dummy prefill inputs to model_inputs""" """Set dummy prefill inputs to model_inputs"""
max_dec_len = expected_decode_len + 1 max_dec_len = expected_decode_len + 1
self.num_gpu_blocks = self.parallel_config.max_block_num self.num_gpu_blocks = self.parallel_config.max_block_num
@@ -131,9 +132,9 @@ class MTPProposer(Proposer):
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
if (self.quant_config and if (self.quant_config
hasattr(self.quant_config, "kv_cache_quant_type") and and hasattr(self.quant_config, "kv_cache_quant_type")
self.quant_config.kv_cache_quant_type is not None): and self.quant_config.kv_cache_quant_type is not None):
cache_type = 'uint8' cache_type = 'uint8'
# Get kv cache shape # Get kv cache shape
@@ -190,8 +191,7 @@ class MTPProposer(Proposer):
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend( attn_cls = get_attention_backend()
self.parallel_config.attention_backend)
attn_backend = attn_cls( attn_backend = attn_cls(
self.cfg, self.cfg,
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
@@ -200,8 +200,8 @@ class MTPProposer(Proposer):
) )
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
f"{ self.parallel_config.attention_backend} attention backend" "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
" is not support by GPUModelRunner") )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
def clear_dummy_input(self): def clear_dummy_input(self):

View File

@@ -593,7 +593,8 @@ class GPUModelRunner(ModelRunnerBase):
self.model = get_model_from_loader(fd_config=self.fd_config) self.model = get_model_from_loader(fd_config=self.fd_config)
# 1.1 Load RL dynamic model # 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight: if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager from fastdeploy.rl.dynamic_weight_manager import \
DynamicWeightManager
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
# 2. Load lora model # 2. Load lora model
@@ -719,7 +720,7 @@ class GPUModelRunner(ModelRunnerBase):
head_dim=head_dim) head_dim=head_dim)
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
"Attention backend which you chose is not support by GPUModelRunner" "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
@@ -1150,7 +1151,6 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.clear_dummy_input() self.proposer.clear_dummy_input()
# paddle.device.cuda.synchronize()
def update_share_input_block_num(self, num_gpu_blocks: int) -> None: def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
""" """

View File

@@ -590,7 +590,7 @@ class XPUModelRunner(ModelRunnerBase):
head_dim=head_dim) head_dim=head_dim)
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
"Attention backend which you chose is not support by GPUModelRunner" "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)