diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index faa05efbf..c7a2d150e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,6 @@ default_stages: - pre-commit # Run locally # - manual # Run in CI repos: -# 格式化 -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] # 代码检查 - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 @@ -29,15 +23,6 @@ repos: rev: 6.0.1 hooks: - id: isort -# # 格式化 -# - repo: https://github.com/pre-commit/mirrors-clang-format -# rev: v20.1.3 -# hooks: -# - id: clang-format -# # exclude: '.*' -# types_or: [c++, cuda] -# args: [--style=file, --verbose] - # markdown - repo: https://github.com/jackdewinter/pymarkdown rev: v0.9.29 diff --git a/custom_ops/0001-DeepGEMM-95e81b3.patch b/custom_ops/0001-DeepGEMM-95e81b3.patch index e62972cec..c3f409c14 100644 --- a/custom_ops/0001-DeepGEMM-95e81b3.patch +++ b/custom_ops/0001-DeepGEMM-95e81b3.patch @@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644 @@ -1,4 +1,4 @@ -import torch +import paddle - + from . import jit from .jit_kernels import ( diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644 -from torch.utils.cpp_extension import CUDA_HOME +from ..paddle_utils import CUDA_HOME from typing import Tuple - + from . import interleave_ffma diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py index fcb377e..db9d6f3 100644 @@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644 import subprocess -from torch.utils.cpp_extension import CUDA_HOME +from ..paddle_utils import CUDA_HOME - - + + def run_cuobjdump(file_path): diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 66c370a..4761426 100644 @@ -78,7 +78,7 @@ index 66c370a..4761426 100644 -import torch +import paddle from typing import Optional - + from .template import map_ctype @@ -35,7 +35,7 @@ class Runtime: assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' @@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644 -import torch +import paddle from typing import Any, Dict, Iterable, Tuple - - + + # Name map for Python `eval` typename_map: Dict[Any, str] = { **{t: t.__name__ for t in (bool, int, float)}, @@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644 + paddle.float8_e4m3fn: 'paddle.float8_e4m3fn', + paddle.device.cuda.Stream: "paddle.device.cuda.Stream", } - + # `ctype` map for Python casting ctype_map: Dict[Any, Any] = { **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, - **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)}, + **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)}, } - - + + @@ -27,25 +27,25 @@ genc_map = { bool: ('bool', 'bool'), int: ('int', 'int'), @@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644 + paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), + paddle.device.cuda.Stream: ('void*', 'cudaStream_t'), } - - + + def map_ctype(value: Any) -> Any: if hasattr(value, 'data_ptr'): - if value.dtype == torch.int: @@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644 +import paddle from functools import lru_cache from typing import Tuple - + @@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config - - + + -def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor) -> None: @@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644 The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. + this function will do a transposing with a set of slow paddle operations. - + Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, + lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`, @@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644 @@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], n, k_ = rhs.shape m_, n_ = out.shape - + - assert n % 64 == 0 and k % 128 == 0 + # assert n % 64 == 0 and k % 128 == 0 - + # Type and shape checks - assert m == m_ and n == n_ and k == k_ - assert n > 0 and k > 0 @@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644 + # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 + # assert out.dtype == paddle.bfloat16 + # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() - + # LHS scales must be transposed for TMA load, but not for RHS scales # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() + # assert rhs_scales.is_contiguous() - + # Do nothing if `m` is zero if m == 0: @@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644 -import torch +import paddle from typing import Tuple - + from .gemm import get_best_configs, get_block_n_padding_for_smem_d @@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout, """ - - + + -def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, m_indices: torch.Tensor) -> None: @@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644 + this function will do a transposing with a set of slow Pypaddle operations. On the M axis, inputs are grouped into several batches, of which batch sizes aligned to `get_m_alignment_for_contiguous_layout()` (128). - + Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, + lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`, @@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644 Values of `m_indices` in every-m-alignment-block must also be the same. @@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten m__ = m_indices.numel() - + # Type and shape checks - assert m == m_ == m__ and k == k_ and n == n_ - assert lhs_scales.shape == (m, (k + 127) // 128) @@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644 + # assert m_indices.dtype == paddle.int32 + # assert lhs.is_contiguous() and rhs.is_contiguous() + # assert out.is_contiguous() and m_indices.is_contiguous() - + # LHS scales must be transposed for TMA load, but not for RHS scales lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() + # assert rhs_scales.is_contiguous() - + # Do nothing if `m` is zero if m == 0: @@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten @@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644 ) @@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten runtime(*args) - - + + -def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: @@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644 + this function will do a transposing with a set of slow paddle operations. Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch should be separately transposed. - + Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, + lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, @@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644 masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute @@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] num_groups___ = masked_m.numel() - + # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ @@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644 + # assert masked_m.dtype == paddle.int32 + # assert lhs.is_contiguous() and rhs.is_contiguous() + # assert out.is_contiguous() and masked_m.is_contiguous() - + # LHS scales must be transposed for TMA load, but not for RHS scales lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() + # assert rhs_scales.is_contiguous() - + # Auto-tuning with compilation global includes, template @@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] - + args = (lhs, lhs_scales, rhs, rhs_scales, out, masked_m, m, - torch.cuda.current_stream(), num_sms, smem_config[0]) @@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644 -import torch +import paddle from typing import Any, Dict - + from ..jit import build, cpp_format, generate, Runtime @@ -51,10 +51,10 @@ class JITTuner: continue - + # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) @@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644 @@ -1,4 +1,4 @@ -import torch +import paddle - + _num_sms = None - + @@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None: num_sms: the desired maximum SM count for all GEMM kernels to use. """ @@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644 - assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count + assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count _num_sms = num_sms - - + + @@ -25,7 +25,7 @@ def get_num_sms() -> int: """ global _num_sms @@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644 - _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count + _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count return _num_sms - - + + @@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int: return ceil_div(x, alignment) * alignment - - + + -def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: +def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor: """ @@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644 + Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary. If the input tensor is already column-major layout and 16-byte aligned along the M axis (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. - + @@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: m, n = x.shape[-2], x.shape[-1] aligned_m = get_tma_aligned_size(m, x.element_size()) @@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644 + if x.strides[0] == 1 and x.strides[1] == aligned_m: return x x, remove_dim = x.unsqueeze(0), True - + b = x.shape[0] - + # The last kernel gives a column-major TMA aligned layout - if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: + if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m: return x.squeeze(0) if remove_dim else x - + # Normal layout requires transposing - aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x = paddle.transpose( @@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644 -import torch.distributed as dist +import paddle +import paddle.distributed as dist - - + + def bench(fn, num_warmups: int = 5, num_tests: int = 10, high_precision: bool = False): # Flush L2 cache with 256 MB data - torch.cuda.synchronize() - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') -+ paddle.device.cuda.synchronize() ++ paddle.device.synchronize() + cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32) cache.zero_() - + # Warmup @@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10, - + # Add a large kernel to eliminate the CPU launch overhead if high_precision: - x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') @@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644 + x = paddle.randn((8192, 8192), dtype=paddle.float32) + y = paddle.randn((8192, 8192), dtype=paddle.float32) x @ y - + # Testing - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) @@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644 end_event.record() - torch.cuda.synchronize() + paddle.device.synchronize() - + return start_event.elapsed_time(end_event) / num_tests - + @@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: # Profile suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress @@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644 - torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + paddle.empty(flush_l2_size, dtype=paddle.int32).zero_() fn() - - if not using_nsys: --- -2.43.0 + if not using_nsys: +-- +2.43.0 diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 715df04eb..446e59298 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -18,7 +18,7 @@ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum -from typing import Optional, Literal +from typing import Literal, Optional from paddleformers.transformers.configuration_utils import PretrainedConfig @@ -69,7 +69,6 @@ class ModelConfig(PretrainedConfig): max_seq_len: int = 512, initializer_range: float = 0.02, use_rope=True, - use_fast_ffn: bool = False, rope_theta: int = 10000, rope_3d: bool = False, ori_vocab_size: int | None = None, @@ -104,7 +103,6 @@ class ModelConfig(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.use_rope = use_rope - self.use_fast_ffn = use_fast_ffn self.rope_theta = rope_theta self.ori_vocab_size = ori_vocab_size or vocab_size self.max_seq_len = max_seq_len @@ -199,7 +197,7 @@ class ParallelConfig: eos_tokens_lens: int = 2 # Enable chunked prefill enable_chunked_prefill: str = "store_true" - # + max_num_batched_tokens: int = 2048 # enable prefix cache enable_prefix_caching = None @@ -349,7 +347,7 @@ class GraphOptimizationConfig: class LoadConfig: """ Configuration for dynamic weight loading strategies - + Attributes: dynamic_load_weight: Whether to enable dynamic weight loading load_strategy: Specifies the weight loading method when enabled: @@ -366,7 +364,7 @@ class LoadConfig: def __post_init__(self): if self.load_strategy is not None and not self.dynamic_load_weight: raise ValueError("Load strategy requires dynamic_load_weight=True") - + if self.dynamic_load_weight and self.load_strategy is None: raise ValueError("Must specify load_strategy when dynamic_load_weight is True") diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index a4efb0c61..c53889058 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -728,7 +728,7 @@ class Config: ), "XPU currently do not support guided_decoding" try: - import xgrammar + import xgrammar # noqa except Exception as e: raise Exception( f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index afbf916a5..6a1d0e1c1 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .attention import Attention from .append_attn_backend import AppendAttentionBackend from .attention_selecter import get_attention_backend from .base_attention_backend import AttentionBackend +from .flash_attn_backend import FlashAttentionBackend from .mla_attention_backend import MLAAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend __all__ = [ - "Attention", "AttentionBackend", "PaddleNativeAttnBackend", + "AttentionBackend", "PaddleNativeAttnBackend", "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", - "MLAAttentionBackend" + "MLAAttentionBackend", "FlashAttentionBackend" ] diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index eb82e0bf9..5bc7f420a 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from paddle._typing.dtype_like import _DTypeLiteral from fastdeploy.config import FDConfig -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) from fastdeploy.worker.forward_meta import ForwardMeta diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py new file mode 100644 index 000000000..74a234bd1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -0,0 +1,243 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import List, Optional + +import paddle +from paddle.nn.functional.flash_attention import flash_attention_v3_varlen + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.layers.attention.ops import ( + get_block_shape_and_split_kv_block, gqa_rope_write_cache, + init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) +from fastdeploy.worker.forward_meta import ForwardMeta + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """ + FlashAttentionMetadata + """ + max_len_kv: paddle.Tensor = None + set_max_lengths: int = -1 + rotary_embs: Optional[paddle.Tensor] = None + block_tables: Optional[paddle.Tensor] = None + encoder_batch_ids: paddle.Tensor = None + encoder_tile_ids_per_batch: paddle.Tensor = None + encoder_num_blocks: paddle.Tensor = None + kv_batch_ids: paddle.Tensor = None + kv_tile_ids_per_batch: paddle.Tensor = None + kv_num_blocks: paddle.Tensor = None + decoder_batch_ids: paddle.Tensor = None + decoder_tile_ids_per_batch: paddle.Tensor = None + decoder_num_blocks: paddle.Tensor = None + + encoder_block_shape_q: Optional[paddle.Tensor] = None + decoder_block_shape_q: Optional[paddle.Tensor] = None + + cu_seqlens_q: paddle.Tensor = None + cu_seqlens_k: paddle.Tensor = None + max_seqlen_q: int = 0 + max_seqlen_k: int = 0 + + pre_cache_batch_ids = None + pre_cache_tile_ids_per_batch = None + pre_cache_num_blocks_cpu = None + kv_token_num_cpu = None + + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) + + +class FlashAttentionBackend(AttentionBackend): + """ + FlashAttentionBackend backend implementation + """ + + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, + head_dim: int): + """ + FlashAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: FlashAttentionMetadata = None + self.max_seq_len = fd_config.parallel_config.max_model_len + self.causal = getattr(fd_config.model_config, "causal", True) + + self.kv_num_heads = kv_num_heads + self.num_heads = num_heads + self.head_dim = fd_config.model_config.head_dim + self.hidden_size = fd_config.model_config.hidden_size + self.block_size = fd_config.parallel_config.block_size + self.num_layers: int = fd_config.model_config.num_layers + + self.speculative_method = fd_config.speculative_config.method + self.use_speculate = self.speculative_method is not None + self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.rank: int = fd_config.parallel_config.tensor_parallel_rank + + # pd_disaggregation + self.use_pd_disaggregation: int = int( + os.getenv("FLAGS_use_pd_disaggregation", 0)) + self.start_layer_index: int = fd_config.model_config.start_layer_index + self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) + + if fd_config.parallel_config.expert_parallel_rank is None: + fd_config.parallel_config.expert_parallel_rank = 0 + device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \ + fd_config.parallel_config.expert_parallel_rank + if self.device_id is None: + self.device_id = device_id + else: + self.device_id = self.device_id.split(",")[device_id] + + def get_attntion_meta(self): + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + ): + """ + Caculate kv cache shape + """ + return (max_num_blocks, self.kv_num_heads, self.block_size, + self.head_dim) + + def init_attention_metadata(self, forward_meta: ForwardMeta): + metadata = FlashAttentionMetadata() + metadata.encoder_block_shape_q = 64 + metadata.decoder_block_shape_q = 16 + metadata.cu_seqlens_q = forward_meta.cu_seqlens_q + metadata.rotary_embs = forward_meta.rotary_embs + metadata.block_tables = forward_meta.block_tables + ( + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + metadata.decoder_batch_ids, + metadata.decoder_tile_ids_per_batch, + metadata.decoder_num_blocks, + metadata.max_len_kv, + metadata.set_max_lengths, + ) = get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cum_offsets, + metadata.encoder_block_shape_q, + metadata.decoder_block_shape_q, + self.num_heads // self.kv_num_heads, + self.block_size, + self.speculate_max_draft_token_num + 1, + ) + + ( + metadata.cu_seqlens_k, + metadata.pre_cache_batch_ids, + metadata.pre_cache_tile_ids_per_batch, + metadata.pre_cache_num_blocks_cpu, + metadata.kv_token_num_cpu, + ) = pre_cache_len_concat( + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + metadata.set_max_lengths[2], + self.block_size, + ) + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + if self.use_pd_disaggregation: + metadata.kv_signal_metadata = open_shm_and_get_meta_signal( + self.rank, int(self.device_id), self.keep_pd_step_flag) + self.attention_metadata = metadata + forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) + forward_meta.decoder_tile_ids_per_batch.copy_( + metadata.decoder_tile_ids_per_batch, False) + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ): + metadata = self.attention_metadata + + if self.use_pd_disaggregation: + metadata.kv_signal_data_list[ + layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index) + + q, k, v, _ = gqa_rope_write_cache( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.rotary_embs, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + metadata.pre_cache_batch_ids, + metadata.pre_cache_tile_ids_per_batch, + metadata.pre_cache_num_blocks_cpu, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + metadata.kv_token_num_cpu[0], + self.max_seq_len, + getattr(layer, "cache_quant_type_str", "none"), + ) + res = flash_attention_v3_varlen( + q, + k, + v, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + max_seqlen_q=metadata.set_max_lengths[0], + max_seqlen_k=metadata.set_max_lengths[3], + causal=self.causal, + )[0].reshape([-1, self.hidden_size]) + return res diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 52489bd5f..1d9c9773b 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from paddle._typing.dtype_like import _DTypeLiteral from fastdeploy.config import FDConfig -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) from fastdeploy.worker.forward_meta import ForwardMeta diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index 8b75ce6f0..95cc06129 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -17,10 +17,16 @@ from .append_attention import append_attention from .get_block_shape_and_split_kv_block import \ get_block_shape_and_split_kv_block +from .gqa_rope_write_cache import gqa_rope_write_cache from .init_signal_layerwise import init_signal_layerwise from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal +from .pre_cache_len_concat import pre_cache_len_concat __all__ = [ - "get_block_shape_and_split_kv_block", "append_attention", - "open_shm_and_get_meta_signal", "init_signal_layerwise" + "get_block_shape_and_split_kv_block", + "append_attention", + "open_shm_and_get_meta_signal", + "init_signal_layerwise", + "gqa_rope_write_cache", + "pre_cache_len_concat", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py b/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py new file mode 100644 index 000000000..c012d932a --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py @@ -0,0 +1,66 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + + +def gqa_rope_write_cache( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + rotary_embs: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + padding_offsets: paddle.Tensor, + cum_offsets: paddle.Tensor, + block_tables: paddle.Tensor, + kv_batch_ids: paddle.Tensor, + kv_tile_ids_per_batch: paddle.Tensor, + kv_num_blocks: paddle.Tensor, + cache_batch_ids: paddle.Tensor, + cache_tile_ids_per_batch: paddle.Tensor, + cache_num_blocks: paddle.Tensor, + cache_k_quant_scales: Optional[paddle.Tensor] = None, + cache_v_quant_scales: Optional[paddle.Tensor] = None, + cache_k_dequant_scales: Optional[paddle.Tensor] = None, + cache_v_dequant_scales: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + kv_token_num: int = 1, + max_seq_len: int = 0, + cache_quant_type: str = "none"): + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache + q, k, v, qkv_ = gqa_rope_write_cache( + qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k, + rotary_embs, seq_lens_this_time, seq_lens_encoder, + seq_lens_decoder, padding_offsets, cum_offsets, block_tables, + kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, + cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks, + cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, + cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data, + kv_token_num, max_seq_len, cache_quant_type) + return q, k, v, qkv_ + else: + raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py b/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py new file mode 100644 index 000000000..f0f0780a3 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/pre_cache_len_concat.py @@ -0,0 +1,36 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, + Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, + software +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + + +def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + max_dec_len: int = 0, + block_size: int = 64): + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat + out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, + max_dec_len, block_size) + return out + else: + raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 2ea49f299..9ecc01fb8 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from paddle._typing.dtype_like import _DTypeLiteral from fastdeploy.config import FDConfig -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) from fastdeploy.worker.forward_meta import ForwardMeta diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 054ffe7f8..b8dc49e1b 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -329,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): with_bias: bool = False, add_bias: bool = False, activation: str = "gelu", - use_fast_ffn: bool = False, skip_quant: bool = False, ): """ @@ -344,11 +343,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): with_bias (bool): Whether to include bias or not. Defaults to False. add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False. activation (str): Activation function to use. Defaults to "gelu". - use_fast_ffn (bool): Whether to use a faster FFN implementation. - Defaults to False. skip_quant (bool): Whether to skip quantization. Defaults to False. """ - self.use_fast_ffn = use_fast_ffn self.activation = activation self.hidden_size = fd_config.model_config.hidden_size self.nranks = fd_config.parallel_config.tensor_parallel_degree @@ -385,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): "gate_proj") bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype( paddle.get_default_dtype()) - converted_bias_tensor = paddle.zeros(shape=list( - bias_tensor.shape), - dtype=bias_tensor.dtype) - if not self.use_fast_ffn: - converted_bias_tensor = paddle.concat( - [bias_tensor[::2], bias_tensor[1::2]], axis=0) - else: - converted_bias_tensor = bias_tensor - state_dict[self.bias_key] = converted_bias_tensor - if not self.use_fast_ffn: - converted_weight_tensor = paddle.concat( - [weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1) - else: - converted_weight_tensor = weight_tensor + state_dict[self.bias_key] = bias_tensor - state_dict[self.weight_key] = converted_weight_tensor + state_dict[self.weight_key] = weight_tensor super().load_state_dict(state_dict) diff --git a/fastdeploy/model_executor/layers/quantization/mix_quant.py b/fastdeploy/model_executor/layers/quantization/mix_quant.py index 7fbb3d88d..4868b346b 100644 --- a/fastdeploy/model_executor/layers/quantization/mix_quant.py +++ b/fastdeploy/model_executor/layers/quantization/mix_quant.py @@ -15,8 +15,9 @@ """ from typing import Optional -from ..attention import Attention -from ..moe import FusedMoE +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.moe.moe import FusedMoE + from . import get_quantization_config from .quant_base import QuantConfigBase, QuantMethodBase diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index a5c84a365..c8ba1f673 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -214,7 +214,7 @@ def load_tp_checkpoint_v1( need_tp = True if tensor_parallel_filtered_map else False state_dict = {} for key, weight in weights_iterator: - paddle.device.cuda.synchronize() + paddle.device.synchronize() if need_tp and key in tensor_parallel_filtered_map: action = tensor_parallel_filtered_map.pop(key) tensor = action(weight).clone() diff --git a/fastdeploy/model_executor/model_loader.py b/fastdeploy/model_executor/model_loader.py index 2010c2021..03ea7fcc6 100644 --- a/fastdeploy/model_executor/model_loader.py +++ b/fastdeploy/model_executor/model_loader.py @@ -86,7 +86,7 @@ class DefaultModelLoader(BaseModelLoader): if isinstance(v, paddle.Tensor): v.value().get_tensor()._clear() paddle.device.cuda.empty_cache() - paddle.device.cuda.synchronize() + paddle.device.synchronize() def load_model(self, fd_config: FDConfig) -> nn.Layer: context = paddle.LazyGuard() diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 8286e77d8..73997c2ac 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.activation import SiluAndMul -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import ( ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear, @@ -68,7 +68,6 @@ class DeepSeekV3MLP(nn.Layer): output_size=intermediate_size * 2, with_bias=False, activation=fd_config.model_config.hidden_act, - use_fast_ffn=True, ) self.down_proj = RowParallelLinear( diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 3a74f4114..f6b73622a 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig from fastdeploy.model_executor.graph_optimization.decorator import \ support_graph_optimization from fastdeploy.model_executor.layers.activation import SiluAndMul -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -62,7 +62,6 @@ class Ernie4_5_MLP(nn.Layer): output_size=intermediate_size * 2, with_bias=False, activation=fd_config.model_config.hidden_act, - use_fast_ffn=True, ) self.down_proj = RowParallelLinear( diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index a8e6955db..0a5912afb 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig from fastdeploy.model_executor.graph_optimization.decorator import \ support_graph_optimization from fastdeploy.model_executor.layers.activation import SiluAndMul -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -55,7 +55,6 @@ class Qwen2MLP(nn.Layer): output_size=fd_config.model_config.ffn_hidden_size * 2, with_bias=False, activation=fd_config.model_config.hidden_act, - use_fast_ffn=True, ) self.down_proj = RowParallelLinear( diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 0c5ecc96f..c1654f414 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -26,7 +26,7 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig, ModelConfig from fastdeploy.model_executor.graph_optimization.decorator import \ support_graph_optimization -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index f73063516..c4d01ef6e 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig from fastdeploy.model_executor.graph_optimization.decorator import \ support_graph_optimization from fastdeploy.model_executor.layers.activation import SiluAndMul -from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -57,7 +57,6 @@ class Qwen3MLP(nn.Layer): output_size=fd_config.model_config.ffn_hidden_size * 2, with_bias=False, activation=fd_config.model_config.hidden_act, - use_fast_ffn=True, ) self.down_proj = RowParallelLinear( diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 769410c29..9b0c86a99 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -24,6 +24,7 @@ class _Backend(enum.Enum): NATIVE_ATTN = enum.auto() APPEND_ATTN = enum.auto() MLA_ATTN = enum.auto() + FLASH_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index f5b3082b5..294506b4b 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -""" -cuda platform file -""" import paddle @@ -65,6 +62,11 @@ class CUDAPlatform(Platform): return ( "fastdeploy.model_executor.layers.attention.MLAAttentionBackend" ) + elif selected_backend == _Backend.FLASH_ATTN: + logger.info("Using FLASH ATTN backend.") + return ( + "fastdeploy.model_executor.layers.attention.FlashAttentionBackend" + ) else: raise ValueError( "Invalid attention backend you specified.\n" diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index ec962f574..264656cbf 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -90,7 +90,8 @@ class MTPProposer(Proposer): self.model = get_model_from_loader(self.cfg) - def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): + def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, + expected_decode_len: int): """Set dummy prefill inputs to model_inputs""" max_dec_len = expected_decode_len + 1 self.num_gpu_blocks = self.parallel_config.max_block_num @@ -130,10 +131,10 @@ class MTPProposer(Proposer): self.cache_kvs = {} cache_type = self.parallel_config.dtype - - if (self.quant_config and - hasattr(self.quant_config, "kv_cache_quant_type") and - self.quant_config.kv_cache_quant_type is not None): + + if (self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None): cache_type = 'uint8' # Get kv cache shape @@ -190,8 +191,7 @@ class MTPProposer(Proposer): head_dim = self.model_config.head_dim # Get the attention backend - attn_cls = get_attention_backend( - self.parallel_config.attention_backend) + attn_cls = get_attention_backend() attn_backend = attn_cls( self.cfg, kv_num_heads=self.model_config.kv_num_heads, @@ -200,8 +200,8 @@ class MTPProposer(Proposer): ) if attn_backend is None: raise NotImplementedError( - f"{ self.parallel_config.attention_backend} attention backend" - " is not support by GPUModelRunner") + "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." + ) self.attn_backends.append(attn_backend) def clear_dummy_input(self): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b211aa9a8..c13f232d3 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -593,7 +593,8 @@ class GPUModelRunner(ModelRunnerBase): self.model = get_model_from_loader(fd_config=self.fd_config) # 1.1 Load RL dynamic model if self.fd_config.load_config.dynamic_load_weight: - from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager + from fastdeploy.rl.dynamic_weight_manager import \ + DynamicWeightManager self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) # 2. Load lora model @@ -622,7 +623,7 @@ class GPUModelRunner(ModelRunnerBase): # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - + def clear_cache(self): """Clear cached data from shared inputs and forward metadata.""" self.share_inputs.pop("caches", None) @@ -719,7 +720,7 @@ class GPUModelRunner(ModelRunnerBase): head_dim=head_dim) if attn_backend is None: raise NotImplementedError( - "Attention backend which you chose is not support by GPUModelRunner" + "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." ) self.attn_backends.append(attn_backend) @@ -1150,7 +1151,6 @@ class GPUModelRunner(ModelRunnerBase): if self.speculative_method in ["mtp"]: self.proposer.clear_dummy_input() - # paddle.device.cuda.synchronize() def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 8df9357b5..b075356f9 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -590,7 +590,7 @@ class XPUModelRunner(ModelRunnerBase): head_dim=head_dim) if attn_backend is None: raise NotImplementedError( - "Attention backend which you chose is not support by GPUModelRunner" + "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." ) self.attn_backends.append(attn_backend)