mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * delete use_fast_ffn
This commit is contained in:
@@ -5,12 +5,6 @@ default_stages:
|
||||
- pre-commit # Run locally
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
@@ -29,15 +23,6 @@ repos:
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
|
@@ -581,7 +581,7 @@ index d5cdd01..5237f09 100644
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
@@ -640,4 +640,3 @@ index d5cdd01..5237f09 100644
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
@@ -69,7 +69,6 @@ class ModelConfig(PretrainedConfig):
|
||||
max_seq_len: int = 512,
|
||||
initializer_range: float = 0.02,
|
||||
use_rope=True,
|
||||
use_fast_ffn: bool = False,
|
||||
rope_theta: int = 10000,
|
||||
rope_3d: bool = False,
|
||||
ori_vocab_size: int | None = None,
|
||||
@@ -104,7 +103,6 @@ class ModelConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.use_rope = use_rope
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.rope_theta = rope_theta
|
||||
self.ori_vocab_size = ori_vocab_size or vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
@@ -199,7 +197,7 @@ class ParallelConfig:
|
||||
eos_tokens_lens: int = 2
|
||||
# Enable chunked prefill
|
||||
enable_chunked_prefill: str = "store_true"
|
||||
#
|
||||
|
||||
max_num_batched_tokens: int = 2048
|
||||
# enable prefix cache
|
||||
enable_prefix_caching = None
|
||||
|
@@ -728,7 +728,7 @@ class Config:
|
||||
), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
import xgrammar
|
||||
import xgrammar # noqa
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
|
@@ -12,16 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .attention import Attention
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
from .mla_attention_backend import MLAAttentionBackend
|
||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||
from .xpu_attn_backend import XPUAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
||||
"MLAAttentionBackend"
|
||||
"MLAAttentionBackend", "FlashAttentionBackend"
|
||||
]
|
||||
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
243
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
243
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
pre_cache_num_blocks_cpu = None
|
||||
kv_token_num_cpu = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
||||
fd_config.parallel_config.expert_parallel_rank
|
||||
if self.device_id is None:
|
||||
self.device_id = device_id
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[device_id]
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.set_max_lengths[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0],
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)
|
||||
res = flash_attention_v3_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.set_max_lengths[0],
|
||||
max_seqlen_k=metadata.set_max_lengths[3],
|
||||
causal=self.causal,
|
||||
)[0].reshape([-1, self.hidden_size])
|
||||
return res
|
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
@@ -17,10 +17,16 @@
|
||||
from .append_attention import append_attention
|
||||
from .get_block_shape_and_split_kv_block import \
|
||||
get_block_shape_and_split_kv_block
|
||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||
from .init_signal_layerwise import init_signal_layerwise
|
||||
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
|
||||
from .pre_cache_len_concat import pre_cache_len_concat
|
||||
|
||||
__all__ = [
|
||||
"get_block_shape_and_split_kv_block", "append_attention",
|
||||
"open_shm_and_get_meta_signal", "init_signal_layerwise"
|
||||
"get_block_shape_and_split_kv_block",
|
||||
"append_attention",
|
||||
"open_shm_and_get_meta_signal",
|
||||
"init_signal_layerwise",
|
||||
"gqa_rope_write_cache",
|
||||
"pre_cache_len_concat",
|
||||
]
|
||||
|
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def gqa_rope_write_cache(
|
||||
qkv: paddle.Tensor,
|
||||
key_cache: paddle.Tensor,
|
||||
value_cache: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
cu_seqlens_k: paddle.Tensor,
|
||||
rotary_embs: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
padding_offsets: paddle.Tensor,
|
||||
cum_offsets: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
kv_batch_ids: paddle.Tensor,
|
||||
kv_tile_ids_per_batch: paddle.Tensor,
|
||||
kv_num_blocks: paddle.Tensor,
|
||||
cache_batch_ids: paddle.Tensor,
|
||||
cache_tile_ids_per_batch: paddle.Tensor,
|
||||
cache_num_blocks: paddle.Tensor,
|
||||
cache_k_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_zp: Optional[paddle.Tensor] = None,
|
||||
cache_v_zp: Optional[paddle.Tensor] = None,
|
||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
cache_quant_type: str = "none"):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||
q, k, v, qkv_ = gqa_rope_write_cache(
|
||||
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
|
||||
rotary_embs, seq_lens_this_time, seq_lens_encoder,
|
||||
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
|
||||
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
|
||||
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
|
||||
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
|
||||
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
|
||||
kv_token_num, max_seq_len, cache_quant_type)
|
||||
return q, k, v, qkv_
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License,
|
||||
Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
max_dec_len: int = 0,
|
||||
block_size: int = 64):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
|
||||
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
|
||||
max_dec_len, block_size)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
@@ -329,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
activation: str = "gelu",
|
||||
use_fast_ffn: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -344,11 +343,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
activation (str): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool): Whether to use a faster FFN implementation.
|
||||
Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.activation = activation
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
@@ -385,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"gate_proj")
|
||||
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
converted_bias_tensor = paddle.zeros(shape=list(
|
||||
bias_tensor.shape),
|
||||
dtype=bias_tensor.dtype)
|
||||
if not self.use_fast_ffn:
|
||||
converted_bias_tensor = paddle.concat(
|
||||
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
|
||||
else:
|
||||
converted_bias_tensor = bias_tensor
|
||||
state_dict[self.bias_key] = converted_bias_tensor
|
||||
|
||||
if not self.use_fast_ffn:
|
||||
converted_weight_tensor = paddle.concat(
|
||||
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
|
||||
else:
|
||||
converted_weight_tensor = weight_tensor
|
||||
state_dict[self.bias_key] = bias_tensor
|
||||
|
||||
state_dict[self.weight_key] = converted_weight_tensor
|
||||
state_dict[self.weight_key] = weight_tensor
|
||||
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
@@ -15,8 +15,9 @@
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..attention import Attention
|
||||
from ..moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
@@ -214,7 +214,7 @@ def load_tp_checkpoint_v1(
|
||||
need_tp = True if tensor_parallel_filtered_map else False
|
||||
state_dict = {}
|
||||
for key, weight in weights_iterator:
|
||||
paddle.device.cuda.synchronize()
|
||||
paddle.device.synchronize()
|
||||
if need_tp and key in tensor_parallel_filtered_map:
|
||||
action = tensor_parallel_filtered_map.pop(key)
|
||||
tensor = action(weight).clone()
|
||||
|
@@ -86,7 +86,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
if isinstance(v, paddle.Tensor):
|
||||
v.value().get_tensor()._clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.cuda.synchronize()
|
||||
paddle.device.synchronize()
|
||||
|
||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
context = paddle.LazyGuard()
|
||||
|
@@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
|
||||
@@ -68,7 +68,6 @@ class DeepSeekV3MLP(nn.Layer):
|
||||
output_size=intermediate_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -62,7 +62,6 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
output_size=intermediate_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -55,7 +55,6 @@ class Qwen2MLP(nn.Layer):
|
||||
output_size=fd_config.model_config.ffn_hidden_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -57,7 +57,6 @@ class Qwen3MLP(nn.Layer):
|
||||
output_size=fd_config.model_config.ffn_hidden_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
|
||||
NATIVE_ATTN = enum.auto()
|
||||
APPEND_ATTN = enum.auto()
|
||||
MLA_ATTN = enum.auto()
|
||||
FLASH_ATTN = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
|
@@ -13,9 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
"""
|
||||
cuda platform file
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -65,6 +62,11 @@ class CUDAPlatform(Platform):
|
||||
return (
|
||||
"fastdeploy.model_executor.layers.attention.MLAAttentionBackend"
|
||||
)
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using FLASH ATTN backend.")
|
||||
return (
|
||||
"fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
|
@@ -90,7 +90,8 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model = get_model_from_loader(self.cfg)
|
||||
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to model_inputs"""
|
||||
max_dec_len = expected_decode_len + 1
|
||||
self.num_gpu_blocks = self.parallel_config.max_block_num
|
||||
@@ -131,9 +132,9 @@ class MTPProposer(Proposer):
|
||||
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
if (self.quant_config and
|
||||
hasattr(self.quant_config, "kv_cache_quant_type") and
|
||||
self.quant_config.kv_cache_quant_type is not None):
|
||||
if (self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None):
|
||||
cache_type = 'uint8'
|
||||
|
||||
# Get kv cache shape
|
||||
@@ -190,8 +191,7 @@ class MTPProposer(Proposer):
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend(
|
||||
self.parallel_config.attention_backend)
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
self.cfg,
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
@@ -200,8 +200,8 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
f"{ self.parallel_config.attention_backend} attention backend"
|
||||
" is not support by GPUModelRunner")
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def clear_dummy_input(self):
|
||||
|
@@ -593,7 +593,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.model = get_model_from_loader(fd_config=self.fd_config)
|
||||
# 1.1 Load RL dynamic model
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
from fastdeploy.rl.dynamic_weight_manager import \
|
||||
DynamicWeightManager
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
||||
|
||||
# 2. Load lora model
|
||||
@@ -719,7 +720,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
"Attention backend which you chose is not support by GPUModelRunner"
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
@@ -1150,7 +1151,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.clear_dummy_input()
|
||||
# paddle.device.cuda.synchronize()
|
||||
|
||||
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
|
@@ -590,7 +590,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
"Attention backend which you chose is not support by GPUModelRunner"
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
|
Reference in New Issue
Block a user